Merge branch 'master' of https://github.com/matrix-org/dendrite into add-health-endpoint

This commit is contained in:
Till Faelligen 2020-10-17 15:34:39 +02:00
commit 18ce352f29
65 changed files with 1189 additions and 462 deletions

View file

@ -30,6 +30,7 @@ import (
"github.com/matrix-org/dendrite/appservice/workers" "github.com/matrix-org/dendrite/appservice/workers"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/internal/setup"
"github.com/matrix-org/dendrite/internal/setup/kafka"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -47,6 +48,8 @@ func NewInternalAPI(
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
) appserviceAPI.AppServiceQueryAPI { ) appserviceAPI.AppServiceQueryAPI {
consumer, _ := kafka.SetupConsumerProducer(&base.Cfg.Global.Kafka)
// Create a connection to the appservice postgres DB // Create a connection to the appservice postgres DB
appserviceDB, err := storage.NewDatabase(&base.Cfg.AppServiceAPI.Database) appserviceDB, err := storage.NewDatabase(&base.Cfg.AppServiceAPI.Database)
if err != nil { if err != nil {
@ -86,7 +89,7 @@ func NewInternalAPI(
// We can't add ASes at runtime so this is safe to do. // We can't add ASes at runtime so this is safe to do.
if len(workerStates) > 0 { if len(workerStates) > 0 {
consumer := consumers.NewOutputRoomEventConsumer( consumer := consumers.NewOutputRoomEventConsumer(
base.Cfg, base.KafkaConsumer, appserviceDB, base.Cfg, consumer, appserviceDB,
rsAPI, workerStates, rsAPI, workerStates,
) )
if err := consumer.Start(); err != nil { if err := consumer.Start(); err != nil {

View file

@ -38,21 +38,38 @@ go run github.com/matrix-org/dendrite/cmd/generate-keys \
--tls-key=server.key --tls-key=server.key
``` ```
## Starting Dendrite ## Starting Dendrite as a monolith deployment
Once in place, start the dependencies: Create your config based on the `dendrite.yaml` configuration file in the `docker/config`
folder in the [Dendrite repository](https://github.com/matrix-org/dendrite). Additionally,
make the following changes to the configuration:
- Enable Naffka: `use_naffka: true`
Once in place, start the PostgreSQL dependency:
``` ```
docker-compose -f docker-compose.deps.yml up docker-compose -f docker-compose.deps.yml up postgres
``` ```
Wait a few seconds for Kafka and Postgres to finish starting up, and then start a monolith: Wait a few seconds for PostgreSQL to finish starting up, and then start a monolith:
``` ```
docker-compose -f docker-compose.monolith.yml up docker-compose -f docker-compose.monolith.yml up
``` ```
... or start the polylith components: ## Starting Dendrite as a polylith deployment
Create your config based on the `dendrite.yaml` configuration file in the `docker/config`
folder in the [Dendrite repository](https://github.com/matrix-org/dendrite).
Once in place, start all the dependencies:
```
docker-compose -f docker-compose.deps.yml up
```
Wait a few seconds for PostgreSQL and Kafka to finish starting up, and then start a polylith:
``` ```
docker-compose -f docker-compose.polylith.yml up docker-compose -f docker-compose.polylith.yml up

View file

@ -76,7 +76,7 @@ global:
# Naffka database options. Not required when using Kafka. # Naffka database options. Not required when using Kafka.
naffka_database: naffka_database:
connection_string: file:naffka.db connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_naffka?sslmode=disable
max_open_conns: 100 max_open_conns: 100
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1

View file

@ -6,6 +6,9 @@ services:
restart: always restart: always
volumes: volumes:
- ./postgres/create_db.sh:/docker-entrypoint-initdb.d/20-create_db.sh - ./postgres/create_db.sh:/docker-entrypoint-initdb.d/20-create_db.sh
# To persist your PostgreSQL databases outside of the Docker image, to
# prevent data loss, you will need to add something like this:
# - ./path/to/persistent/storage:/var/lib/postgresql/data
environment: environment:
POSTGRES_PASSWORD: itsasecret POSTGRES_PASSWORD: itsasecret
POSTGRES_USER: dendrite POSTGRES_USER: dendrite
@ -26,6 +29,8 @@ services:
KAFKA_ADVERTISED_HOST_NAME: "kafka" KAFKA_ADVERTISED_HOST_NAME: "kafka"
KAFKA_DELETE_TOPIC_ENABLE: "true" KAFKA_DELETE_TOPIC_ENABLE: "true"
KAFKA_ZOOKEEPER_CONNECT: "zookeeper:2181" KAFKA_ZOOKEEPER_CONNECT: "zookeeper:2181"
ports:
- 9092:9092
depends_on: depends_on:
- zookeeper - zookeeper
networks: networks:

2
build/docker/postgres/create_db.sh Normal file → Executable file
View file

@ -1,4 +1,4 @@
#!/bin/bash #!/bin/sh
for db in account device mediaapi syncapi roomserver signingkeyserver keyserver federationsender appservice e2ekey naffka; do for db in account device mediaapi syncapi roomserver signingkeyserver keyserver federationsender appservice e2ekey naffka; do
createdb -U dendrite -O dendrite dendrite_$db createdb -U dendrite -O dendrite dendrite_$db

View file

@ -112,7 +112,7 @@ func (m *DendriteMonolith) Start() {
serverKeyAPI := &signing.YggdrasilKeys{} serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing() keyRing := serverKeyAPI.KeyRing()
keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer) keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation)
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI)
keyAPI.SetUserAPI(userAPI) keyAPI.SetUserAPI(userAPI)
@ -146,13 +146,11 @@ func (m *DendriteMonolith) Start() {
rsAPI.SetFederationSenderAPI(fsAPI) rsAPI.SetFederationSenderAPI(fsAPI)
monolith := setup.Monolith{ monolith := setup.Monolith{
Config: base.Cfg, Config: base.Cfg,
AccountDB: accountDB, AccountDB: accountDB,
Client: ygg.CreateClient(base), Client: ygg.CreateClient(base),
FedClient: federation, FedClient: federation,
KeyRing: keyRing, KeyRing: keyRing,
KafkaConsumer: base.KafkaConsumer,
KafkaProducer: base.KafkaProducer,
AppserviceAPI: asAPI, AppserviceAPI: asAPI,
EDUInternalAPI: eduInputAPI, EDUInternalAPI: eduInputAPI,

View file

@ -15,7 +15,6 @@
package clientapi package clientapi
import ( import (
"github.com/Shopify/sarama"
"github.com/gorilla/mux" "github.com/gorilla/mux"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/api"
@ -24,6 +23,7 @@ import (
eduServerAPI "github.com/matrix-org/dendrite/eduserver/api" eduServerAPI "github.com/matrix-org/dendrite/eduserver/api"
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/setup/kafka"
"github.com/matrix-org/dendrite/internal/transactions" "github.com/matrix-org/dendrite/internal/transactions"
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
@ -36,7 +36,6 @@ import (
func AddPublicRoutes( func AddPublicRoutes(
router *mux.Router, router *mux.Router,
cfg *config.ClientAPI, cfg *config.ClientAPI,
producer sarama.SyncProducer,
accountsDB accounts.Database, accountsDB accounts.Database,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
@ -48,6 +47,8 @@ func AddPublicRoutes(
keyAPI keyserverAPI.KeyInternalAPI, keyAPI keyserverAPI.KeyInternalAPI,
extRoomsProvider api.ExtraPublicRoomsProvider, extRoomsProvider api.ExtraPublicRoomsProvider,
) { ) {
_, producer := kafka.SetupConsumerProducer(&cfg.Matrix.Kafka)
syncProducer := &producers.SyncAPIProducer{ syncProducer := &producers.SyncAPIProducer{
Producer: producer, Producer: producer,
Topic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputClientData), Topic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputClientData),

View file

@ -339,12 +339,21 @@ func createRoom(
util.GetLogger(req.Context()).WithError(err).Error("authEvents.AddEvent failed") util.GetLogger(req.Context()).WithError(err).Error("authEvents.AddEvent failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
}
// send events to the room server accumulated := gomatrixserverlib.UnwrapEventHeaders(builtEvents)
if err = roomserverAPI.SendEvents(req.Context(), rsAPI, builtEvents, cfg.Matrix.ServerName, nil); err != nil { if err = roomserverAPI.SendEventWithState(
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") req.Context(),
return jsonerror.InternalServerError() rsAPI,
&gomatrixserverlib.RespState{
StateEvents: accumulated,
AuthEvents: accumulated,
},
ev.Headered(roomVersion),
nil,
); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SendEventWithState failed")
return jsonerror.InternalServerError()
}
} }
// TODO(#269): Reserve room alias while we create the room. This stops us // TODO(#269): Reserve room alias while we create the room. This stops us

View file

@ -13,6 +13,7 @@ import (
type rateLimits struct { type rateLimits struct {
limits map[string]chan struct{} limits map[string]chan struct{}
limitsMutex sync.RWMutex limitsMutex sync.RWMutex
cleanMutex sync.RWMutex
enabled bool enabled bool
requestThreshold int64 requestThreshold int64
cooloffDuration time.Duration cooloffDuration time.Duration
@ -38,6 +39,7 @@ func (l *rateLimits) clean() {
// empty. If they are then we will close and delete them, // empty. If they are then we will close and delete them,
// freeing up memory. // freeing up memory.
time.Sleep(time.Second * 30) time.Sleep(time.Second * 30)
l.cleanMutex.Lock()
l.limitsMutex.Lock() l.limitsMutex.Lock()
for k, c := range l.limits { for k, c := range l.limits {
if len(c) == 0 { if len(c) == 0 {
@ -46,6 +48,7 @@ func (l *rateLimits) clean() {
} }
} }
l.limitsMutex.Unlock() l.limitsMutex.Unlock()
l.cleanMutex.Unlock()
} }
} }
@ -55,12 +58,12 @@ func (l *rateLimits) rateLimit(req *http.Request) *util.JSONResponse {
return nil return nil
} }
// Lock the map long enough to check for rate limiting. We hold it // Take a read lock out on the cleaner mutex. The cleaner expects to
// for longer here than we really need to but it makes sure that we // be able to take a write lock, which isn't possible while there are
// also don't conflict with the cleaner goroutine which might clean // readers, so this has the effect of blocking the cleaner goroutine
// up a channel after we have retrieved it otherwise. // from doing its work until there are no requests in flight.
l.limitsMutex.RLock() l.cleanMutex.RLock()
defer l.limitsMutex.RUnlock() defer l.cleanMutex.RUnlock()
// First of all, work out if X-Forwarded-For was sent to us. If not // First of all, work out if X-Forwarded-For was sent to us. If not
// then we'll just use the IP address of the caller. // then we'll just use the IP address of the caller.
@ -69,12 +72,19 @@ func (l *rateLimits) rateLimit(req *http.Request) *util.JSONResponse {
caller = forwardedFor caller = forwardedFor
} }
// Look up the caller's channel, if they have one. If they don't then // Look up the caller's channel, if they have one.
// let's create one. l.limitsMutex.RLock()
rateLimit, ok := l.limits[caller] rateLimit, ok := l.limits[caller]
l.limitsMutex.RUnlock()
// If the caller doesn't have a channel, create one and write it
// back to the map.
if !ok { if !ok {
l.limits[caller] = make(chan struct{}, l.requestThreshold) rateLimit = make(chan struct{}, l.requestThreshold)
rateLimit = l.limits[caller]
l.limitsMutex.Lock()
l.limits[caller] = rateLimit
l.limitsMutex.Unlock()
} }
// Check if the user has got free resource slots for this request. // Check if the user has got free resource slots for this request.

View file

@ -37,7 +37,7 @@ func main() {
keyAPI := base.KeyServerHTTPClient() keyAPI := base.KeyServerHTTPClient()
clientapi.AddPublicRoutes( clientapi.AddPublicRoutes(
base.PublicClientAPIMux, &base.Cfg.ClientAPI, base.KafkaProducer, accountDB, federation, base.PublicClientAPIMux, &base.Cfg.ClientAPI, accountDB, federation,
rsAPI, eduInputAPI, asQuery, transactions.New(), fsAPI, userAPI, keyAPI, nil, rsAPI, eduInputAPI, asQuery, transactions.New(), fsAPI, userAPI, keyAPI, nil,
) )

View file

@ -89,7 +89,7 @@ func createClient(
"matrix", "matrix",
p2phttp.NewTransport(base.LibP2P, p2phttp.ProtocolOption("/matrix")), p2phttp.NewTransport(base.LibP2P, p2phttp.ProtocolOption("/matrix")),
) )
return gomatrixserverlib.NewClientWithTransport(true, tr) return gomatrixserverlib.NewClientWithTransport(tr)
} }
func main() { func main() {
@ -139,7 +139,7 @@ func main() {
accountDB := base.Base.CreateAccountsDB() accountDB := base.Base.CreateAccountsDB()
federation := createFederationClient(base) federation := createFederationClient(base)
keyAPI := keyserver.NewInternalAPI(&base.Base.Cfg.KeyServer, federation, base.Base.KafkaProducer) keyAPI := keyserver.NewInternalAPI(&base.Base.Cfg.KeyServer, federation)
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
keyAPI.SetUserAPI(userAPI) keyAPI.SetUserAPI(userAPI)
@ -169,13 +169,11 @@ func main() {
} }
monolith := setup.Monolith{ monolith := setup.Monolith{
Config: base.Base.Cfg, Config: base.Base.Cfg,
AccountDB: accountDB, AccountDB: accountDB,
Client: createClient(base), Client: createClient(base),
FedClient: federation, FedClient: federation,
KeyRing: keyRing, KeyRing: keyRing,
KafkaConsumer: base.Base.KafkaConsumer,
KafkaProducer: base.Base.KafkaProducer,
AppserviceAPI: asAPI, AppserviceAPI: asAPI,
EDUInternalAPI: eduInputAPI, EDUInternalAPI: eduInputAPI,

View file

@ -96,7 +96,7 @@ func main() {
serverKeyAPI := &signing.YggdrasilKeys{} serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing() keyRing := serverKeyAPI.KeyRing()
keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer) keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation)
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
keyAPI.SetUserAPI(userAPI) keyAPI.SetUserAPI(userAPI)
@ -129,13 +129,11 @@ func main() {
rsComponent.SetFederationSenderAPI(fsAPI) rsComponent.SetFederationSenderAPI(fsAPI)
monolith := setup.Monolith{ monolith := setup.Monolith{
Config: base.Cfg, Config: base.Cfg,
AccountDB: accountDB, AccountDB: accountDB,
Client: ygg.CreateClient(base), Client: ygg.CreateClient(base),
FedClient: federation, FedClient: federation,
KeyRing: keyRing, KeyRing: keyRing,
KafkaConsumer: base.KafkaConsumer,
KafkaProducer: base.KafkaProducer,
AppserviceAPI: asAPI, AppserviceAPI: asAPI,
EDUInternalAPI: eduInputAPI, EDUInternalAPI: eduInputAPI,

View file

@ -33,7 +33,7 @@ func (n *Node) CreateClient(
}, },
}, },
) )
return gomatrixserverlib.NewClientWithTransport(true, tr) return gomatrixserverlib.NewClientWithTransport(tr)
} }
func (n *Node) CreateFederationClient( func (n *Node) CreateFederationClient(

View file

@ -24,7 +24,7 @@ func main() {
base := setup.NewBaseDendrite(cfg, "KeyServer", true) base := setup.NewBaseDendrite(cfg, "KeyServer", true)
defer base.Close() // nolint: errcheck defer base.Close() // nolint: errcheck
intAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, base.CreateFederationClient(), base.KafkaProducer) intAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, base.CreateFederationClient())
intAPI.SetUserAPI(base.UserAPIClient()) intAPI.SetUserAPI(base.UserAPIClient())
keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI) keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI)

View file

@ -108,7 +108,7 @@ func main() {
// This is different to rsAPI which can be the http client which doesn't need this dependency // This is different to rsAPI which can be the http client which doesn't need this dependency
rsImpl.SetFederationSenderAPI(fsAPI) rsImpl.SetFederationSenderAPI(fsAPI)
keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, fsAPI, base.KafkaProducer) keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, fsAPI)
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI)
keyAPI.SetUserAPI(userAPI) keyAPI.SetUserAPI(userAPI)
@ -127,13 +127,11 @@ func main() {
} }
monolith := setup.Monolith{ monolith := setup.Monolith{
Config: base.Cfg, Config: base.Cfg,
AccountDB: accountDB, AccountDB: accountDB,
Client: base.CreateClient(), Client: base.CreateClient(),
FedClient: federation, FedClient: federation,
KeyRing: keyRing, KeyRing: keyRing,
KafkaConsumer: base.KafkaConsumer,
KafkaProducer: base.KafkaProducer,
AppserviceAPI: asAPI, AppserviceAPI: asAPI,
EDUInternalAPI: eduInputAPI, EDUInternalAPI: eduInputAPI,

View file

@ -21,6 +21,7 @@ import (
func main() { func main() {
cfg := setup.ParseFlags(false) cfg := setup.ParseFlags(false)
base := setup.NewBaseDendrite(cfg, "SyncAPI", true) base := setup.NewBaseDendrite(cfg, "SyncAPI", true)
defer base.Close() // nolint: errcheck defer base.Close() // nolint: errcheck
@ -30,7 +31,7 @@ func main() {
rsAPI := base.RoomserverHTTPClient() rsAPI := base.RoomserverHTTPClient()
syncapi.AddPublicRoutes( syncapi.AddPublicRoutes(
base.PublicClientAPIMux, base.KafkaConsumer, userAPI, rsAPI, base.PublicClientAPIMux, userAPI, rsAPI,
base.KeyServerHTTPClient(), base.KeyServerHTTPClient(),
federation, &cfg.SyncAPI, federation, &cfg.SyncAPI,
) )

View file

@ -141,14 +141,14 @@ func createFederationClient(cfg *config.Dendrite, node *go_http_js_libp2p.P2pLoc
fed := gomatrixserverlib.NewFederationClient( fed := gomatrixserverlib.NewFederationClient(
cfg.Global.ServerName, cfg.Global.KeyID, cfg.Global.PrivateKey, true, cfg.Global.ServerName, cfg.Global.KeyID, cfg.Global.PrivateKey, true,
) )
fed.Client = *gomatrixserverlib.NewClientWithTransport(true, tr) fed.Client = *gomatrixserverlib.NewClientWithTransport(tr)
return fed return fed
} }
func createClient(node *go_http_js_libp2p.P2pLocalNode) *gomatrixserverlib.Client { func createClient(node *go_http_js_libp2p.P2pLocalNode) *gomatrixserverlib.Client {
tr := go_http_js_libp2p.NewP2pTransport(node) tr := go_http_js_libp2p.NewP2pTransport(node)
return gomatrixserverlib.NewClientWithTransport(true, tr) return gomatrixserverlib.NewClientWithTransport(tr)
} }
func createP2PNode(privKey ed25519.PrivateKey) (serverName string, node *go_http_js_libp2p.P2pLocalNode) { func createP2PNode(privKey ed25519.PrivateKey) (serverName string, node *go_http_js_libp2p.P2pLocalNode) {
@ -190,7 +190,7 @@ func main() {
accountDB := base.CreateAccountsDB() accountDB := base.CreateAccountsDB()
federation := createFederationClient(cfg, node) federation := createFederationClient(cfg, node)
keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer) keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation)
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
keyAPI.SetUserAPI(userAPI) keyAPI.SetUserAPI(userAPI)
@ -212,13 +212,11 @@ func main() {
p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation) p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation)
monolith := setup.Monolith{ monolith := setup.Monolith{
Config: base.Cfg, Config: base.Cfg,
AccountDB: accountDB, AccountDB: accountDB,
Client: createClient(node), Client: createClient(node),
FedClient: federation, FedClient: federation,
KeyRing: &keyRing, KeyRing: &keyRing,
KafkaConsumer: base.KafkaConsumer,
KafkaProducer: base.KafkaProducer,
AppserviceAPI: asQuery, AppserviceAPI: asQuery,
EDUInternalAPI: eduInputAPI, EDUInternalAPI: eduInputAPI,

View file

@ -8,19 +8,38 @@ import (
"log" "log"
"os" "os"
// Example complex Go migration import: pgaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
// _ "github.com/matrix-org/dendrite/serverkeyapi/storage/postgres/deltas" slaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
pgdevices "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
sldevices "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
"github.com/pressly/goose" "github.com/pressly/goose"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
var ( const (
flags = flag.NewFlagSet("goose", flag.ExitOnError) AppService = "appservice"
dir = flags.String("dir", ".", "directory with migration files") FederationSender = "federationsender"
KeyServer = "keyserver"
MediaAPI = "mediaapi"
RoomServer = "roomserver"
SigningKeyServer = "signingkeyserver"
SyncAPI = "syncapi"
UserAPIAccounts = "userapi_accounts"
UserAPIDevices = "userapi_devices"
) )
var (
dir = flags.String("dir", "", "directory with migration files")
flags = flag.NewFlagSet("goose", flag.ExitOnError)
component = flags.String("component", "", "dendrite component name")
knownDBs = []string{
AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPIAccounts, UserAPIDevices,
}
)
// nolint: gocyclo
func main() { func main() {
err := flags.Parse(os.Args[1:]) err := flags.Parse(os.Args[1:])
if err != nil { if err != nil {
@ -37,19 +56,20 @@ Drivers:
sqlite3 sqlite3
Examples: Examples:
goose -d roomserver/storage/sqlite3/deltas sqlite3 ./roomserver.db status goose -component roomserver sqlite3 ./roomserver.db status
goose -d roomserver/storage/sqlite3/deltas sqlite3 ./roomserver.db up goose -component roomserver sqlite3 ./roomserver.db up
goose -d roomserver/storage/postgres/deltas postgres "user=dendrite dbname=dendrite sslmode=disable" status goose -component roomserver postgres "user=dendrite dbname=dendrite sslmode=disable" status
Options: Options:
-component string
-dir string Dendrite component name e.g roomserver, signingkeyserver, clientapi, syncapi
directory with migration files (default ".")
-table string -table string
migrations table name (default "goose_db_version") migrations table name (default "goose_db_version")
-h print help -h print help
-v enable verbose mode -v enable verbose mode
-dir string
directory with migration files, only relevant when creating new migrations.
-version -version
print version print version
@ -74,6 +94,25 @@ Commands:
fmt.Println("engine must be one of 'sqlite3' or 'postgres'") fmt.Println("engine must be one of 'sqlite3' or 'postgres'")
return return
} }
knownComponent := false
for _, c := range knownDBs {
if c == *component {
knownComponent = true
break
}
}
if !knownComponent {
fmt.Printf("component must be one of %v\n", knownDBs)
return
}
if engine == "sqlite3" {
loadSQLiteDeltas(*component)
} else {
loadPostgresDeltas(*component)
}
dbstring, command := args[1], args[2] dbstring, command := args[1], args[2]
db, err := goose.OpenDBWithDriver(engine, dbstring) db, err := goose.OpenDBWithDriver(engine, dbstring)
@ -92,7 +131,30 @@ Commands:
arguments = append(arguments, args[3:]...) arguments = append(arguments, args[3:]...)
} }
if err := goose.Run(command, db, *dir, arguments...); err != nil { // goose demands a directory even though we don't use it for upgrades
d := *dir
if d == "" {
d = os.TempDir()
}
if err := goose.Run(command, db, d, arguments...); err != nil {
log.Fatalf("goose %v: %v", command, err) log.Fatalf("goose %v: %v", command, err)
} }
} }
func loadSQLiteDeltas(component string) {
switch component {
case UserAPIAccounts:
slaccounts.LoadFromGoose()
case UserAPIDevices:
sldevices.LoadFromGoose()
}
}
func loadPostgresDeltas(component string) {
switch component {
case UserAPIAccounts:
pgaccounts.LoadFromGoose()
case UserAPIDevices:
pgdevices.LoadFromGoose()
}
}

View file

@ -0,0 +1,17 @@
[Unit]
Description=Dendrite (Matrix Homeserver)
After=syslog.target
After=network.target
After=postgresql.service
[Service]
RestartSec=2s
Type=simple
User=dendrite
Group=dendrite
WorkingDirectory=/opt/dendrite/
ExecStart=/opt/dendrite/bin/dendrite-monolith-server
Restart=always
[Install]
WantedBy=multi-user.target

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/eduserver/inthttp" "github.com/matrix-org/dendrite/eduserver/inthttp"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/internal/setup"
"github.com/matrix-org/dendrite/internal/setup/kafka"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
) )
@ -41,10 +42,13 @@ func NewInternalAPI(
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
) api.EDUServerInputAPI { ) api.EDUServerInputAPI {
cfg := &base.Cfg.EDUServer cfg := &base.Cfg.EDUServer
_, producer := kafka.SetupConsumerProducer(&cfg.Matrix.Kafka)
return &input.EDUServerInputAPI{ return &input.EDUServerInputAPI{
Cache: eduCache, Cache: eduCache,
UserAPI: userAPI, UserAPI: userAPI,
Producer: base.KafkaProducer, Producer: producer,
OutputTypingEventTopic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent)), OutputTypingEventTopic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent)),
OutputSendToDeviceEventTopic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent)), OutputSendToDeviceEventTopic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent)),
ServerName: cfg.Matrix.ServerName, ServerName: cfg.Matrix.ServerName,

View file

@ -16,6 +16,7 @@ package routing
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -182,7 +183,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res
// Process the events. // Process the events.
for _, e := range pdus { for _, e := range pdus {
if err := t.processEvent(ctx, e.Unwrap(), true); err != nil { if err := t.processEvent(ctx, e.Unwrap()); err != nil {
// If the error is due to the event itself being bad then we skip // If the error is due to the event itself being bad then we skip
// it and move onto the next event. We report an error so that the // it and move onto the next event. We report an error so that the
// sender knows that we have skipped processing it. // sender knows that we have skipped processing it.
@ -234,17 +235,10 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res
// we should stop processing the transaction, and returns false if it // we should stop processing the transaction, and returns false if it
// is just some less serious error about a specific event. // is just some less serious error about a specific event.
func isProcessingErrorFatal(err error) bool { func isProcessingErrorFatal(err error) bool {
switch err.(type) { switch err {
case roomNotFoundError: case sql.ErrConnDone:
case *gomatrixserverlib.NotAllowed: case sql.ErrTxDone:
case missingPrevEventsError: return true
default:
switch err {
case context.Canceled:
case context.DeadlineExceeded:
default:
return true
}
} }
return false return false
} }
@ -252,9 +246,6 @@ func isProcessingErrorFatal(err error) bool {
type roomNotFoundError struct { type roomNotFoundError struct {
roomID string roomID string
} }
type unmarshalError struct {
err error
}
type verifySigError struct { type verifySigError struct {
eventID string eventID string
err error err error
@ -265,7 +256,6 @@ type missingPrevEventsError struct {
} }
func (e roomNotFoundError) Error() string { return fmt.Sprintf("room %q not found", e.roomID) } func (e roomNotFoundError) Error() string { return fmt.Sprintf("room %q not found", e.roomID) }
func (e unmarshalError) Error() string { return fmt.Sprintf("unable to parse event: %s", e.err) }
func (e verifySigError) Error() string { func (e verifySigError) Error() string {
return fmt.Sprintf("unable to verify signature of event %q: %s", e.eventID, e.err) return fmt.Sprintf("unable to verify signature of event %q: %s", e.eventID, e.err)
} }
@ -299,6 +289,15 @@ func (t *txnReq) processEDUs(ctx context.Context) {
util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal typing event") util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal typing event")
continue continue
} }
_, domain, err := gomatrixserverlib.SplitID('@', typingPayload.UserID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to split domain from typing event sender")
continue
}
if domain != t.Origin {
util.GetLogger(ctx).Warnf("Dropping typing event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin)
continue
}
if err := eduserverAPI.SendTyping(ctx, t.eduAPI, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { if err := eduserverAPI.SendTyping(ctx, t.eduAPI, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to edu server") util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to edu server")
} }
@ -344,11 +343,28 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli
} }
} }
func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, isInboundTxn bool) error { func (t *txnReq) getServers(ctx context.Context, roomID string) []gomatrixserverlib.ServerName {
servers := []gomatrixserverlib.ServerName{t.Origin}
serverReq := &api.QueryServerJoinedToRoomRequest{
RoomID: roomID,
}
serverRes := &api.QueryServerJoinedToRoomResponse{}
if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil {
servers = append(servers, serverRes.ServerNames...)
util.GetLogger(ctx).Infof("Found %d server(s) to query for missing events in %q", len(servers), roomID)
}
return servers
}
func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event) error {
logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID())
// Work out if the roomserver knows everything it needs to know to auth // Work out if the roomserver knows everything it needs to know to auth
// the event. // the event. This includes the prev_events and auth_events.
// NOTE! This is going to include prev_events that have an empty state
// snapshot. This is because we will need to re-request the event, and
// it's /state_ids, in order for it to exist in the roomserver correctly
// before the roomserver tries to work out
stateReq := api.QueryMissingAuthPrevEventsRequest{ stateReq := api.QueryMissingAuthPrevEventsRequest{
RoomID: e.RoomID(), RoomID: e.RoomID(),
AuthEventIDs: e.AuthEventIDs(), AuthEventIDs: e.AuthEventIDs(),
@ -356,7 +372,7 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is
} }
var stateResp api.QueryMissingAuthPrevEventsResponse var stateResp api.QueryMissingAuthPrevEventsResponse
if err := t.rsAPI.QueryMissingAuthPrevEvents(ctx, &stateReq, &stateResp); err != nil { if err := t.rsAPI.QueryMissingAuthPrevEvents(ctx, &stateReq, &stateResp); err != nil {
return err return fmt.Errorf("t.rsAPI.QueryMissingAuthPrevEvents: %w", err)
} }
if !stateResp.RoomExists { if !stateResp.RoomExists {
@ -371,52 +387,14 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is
if len(stateResp.MissingAuthEventIDs) > 0 { if len(stateResp.MissingAuthEventIDs) > 0 {
logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs)) logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs))
if err := t.retrieveMissingAuthEvents(ctx, e, &stateResp); err != nil {
servers := []gomatrixserverlib.ServerName{t.Origin} return fmt.Errorf("t.retrieveMissingAuthEvents: %w", err)
serverReq := &api.QueryServerJoinedToRoomRequest{
RoomID: e.RoomID(),
}
serverRes := &api.QueryServerJoinedToRoomResponse{}
if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil {
servers = append(servers, serverRes.ServerNames...)
logger.Infof("Found %d server(s) to query for missing events", len(servers))
}
getAuthEvent:
for _, missingAuthEventID := range stateResp.MissingAuthEventIDs {
for _, server := range servers {
logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server)
tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID)
if err != nil {
continue // try the next server
}
ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion)
if err != nil {
logger.WithError(err).Errorf("Failed to unmarshal auth event %q", missingAuthEventID)
continue // try the next server
}
if err = api.SendInputRoomEvents(
context.Background(),
t.rsAPI,
[]api.InputRoomEvent{
{
Kind: api.KindOutlier,
Event: ev.Headered(stateResp.RoomVersion),
AuthEventIDs: ev.AuthEventIDs(),
SendAsServer: api.DoNotSendToOtherServers,
},
},
); err != nil {
logger.WithError(err).Errorf("Failed to send auth event %q to roomserver", missingAuthEventID)
continue getAuthEvent // move onto the next event
}
}
} }
} }
if len(stateResp.MissingPrevEventIDs) > 0 { if len(stateResp.MissingPrevEventIDs) > 0 {
logger.Infof("Event refers to %d unknown prev_events", len(stateResp.MissingPrevEventIDs)) logger.Infof("Event refers to %d unknown prev_events", len(stateResp.MissingPrevEventIDs))
return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion, isInboundTxn) return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion)
} }
// pass the event to the roomserver which will do auth checks // pass the event to the roomserver which will do auth checks
@ -433,6 +411,60 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is
) )
} }
func (t *txnReq) retrieveMissingAuthEvents(
ctx context.Context, e gomatrixserverlib.Event, stateResp *api.QueryMissingAuthPrevEventsResponse,
) error {
logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID())
missingAuthEvents := make(map[string]struct{})
for _, missingAuthEventID := range stateResp.MissingAuthEventIDs {
missingAuthEvents[missingAuthEventID] = struct{}{}
}
servers := t.getServers(ctx, e.RoomID())
if len(servers) > 5 {
servers = servers[:5]
}
withNextEvent:
for missingAuthEventID := range missingAuthEvents {
withNextServer:
for _, server := range servers {
logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server)
tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID)
if err != nil {
logger.WithError(err).Warnf("Failed to retrieve auth event %q", missingAuthEventID)
continue withNextServer
}
ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion)
if err != nil {
logger.WithError(err).Warnf("Failed to unmarshal auth event %q", missingAuthEventID)
continue withNextServer
}
if err = api.SendInputRoomEvents(
context.Background(),
t.rsAPI,
[]api.InputRoomEvent{
{
Kind: api.KindOutlier,
Event: ev.Headered(stateResp.RoomVersion),
AuthEventIDs: ev.AuthEventIDs(),
SendAsServer: api.DoNotSendToOtherServers,
},
},
); err != nil {
return fmt.Errorf("api.SendEvents: %w", err)
}
delete(missingAuthEvents, missingAuthEventID)
continue withNextEvent
}
}
if missing := len(missingAuthEvents); missing > 0 {
return fmt.Errorf("Event refers to %d auth_events which we failed to fetch", missing)
}
return nil
}
func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserverlib.Event) error { func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserverlib.Event) error {
authUsingState := gomatrixserverlib.NewAuthEvents(nil) authUsingState := gomatrixserverlib.NewAuthEvents(nil)
for i := range stateEvents { for i := range stateEvents {
@ -444,7 +476,8 @@ func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserver
return gomatrixserverlib.Allowed(e, &authUsingState) return gomatrixserverlib.Allowed(e, &authUsingState)
} }
func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) error { // nolint:gocyclo
func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) error {
// Do this with a fresh context, so that we keep working even if the // Do this with a fresh context, so that we keep working even if the
// original request times out. With any luck, by the time the remote // original request times out. With any luck, by the time the remote
// side retries, we'll have fetched the missing state. // side retries, we'll have fetched the missing state.
@ -470,64 +503,146 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser
// - fill in the gap completely then process event `e` returning no backwards extremity // - fill in the gap completely then process event `e` returning no backwards extremity
// - fail to fill in the gap and tell us to terminate the transaction err=not nil // - fail to fill in the gap and tell us to terminate the transaction err=not nil
// - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction
backwardsExtremity, err := t.getMissingEvents(gmectx, e, roomVersion, isInboundTxn) newEvents, err := t.getMissingEvents(gmectx, e, roomVersion)
if err != nil { if err != nil {
return err return err
} }
if backwardsExtremity == nil { if len(newEvents) == 0 {
// we filled in the gap!
return nil return nil
} }
// at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. backwardsExtremity := &newEvents[0]
// security: we have to do state resolution on the new backwards extremity (TODO: WHY) newEvents = newEvents[1:]
// Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query
// the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. type respState struct {
var states []*gomatrixserverlib.RespState // A snapshot is considered trustworthy if it came from our own roomserver.
needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{*backwardsExtremity}).Tuples() // That's because the state will have been through state resolution once
for _, prevEventID := range backwardsExtremity.PrevEventIDs() { // already in QueryStateAfterEvent.
var prevState *gomatrixserverlib.RespState trustworthy bool
prevState, err = t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID, needed) *gomatrixserverlib.RespState
if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID)
return err
}
states = append(states, prevState)
}
resolvedState, err := t.resolveStatesAndCheck(gmectx, roomVersion, states, backwardsExtremity)
if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID())
return err
} }
// pass the event along with the state to the roomserver using a background context so we don't // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity.
// needlessly expire // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query
return api.SendEventWithState(context.Background(), t.rsAPI, resolvedState, e.Headered(roomVersion), t.haveEventIDs()) // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event.
var states []*respState
for _, prevEventID := range backwardsExtremity.PrevEventIDs() {
// Look up what the state is after the backward extremity. This will either
// come from the roomserver, if we know all the required events, or it will
// come from a remote server via /state_ids if not.
prevState, trustworthy, lerr := t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID)
if lerr != nil {
util.GetLogger(ctx).WithError(lerr).Errorf("Failed to lookup state after prev_event: %s", prevEventID)
return lerr
}
// Append the state onto the collected state. We'll run this through the
// state resolution next.
states = append(states, &respState{trustworthy, prevState})
}
// Now that we have collected all of the state from the prev_events, we'll
// run the state through the appropriate state resolution algorithm for the
// room if needed. This does a couple of things:
// 1. Ensures that the state is deduplicated fully for each state-key tuple
// 2. Ensures that we pick the latest events from both sets, in the case that
// one of the prev_events is quite a bit older than the others
resolvedState := &gomatrixserverlib.RespState{}
switch len(states) {
case 0:
extremityIsCreate := backwardsExtremity.Type() == gomatrixserverlib.MRoomCreate && backwardsExtremity.StateKeyEquals("")
if !extremityIsCreate {
// There are no previous states and this isn't the beginning of the
// room - this is an error condition!
util.GetLogger(ctx).Errorf("Failed to lookup any state after prev_events")
return fmt.Errorf("expected %d states but got %d", len(backwardsExtremity.PrevEventIDs()), len(states))
}
case 1:
// There's only one previous state - if it's trustworthy (came from a
// local state snapshot which will already have been through state res),
// use it as-is. There's no point in resolving it again.
if states[0].trustworthy {
resolvedState = states[0].RespState
break
}
// Otherwise, if it isn't trustworthy (came from federation), run it through
// state resolution anyway for safety, in case there are duplicates.
fallthrough
default:
respStates := make([]*gomatrixserverlib.RespState, len(states))
for i := range states {
respStates[i] = states[i].RespState
}
// There's more than one previous state - run them all through state res
resolvedState, err = t.resolveStatesAndCheck(gmectx, roomVersion, respStates, backwardsExtremity)
if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID())
return err
}
}
// First of all, send the backward extremity into the roomserver with the
// newly resolved state. This marks the "oldest" point in the backfill and
// sets the baseline state for any new events after this.
err = api.SendEventWithState(
context.Background(),
t.rsAPI,
resolvedState,
backwardsExtremity.Headered(roomVersion),
t.haveEventIDs(),
)
if err != nil {
return fmt.Errorf("api.SendEventWithState: %w", err)
}
// Then send all of the newer backfilled events, of which will all be newer
// than the backward extremity, into the roomserver without state. This way
// they will automatically fast-forward based on the room state at the
// extremity in the last step.
headeredNewEvents := make([]gomatrixserverlib.HeaderedEvent, len(newEvents))
for i, newEvent := range newEvents {
headeredNewEvents[i] = newEvent.Headered(roomVersion)
}
if err = api.SendEvents(
context.Background(),
t.rsAPI,
append(headeredNewEvents, e.Headered(roomVersion)),
api.DoNotSendToOtherServers,
nil,
); err != nil {
return fmt.Errorf("api.SendEvents: %w", err)
}
return nil
} }
// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event)
// added into the mix. // added into the mix.
func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) (*gomatrixserverlib.RespState, error) { func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*gomatrixserverlib.RespState, bool, error) {
// try doing all this locally before we resort to querying federation // try doing all this locally before we resort to querying federation
respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID, needed) respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID)
if respState != nil { if respState != nil {
return respState, nil return respState, true, nil
} }
respState, err := t.lookupStateBeforeEvent(ctx, roomVersion, roomID, eventID) respState, err := t.lookupStateBeforeEvent(ctx, roomVersion, roomID, eventID)
if err != nil { if err != nil {
return nil, err return nil, false, fmt.Errorf("t.lookupStateBeforeEvent: %w", err)
}
servers := t.getServers(ctx, roomID)
if len(servers) > 5 {
servers = servers[:5]
} }
// fetch the event we're missing and add it to the pile // fetch the event we're missing and add it to the pile
h, err := t.lookupEvent(ctx, roomVersion, eventID, false) h, err := t.lookupEvent(ctx, roomVersion, eventID, false, servers)
switch err.(type) { switch err.(type) {
case verifySigError: case verifySigError:
return respState, nil return respState, false, nil
case nil: case nil:
// do nothing // do nothing
default: default:
return nil, err return nil, false, fmt.Errorf("t.lookupEvent: %w", err)
} }
t.haveEvents[h.EventID()] = h t.haveEvents[h.EventID()] = h
if h.StateKey() != nil { if h.StateKey() != nil {
@ -545,15 +660,14 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix
} }
} }
return respState, nil return respState, false, nil
} }
func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.RespState { func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *gomatrixserverlib.RespState {
var res api.QueryStateAfterEventsResponse var res api.QueryStateAfterEventsResponse
err := t.rsAPI.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{ err := t.rsAPI.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{
RoomID: roomID, RoomID: roomID,
PrevEventIDs: []string{eventID}, PrevEventIDs: []string{eventID},
StateToFetch: needed,
}, &res) }, &res)
if err != nil || !res.PrevEventsExist { if err != nil || !res.PrevEventsExist {
util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID) util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID)
@ -628,7 +742,11 @@ retryAllowedState:
if err = checkAllowedByState(*backwardsExtremity, resolvedStateEvents); err != nil { if err = checkAllowedByState(*backwardsExtremity, resolvedStateEvents); err != nil {
switch missing := err.(type) { switch missing := err.(type) {
case gomatrixserverlib.MissingAuthEventError: case gomatrixserverlib.MissingAuthEventError:
h, err2 := t.lookupEvent(ctx, roomVersion, missing.AuthEventID, true) servers := t.getServers(ctx, backwardsExtremity.RoomID())
if len(servers) > 5 {
servers = servers[:5]
}
h, err2 := t.lookupEvent(ctx, roomVersion, missing.AuthEventID, true, servers)
switch err2.(type) { switch err2.(type) {
case verifySigError: case verifySigError:
return &gomatrixserverlib.RespState{ return &gomatrixserverlib.RespState{
@ -658,11 +776,7 @@ retryAllowedState:
// This function recursively calls txnReq.processEvent with the missing events, which will be processed before this function returns. // This function recursively calls txnReq.processEvent with the missing events, which will be processed before this function returns.
// This means that we may recursively call this function, as we spider back up prev_events. // This means that we may recursively call this function, as we spider back up prev_events.
// nolint:gocyclo // nolint:gocyclo
func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) (backwardsExtremity *gomatrixserverlib.Event, err error) { func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []gomatrixserverlib.Event, err error) {
if !isInboundTxn {
// we've recursed here, so just take a state snapshot please!
return &e, nil
}
logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID())
needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{e}) needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{e})
// query latest events (our trusted forward extremities) // query latest events (our trusted forward extremities)
@ -673,7 +787,7 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event
var res api.QueryLatestEventsAndStateResponse var res api.QueryLatestEventsAndStateResponse
if err = t.rsAPI.QueryLatestEventsAndState(ctx, &req, &res); err != nil { if err = t.rsAPI.QueryLatestEventsAndState(ctx, &req, &res); err != nil {
logger.WithError(err).Warn("Failed to query latest events") logger.WithError(err).Warn("Failed to query latest events")
return &e, nil return nil, err
} }
latestEvents := make([]string, len(res.LatestEvents)) latestEvents := make([]string, len(res.LatestEvents))
for i := range res.LatestEvents { for i := range res.LatestEvents {
@ -732,7 +846,7 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event
logger.Infof("get_missing_events returned %d events", len(missingResp.Events)) logger.Infof("get_missing_events returned %d events", len(missingResp.Events))
// topologically sort and sanity check that we are making forward progress // topologically sort and sanity check that we are making forward progress
newEvents := gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents) newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents)
shouldHaveSomeEventIDs := e.PrevEventIDs() shouldHaveSomeEventIDs := e.PrevEventIDs()
hasPrevEvent := false hasPrevEvent := false
Event: Event:
@ -755,16 +869,9 @@ Event:
err: err, err: err,
} }
} }
// process the missing events then the event which started this whole thing
for _, ev := range append(newEvents, e) {
err := t.processEvent(ctx, ev, false)
if err != nil {
return nil, err
}
}
// we processed everything! // we processed everything!
return nil, nil return newEvents, nil
} }
func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
@ -844,6 +951,12 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
"concurrent_requests": concurrentRequests, "concurrent_requests": concurrentRequests,
}).Info("Fetching missing state at event") }).Info("Fetching missing state at event")
// Get a list of servers to fetch from.
servers := t.getServers(ctx, roomID)
if len(servers) > 5 {
servers = servers[:5]
}
// Create a queue containing all of the missing event IDs that we want // Create a queue containing all of the missing event IDs that we want
// to retrieve. // to retrieve.
pending := make(chan string, missingCount) pending := make(chan string, missingCount)
@ -869,7 +982,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
// Define what we'll do in order to fetch the missing event ID. // Define what we'll do in order to fetch the missing event ID.
fetch := func(missingEventID string) { fetch := func(missingEventID string) {
var h *gomatrixserverlib.HeaderedEvent var h *gomatrixserverlib.HeaderedEvent
h, err = t.lookupEvent(ctx, roomVersion, missingEventID, false) h, err = t.lookupEvent(ctx, roomVersion, missingEventID, false, servers)
switch err.(type) { switch err.(type) {
case verifySigError: case verifySigError:
return return
@ -907,26 +1020,25 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
} }
func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs) ( func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs) (
*gomatrixserverlib.RespState, error) { *gomatrixserverlib.RespState, error) { // nolint:unparam
// create a RespState response using the response to /state_ids as a guide // create a RespState response using the response to /state_ids as a guide
respState := gomatrixserverlib.RespState{ respState := gomatrixserverlib.RespState{}
AuthEvents: make([]gomatrixserverlib.Event, len(stateIDs.AuthEventIDs)),
StateEvents: make([]gomatrixserverlib.Event, len(stateIDs.StateEventIDs)),
}
for i := range stateIDs.StateEventIDs { for i := range stateIDs.StateEventIDs {
ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]] ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]]
if !ok { if !ok {
return nil, fmt.Errorf("missing state event %s", stateIDs.StateEventIDs[i]) logrus.Warnf("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i])
continue
} }
respState.StateEvents[i] = ev.Unwrap() respState.StateEvents = append(respState.StateEvents, ev.Unwrap())
} }
for i := range stateIDs.AuthEventIDs { for i := range stateIDs.AuthEventIDs {
ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]] ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]]
if !ok { if !ok {
return nil, fmt.Errorf("missing auth event %s", stateIDs.AuthEventIDs[i]) logrus.Warnf("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i])
continue
} }
respState.AuthEvents[i] = ev.Unwrap() respState.AuthEvents = append(respState.AuthEvents, ev.Unwrap())
} }
// We purposefully do not do auth checks on the returned events, as they will still // We purposefully do not do auth checks on the returned events, as they will still
// be processed in the exact same way, just as a 'rejected' event // be processed in the exact same way, just as a 'rejected' event
@ -934,7 +1046,7 @@ func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStat
return &respState, nil return &respState, nil
} }
func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) { func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, missingEventID string, localFirst bool, servers []gomatrixserverlib.ServerName) (*gomatrixserverlib.HeaderedEvent, error) {
if localFirst { if localFirst {
// fetch from the roomserver // fetch from the roomserver
queryReq := api.QueryEventsByIDRequest{ queryReq := api.QueryEventsByIDRequest{
@ -947,19 +1059,27 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.
return &queryRes.Events[0], nil return &queryRes.Events[0], nil
} }
} }
txn, err := t.federation.GetEvent(ctx, t.Origin, missingEventID)
if err != nil || len(txn.PDUs) == 0 {
util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("failed to get missing /event for event ID")
return nil, err
}
pdu := txn.PDUs[0]
var event gomatrixserverlib.Event var event gomatrixserverlib.Event
event, err = gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) found := false
if err != nil { for _, serverName := range servers {
util.GetLogger(ctx).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %q", event.EventID()) txn, err := t.federation.GetEvent(ctx, serverName, missingEventID)
return nil, unmarshalError{err} if err != nil || len(txn.PDUs) == 0 {
util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("Failed to get missing /event for event ID")
continue
}
event, err = gomatrixserverlib.NewEventFromUntrustedJSON(txn.PDUs[0], roomVersion)
if err != nil {
util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warnf("Transaction: Failed to parse event JSON of event")
continue
}
found = true
break
} }
if err = gomatrixserverlib.VerifyAllEventSignatures(ctx, []gomatrixserverlib.Event{event}, t.keys); err != nil { if !found {
util.GetLogger(ctx).WithField("event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(servers))
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(servers))
}
if err := gomatrixserverlib.VerifyAllEventSignatures(ctx, []gomatrixserverlib.Event{event}, t.keys); err != nil {
util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID())
return nil, verifySigError{event.EventID(), err} return nil, verifySigError{event.EventID(), err}
} }

View file

@ -491,7 +491,7 @@ func TestTransactionFailAuthChecks(t *testing.T) {
queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse {
return api.QueryMissingAuthPrevEventsResponse{ return api.QueryMissingAuthPrevEventsResponse{
RoomExists: true, RoomExists: true,
MissingAuthEventIDs: []string{"create_event"}, MissingAuthEventIDs: []string{},
MissingPrevEventIDs: []string{}, MissingPrevEventIDs: []string{},
} }
}, },
@ -516,6 +516,23 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) {
var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions
rsAPI = &testRoomserverAPI{ rsAPI = &testRoomserverAPI{
queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse {
res := api.QueryEventsByIDResponse{}
for _, ev := range testEvents {
for _, id := range req.EventIDs {
if ev.EventID() == id {
res.Events = append(res.Events, ev)
}
}
}
return res
},
queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse {
return api.QueryStateAfterEventsResponse{
PrevEventsExist: true,
StateEvents: testEvents[:5],
}
},
queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse {
missingPrevEvent := []string{"missing_prev_event"} missingPrevEvent := []string{"missing_prev_event"}
if len(req.PrevEventIDs) == 1 { if len(req.PrevEventIDs) == 1 {

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/federationsender/statistics" "github.com/matrix-org/dendrite/federationsender/statistics"
"github.com/matrix-org/dendrite/federationsender/storage" "github.com/matrix-org/dendrite/federationsender/storage"
"github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/internal/setup"
"github.com/matrix-org/dendrite/internal/setup/kafka"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -55,6 +56,8 @@ func NewInternalAPI(
FailuresUntilBlacklist: cfg.FederationMaxRetries, FailuresUntilBlacklist: cfg.FederationMaxRetries,
} }
consumer, _ := kafka.SetupConsumerProducer(&cfg.Matrix.Kafka)
queues := queue.NewOutgoingQueues( queues := queue.NewOutgoingQueues(
federationSenderDB, cfg.Matrix.ServerName, federation, federationSenderDB, cfg.Matrix.ServerName, federation,
rsAPI, stats, rsAPI, stats,
@ -66,7 +69,7 @@ func NewInternalAPI(
) )
rsConsumer := consumers.NewOutputRoomEventConsumer( rsConsumer := consumers.NewOutputRoomEventConsumer(
cfg, base.KafkaConsumer, queues, cfg, consumer, queues,
federationSenderDB, rsAPI, federationSenderDB, rsAPI,
) )
if err = rsConsumer.Start(); err != nil { if err = rsConsumer.Start(); err != nil {
@ -74,13 +77,13 @@ func NewInternalAPI(
} }
tsConsumer := consumers.NewOutputEDUConsumer( tsConsumer := consumers.NewOutputEDUConsumer(
cfg, base.KafkaConsumer, queues, federationSenderDB, cfg, consumer, queues, federationSenderDB,
) )
if err := tsConsumer.Start(); err != nil { if err := tsConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start typing server consumer") logrus.WithError(err).Panic("failed to start typing server consumer")
} }
keyConsumer := consumers.NewKeyChangeConsumer( keyConsumer := consumers.NewKeyChangeConsumer(
&base.Cfg.KeyServer, base.KafkaConsumer, queues, federationSenderDB, rsAPI, &base.Cfg.KeyServer, consumer, queues, federationSenderDB, rsAPI,
) )
if err := keyConsumer.Start(); err != nil { if err := keyConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start key server consumer") logrus.WithError(err).Panic("failed to start key server consumer")

View file

@ -109,6 +109,8 @@ func (a *FederationSenderInternalAPI) doRequest(
func (a *FederationSenderInternalAPI) GetUserDevices( func (a *FederationSenderInternalAPI) GetUserDevices(
ctx context.Context, s gomatrixserverlib.ServerName, userID string, ctx context.Context, s gomatrixserverlib.ServerName, userID string,
) (gomatrixserverlib.RespUserDevices, error) { ) (gomatrixserverlib.RespUserDevices, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) { ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.GetUserDevices(ctx, s, userID) return a.federation.GetUserDevices(ctx, s, userID)
}) })
@ -121,6 +123,8 @@ func (a *FederationSenderInternalAPI) GetUserDevices(
func (a *FederationSenderInternalAPI) ClaimKeys( func (a *FederationSenderInternalAPI) ClaimKeys(
ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string,
) (gomatrixserverlib.RespClaimKeys, error) { ) (gomatrixserverlib.RespClaimKeys, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) { ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.ClaimKeys(ctx, s, oneTimeKeys) return a.federation.ClaimKeys(ctx, s, oneTimeKeys)
}) })
@ -145,6 +149,8 @@ func (a *FederationSenderInternalAPI) QueryKeys(
func (a *FederationSenderInternalAPI) Backfill( func (a *FederationSenderInternalAPI) Backfill(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
) (res gomatrixserverlib.Transaction, err error) { ) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) { ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.Backfill(ctx, s, roomID, limit, eventIDs) return a.federation.Backfill(ctx, s, roomID, limit, eventIDs)
}) })
@ -157,6 +163,8 @@ func (a *FederationSenderInternalAPI) Backfill(
func (a *FederationSenderInternalAPI) LookupState( func (a *FederationSenderInternalAPI) LookupState(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespState, err error) { ) (res gomatrixserverlib.RespState, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) { ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.LookupState(ctx, s, roomID, eventID, roomVersion) return a.federation.LookupState(ctx, s, roomID, eventID, roomVersion)
}) })
@ -169,6 +177,8 @@ func (a *FederationSenderInternalAPI) LookupState(
func (a *FederationSenderInternalAPI) LookupStateIDs( func (a *FederationSenderInternalAPI) LookupStateIDs(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string,
) (res gomatrixserverlib.RespStateIDs, err error) { ) (res gomatrixserverlib.RespStateIDs, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) { ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.LookupStateIDs(ctx, s, roomID, eventID) return a.federation.LookupStateIDs(ctx, s, roomID, eventID)
}) })
@ -181,6 +191,8 @@ func (a *FederationSenderInternalAPI) LookupStateIDs(
func (a *FederationSenderInternalAPI) GetEvent( func (a *FederationSenderInternalAPI) GetEvent(
ctx context.Context, s gomatrixserverlib.ServerName, eventID string, ctx context.Context, s gomatrixserverlib.ServerName, eventID string,
) (res gomatrixserverlib.Transaction, err error) { ) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) { ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.GetEvent(ctx, s, eventID) return a.federation.GetEvent(ctx, s, eventID)
}) })
@ -193,6 +205,8 @@ func (a *FederationSenderInternalAPI) GetEvent(
func (a *FederationSenderInternalAPI) GetServerKeys( func (a *FederationSenderInternalAPI) GetServerKeys(
ctx context.Context, s gomatrixserverlib.ServerName, ctx context.Context, s gomatrixserverlib.ServerName,
) (gomatrixserverlib.ServerKeys, error) { ) (gomatrixserverlib.ServerKeys, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) { ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.GetServerKeys(ctx, s) return a.federation.GetServerKeys(ctx, s)
}) })
@ -205,6 +219,8 @@ func (a *FederationSenderInternalAPI) GetServerKeys(
func (a *FederationSenderInternalAPI) LookupServerKeys( func (a *FederationSenderInternalAPI) LookupServerKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) ([]gomatrixserverlib.ServerKeys, error) { ) ([]gomatrixserverlib.ServerKeys, error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) { ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.LookupServerKeys(ctx, s, keyRequests) return a.federation.LookupServerKeys(ctx, s, keyRequests)
}) })

View file

@ -2,6 +2,7 @@ package internal
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"time" "time"
@ -42,7 +43,7 @@ type federatedJoin struct {
RoomID string RoomID string
} }
// PerformJoinRequest implements api.FederationSenderInternalAPI // PerformJoin implements api.FederationSenderInternalAPI
func (r *FederationSenderInternalAPI) PerformJoin( func (r *FederationSenderInternalAPI) PerformJoin(
ctx context.Context, ctx context.Context,
request *api.PerformJoinRequest, request *api.PerformJoinRequest,
@ -174,8 +175,10 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer(
// Work out if we support the room version that has been supplied in // Work out if we support the room version that has been supplied in
// the make_join response. // the make_join response.
// "If not provided, the room version is assumed to be either "1" or "2"."
// https://matrix.org/docs/spec/server_server/unstable#get-matrix-federation-v1-make-join-roomid-userid
if respMakeJoin.RoomVersion == "" { if respMakeJoin.RoomVersion == "" {
respMakeJoin.RoomVersion = gomatrixserverlib.RoomVersionV1 respMakeJoin.RoomVersion = setDefaultRoomVersionFromJoinEvent(respMakeJoin.JoinEvent)
} }
if _, err = respMakeJoin.RoomVersion.EventFormat(); err != nil { if _, err = respMakeJoin.RoomVersion.EventFormat(); err != nil {
return fmt.Errorf("respMakeJoin.RoomVersion.EventFormat: %w", err) return fmt.Errorf("respMakeJoin.RoomVersion.EventFormat: %w", err)
@ -193,6 +196,11 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer(
return fmt.Errorf("respMakeJoin.JoinEvent.Build: %w", err) return fmt.Errorf("respMakeJoin.JoinEvent.Build: %w", err)
} }
// No longer reuse the request context from this point forward.
// We don't want the client timing out to interrupt the join.
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(context.Background())
// Try to perform a send_join using the newly built event. // Try to perform a send_join using the newly built event.
respSendJoin, err := r.federation.SendJoin( respSendJoin, err := r.federation.SendJoin(
ctx, ctx,
@ -202,17 +210,23 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer(
) )
if err != nil { if err != nil {
r.statistics.ForServer(serverName).Failure() r.statistics.ForServer(serverName).Failure()
cancel()
return fmt.Errorf("r.federation.SendJoin: %w", err) return fmt.Errorf("r.federation.SendJoin: %w", err)
} }
r.statistics.ForServer(serverName).Success() r.statistics.ForServer(serverName).Success()
// Sanity-check the join response to ensure that it has a create
// event, that the room version is known, etc.
if err := sanityCheckSendJoinResponse(respSendJoin); err != nil {
cancel()
return fmt.Errorf("sanityCheckSendJoinResponse: %w", err)
}
// Process the join response in a goroutine. The idea here is // Process the join response in a goroutine. The idea here is
// that we'll try and wait for as long as possible for the work // that we'll try and wait for as long as possible for the work
// to complete, but if the client does give up waiting, we'll // to complete, but if the client does give up waiting, we'll
// still continue to process the join anyway so that we don't // still continue to process the join anyway so that we don't
// waste the effort. // waste the effort.
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(context.Background())
go func() { go func() {
defer cancel() defer cancel()
@ -424,3 +438,53 @@ func (r *FederationSenderInternalAPI) PerformBroadcastEDU(
return nil return nil
} }
func sanityCheckSendJoinResponse(respSendJoin gomatrixserverlib.RespSendJoin) error {
// sanity check we have a create event and it has a known room version
for _, ev := range respSendJoin.AuthEvents {
if ev.Type() == gomatrixserverlib.MRoomCreate && ev.StateKeyEquals("") {
// make sure the room version is known
content := ev.Content()
verBody := struct {
Version string `json:"room_version"`
}{}
err := json.Unmarshal(content, &verBody)
if err != nil {
return err
}
if verBody.Version == "" {
// https://matrix.org/docs/spec/client_server/r0.6.0#m-room-create
// The version of the room. Defaults to "1" if the key does not exist.
verBody.Version = "1"
}
knownVersions := gomatrixserverlib.RoomVersions()
if _, ok := knownVersions[gomatrixserverlib.RoomVersion(verBody.Version)]; !ok {
return fmt.Errorf("send_join m.room.create event has an unknown room version: %s", verBody.Version)
}
return nil
}
}
return fmt.Errorf("send_join response is missing m.room.create event")
}
func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder) gomatrixserverlib.RoomVersion {
// if auth events are not event references we know it must be v3+
// we have to do these shenanigans to satisfy sytest, specifically for:
// "Outbound federation rejects m.room.create events with an unknown room version"
hasEventRefs := true
authEvents, ok := joinEvent.AuthEvents.([]interface{})
if ok {
if len(authEvents) > 0 {
_, ok = authEvents[0].(string)
if ok {
// event refs are objects, not strings, so we know we must be dealing with a v3+ room.
hasEventRefs = false
}
}
}
if hasEventRefs {
return gomatrixserverlib.RoomVersionV1
}
return gomatrixserverlib.RoomVersionV4
}

View file

@ -349,7 +349,7 @@ func (oq *destinationQueue) nextTransaction() (bool, error) {
// TODO: we should check for 500-ish fails vs 400-ish here, // TODO: we should check for 500-ish fails vs 400-ish here,
// since we shouldn't queue things indefinitely in response // since we shouldn't queue things indefinitely in response
// to a 400-ish error // to a 400-ish error
ctx, cancel = context.WithTimeout(context.Background(), time.Second*30) ctx, cancel = context.WithTimeout(context.Background(), time.Minute*5)
defer cancel() defer cancel()
_, err = oq.client.SendTransaction(ctx, t) _, err = oq.client.SendTransaction(ctx, t)
switch err.(type) { switch err.(type) {

2
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd
github.com/matrix-org/gomatrixserverlib v0.0.0-20201006143701-222e7423a5e3 github.com/matrix-org/gomatrixserverlib v0.0.0-20201015151920-aa4f62b827b8
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.2 github.com/mattn/go-sqlite3 v1.14.2

4
go.sum
View file

@ -569,8 +569,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg=
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20201006143701-222e7423a5e3 h1:lWR/w6rXKZJJU1yGHb2zem/EK7+aYhUcRgAOiouZAxk= github.com/matrix-org/gomatrixserverlib v0.0.0-20201015151920-aa4f62b827b8 h1:GF1PxbvImWDoz1DQZNMoaYtIqQXtyLAtmQOzwwmw1OI=
github.com/matrix-org/gomatrixserverlib v0.0.0-20201006143701-222e7423a5e3/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/gomatrixserverlib v0.0.0-20201015151920-aa4f62b827b8/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4=
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=

View file

@ -26,13 +26,9 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/matrix-org/naffka"
naffkaStorage "github.com/matrix-org/naffka/storage"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/Shopify/sarama"
"github.com/gorilla/mux" "github.com/gorilla/mux"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
@ -73,8 +69,8 @@ type BaseDendrite struct {
httpClient *http.Client httpClient *http.Client
Cfg *config.Dendrite Cfg *config.Dendrite
Caches *caching.Caches Caches *caching.Caches
KafkaConsumer sarama.Consumer // KafkaConsumer sarama.Consumer
KafkaProducer sarama.SyncProducer // KafkaProducer sarama.SyncProducer
} }
const HTTPServerTimeout = time.Minute * 5 const HTTPServerTimeout = time.Minute * 5
@ -106,14 +102,6 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, useHTTPAPIs boo
logrus.WithError(err).Panicf("failed to start opentracing") logrus.WithError(err).Panicf("failed to start opentracing")
} }
var kafkaConsumer sarama.Consumer
var kafkaProducer sarama.SyncProducer
if cfg.Global.Kafka.UseNaffka {
kafkaConsumer, kafkaProducer = setupNaffka(cfg)
} else {
kafkaConsumer, kafkaProducer = setupKafka(cfg)
}
cache, err := caching.NewInMemoryLRUCache(true) cache, err := caching.NewInMemoryLRUCache(true)
if err != nil { if err != nil {
logrus.WithError(err).Warnf("Failed to create cache") logrus.WithError(err).Warnf("Failed to create cache")
@ -152,8 +140,6 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, useHTTPAPIs boo
InternalAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.InternalPathPrefix).Subrouter().UseEncodedPath(), InternalAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.InternalPathPrefix).Subrouter().UseEncodedPath(),
apiHttpClient: &apiClient, apiHttpClient: &apiClient,
httpClient: &client, httpClient: &client,
KafkaConsumer: kafkaConsumer,
KafkaProducer: kafkaProducer,
} }
} }
@ -254,9 +240,9 @@ func (b *BaseDendrite) CreateClient() *gomatrixserverlib.Client {
// CreateFederationClient creates a new federation client. Should only be called // CreateFederationClient creates a new federation client. Should only be called
// once per component. // once per component.
func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationClient { func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationClient {
client := gomatrixserverlib.NewFederationClient( client := gomatrixserverlib.NewFederationClientWithTimeout(
b.Cfg.Global.ServerName, b.Cfg.Global.KeyID, b.Cfg.Global.PrivateKey, b.Cfg.Global.ServerName, b.Cfg.Global.KeyID, b.Cfg.Global.PrivateKey,
b.Cfg.FederationSender.DisableTLSValidation, b.Cfg.FederationSender.DisableTLSValidation, time.Minute*5,
) )
client.SetUserAgent(fmt.Sprintf("Dendrite/%s", internal.VersionString())) client.SetUserAgent(fmt.Sprintf("Dendrite/%s", internal.VersionString()))
return client return client
@ -336,31 +322,3 @@ func (b *BaseDendrite) SetupAndServeHTTP(
select {} select {}
} }
// setupKafka creates kafka consumer/producer pair from the config.
func setupKafka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
consumer, err := sarama.NewConsumer(cfg.Global.Kafka.Addresses, nil)
if err != nil {
logrus.WithError(err).Panic("failed to start kafka consumer")
}
producer, err := sarama.NewSyncProducer(cfg.Global.Kafka.Addresses, nil)
if err != nil {
logrus.WithError(err).Panic("failed to setup kafka producers")
}
return consumer, producer
}
// setupNaffka creates kafka consumer/producer pair from the config.
func setupNaffka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
naffkaDB, err := naffkaStorage.NewDatabase(string(cfg.Global.Kafka.Database.ConnectionString))
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka database")
}
naff, err := naffka.New(naffkaDB)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka")
}
return naff, naff
}

View file

@ -0,0 +1,53 @@
package kafka
import (
"github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/naffka"
naffkaStorage "github.com/matrix-org/naffka/storage"
"github.com/sirupsen/logrus"
)
func SetupConsumerProducer(cfg *config.Kafka) (sarama.Consumer, sarama.SyncProducer) {
if cfg.UseNaffka {
return setupNaffka(cfg)
}
return setupKafka(cfg)
}
// setupKafka creates kafka consumer/producer pair from the config.
func setupKafka(cfg *config.Kafka) (sarama.Consumer, sarama.SyncProducer) {
consumer, err := sarama.NewConsumer(cfg.Addresses, nil)
if err != nil {
logrus.WithError(err).Panic("failed to start kafka consumer")
}
producer, err := sarama.NewSyncProducer(cfg.Addresses, nil)
if err != nil {
logrus.WithError(err).Panic("failed to setup kafka producers")
}
return consumer, producer
}
// In monolith mode with Naffka, we don't have the same constraints about
// consuming the same topic from more than one place like we do with Kafka.
// Therefore, we will only open one Naffka connection in case Naffka is
// running on SQLite.
var naffkaInstance *naffka.Naffka
// setupNaffka creates kafka consumer/producer pair from the config.
func setupNaffka(cfg *config.Kafka) (sarama.Consumer, sarama.SyncProducer) {
if naffkaInstance != nil {
return naffkaInstance, naffkaInstance
}
naffkaDB, err := naffkaStorage.NewDatabase(string(cfg.Database.ConnectionString))
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka database")
}
naffkaInstance, err = naffka.New(naffkaDB)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka")
}
return naffkaInstance, naffkaInstance
}

View file

@ -15,7 +15,6 @@
package setup package setup
import ( import (
"github.com/Shopify/sarama"
"github.com/gorilla/mux" "github.com/gorilla/mux"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi" "github.com/matrix-org/dendrite/clientapi"
@ -38,13 +37,11 @@ import (
// Monolith represents an instantiation of all dependencies required to build // Monolith represents an instantiation of all dependencies required to build
// all components of Dendrite, for use in monolith mode. // all components of Dendrite, for use in monolith mode.
type Monolith struct { type Monolith struct {
Config *config.Dendrite Config *config.Dendrite
AccountDB accounts.Database AccountDB accounts.Database
KeyRing *gomatrixserverlib.KeyRing KeyRing *gomatrixserverlib.KeyRing
Client *gomatrixserverlib.Client Client *gomatrixserverlib.Client
FedClient *gomatrixserverlib.FederationClient FedClient *gomatrixserverlib.FederationClient
KafkaConsumer sarama.Consumer
KafkaProducer sarama.SyncProducer
AppserviceAPI appserviceAPI.AppServiceQueryAPI AppserviceAPI appserviceAPI.AppServiceQueryAPI
EDUInternalAPI eduServerAPI.EDUServerInputAPI EDUInternalAPI eduServerAPI.EDUServerInputAPI
@ -61,7 +58,7 @@ type Monolith struct {
// AddAllPublicRoutes attaches all public paths to the given router // AddAllPublicRoutes attaches all public paths to the given router
func (m *Monolith) AddAllPublicRoutes(csMux, ssMux, keyMux, mediaMux *mux.Router) { func (m *Monolith) AddAllPublicRoutes(csMux, ssMux, keyMux, mediaMux *mux.Router) {
clientapi.AddPublicRoutes( clientapi.AddPublicRoutes(
csMux, &m.Config.ClientAPI, m.KafkaProducer, m.AccountDB, csMux, &m.Config.ClientAPI, m.AccountDB,
m.FedClient, m.RoomserverAPI, m.FedClient, m.RoomserverAPI,
m.EDUInternalAPI, m.AppserviceAPI, transactions.New(), m.EDUInternalAPI, m.AppserviceAPI, transactions.New(),
m.FederationSenderAPI, m.UserAPI, m.KeyAPI, m.ExtPublicRoomsProvider, m.FederationSenderAPI, m.UserAPI, m.KeyAPI, m.ExtPublicRoomsProvider,
@ -73,7 +70,7 @@ func (m *Monolith) AddAllPublicRoutes(csMux, ssMux, keyMux, mediaMux *mux.Router
) )
mediaapi.AddPublicRoutes(mediaMux, &m.Config.MediaAPI, m.UserAPI, m.Client) mediaapi.AddPublicRoutes(mediaMux, &m.Config.MediaAPI, m.UserAPI, m.Client)
syncapi.AddPublicRoutes( syncapi.AddPublicRoutes(
csMux, m.KafkaConsumer, m.UserAPI, m.RoomserverAPI, csMux, m.UserAPI, m.RoomserverAPI,
m.KeyAPI, m.FedClient, &m.Config.SyncAPI, m.KeyAPI, m.FedClient, &m.Config.SyncAPI,
) )
} }

130
internal/sqlutil/migrate.go Normal file
View file

@ -0,0 +1,130 @@
package sqlutil
import (
"database/sql"
"fmt"
"runtime"
"sort"
"github.com/matrix-org/dendrite/internal/config"
"github.com/pressly/goose"
)
type Migrations struct {
registeredGoMigrations map[int64]*goose.Migration
}
func NewMigrations() *Migrations {
return &Migrations{
registeredGoMigrations: make(map[int64]*goose.Migration),
}
}
// Copy-pasted from goose directly to store migrations into a map we control
// AddMigration adds a migration.
func (m *Migrations) AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
_, filename, _, _ := runtime.Caller(1)
m.AddNamedMigration(filename, up, down)
}
// AddNamedMigration : Add a named migration.
func (m *Migrations) AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) {
v, _ := goose.NumericComponent(filename)
migration := &goose.Migration{Version: v, Next: -1, Previous: -1, Registered: true, UpFn: up, DownFn: down, Source: filename}
if existing, ok := m.registeredGoMigrations[v]; ok {
panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source))
}
m.registeredGoMigrations[v] = migration
}
// RunDeltas up to the latest version.
func (m *Migrations) RunDeltas(db *sql.DB, props *config.DatabaseOptions) error {
maxVer := goose.MaxVersion
minVer := int64(0)
migrations, err := m.collect(minVer, maxVer)
if err != nil {
return fmt.Errorf("RunDeltas: Failed to collect migrations: %w", err)
}
if props.ConnectionString.IsPostgres() {
if err = goose.SetDialect("postgres"); err != nil {
return err
}
} else if props.ConnectionString.IsSQLite() {
if err = goose.SetDialect("sqlite3"); err != nil {
return err
}
} else {
return fmt.Errorf("Unknown connection string: %s", props.ConnectionString)
}
for {
current, err := goose.EnsureDBVersion(db)
if err != nil {
return fmt.Errorf("RunDeltas: Failed to EnsureDBVersion: %w", err)
}
next, err := migrations.Next(current)
if err != nil {
if err == goose.ErrNoNextVersion {
return nil
}
return fmt.Errorf("RunDeltas: Failed to load next migration to %+v : %w", next, err)
}
if err = next.Up(db); err != nil {
return fmt.Errorf("RunDeltas: Failed run migration: %w", err)
}
}
}
func (m *Migrations) collect(current, target int64) (goose.Migrations, error) {
var migrations goose.Migrations
// Go migrations registered via goose.AddMigration().
for _, migration := range m.registeredGoMigrations {
v, err := goose.NumericComponent(migration.Source)
if err != nil {
return nil, err
}
if versionFilter(v, current, target) {
migrations = append(migrations, migration)
}
}
migrations = sortAndConnectMigrations(migrations)
return migrations, nil
}
func sortAndConnectMigrations(migrations goose.Migrations) goose.Migrations {
sort.Sort(migrations)
// now that we're sorted in the appropriate direction,
// populate next and previous for each migration
for i, m := range migrations {
prev := int64(-1)
if i > 0 {
prev = migrations[i-1].Version
migrations[i-1].Next = m.Version
}
migrations[i].Previous = prev
}
return migrations
}
func versionFilter(v, current, target int64) bool {
if target > current {
return v > current && v <= target
}
if target < current {
return v <= current && v > target
}
return false
}

View file

@ -93,7 +93,7 @@ func trackGoID(query string) {
if strings.HasPrefix(q, "SELECT") { if strings.HasPrefix(q, "SELECT") {
return // SELECTs can go on other goroutines return // SELECTs can go on other goroutines
} }
logrus.Warnf("unsafe goid: SQL executed not on an ExclusiveWriter: %s", q) logrus.Warnf("unsafe goid %d: SQL executed not on an ExclusiveWriter: %s", thisGoID, q)
} }
var dbConns = make(map[string]*sql.DB) var dbConns = make(map[string]*sql.DB)

View file

@ -1,6 +1,12 @@
package internal package internal
import "fmt" import (
"fmt"
"strings"
)
// the final version string
var version string
// -ldflags "-X github.com/matrix-org/dendrite/internal.branch=master" // -ldflags "-X github.com/matrix-org/dendrite/internal.branch=master"
var branch string var branch string
@ -16,12 +22,22 @@ const (
) )
func VersionString() string { func VersionString() string {
version := fmt.Sprintf("%d.%d.%d%s", VersionMajor, VersionMinor, VersionPatch, VersionTag)
if branch != "" {
version += fmt.Sprintf("-%s", branch)
}
if build != "" {
version += fmt.Sprintf("+%s", build)
}
return version return version
} }
func init() {
version = fmt.Sprintf("%d.%d.%d", VersionMajor, VersionMinor, VersionPatch)
if VersionTag != "" {
version += "-" + VersionTag
}
parts := []string{}
if build != "" {
parts = append(parts, build)
}
if branch != "" {
parts = append(parts, branch)
}
if len(parts) > 0 {
version += "+" + strings.Join(parts, ".")
}
}

View file

@ -108,7 +108,7 @@ func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrix
fedClient := gomatrixserverlib.NewFederationClient( fedClient := gomatrixserverlib.NewFederationClient(
gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey, true, gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey, true,
) )
fedClient.Client = *gomatrixserverlib.NewClientWithTransport(true, &roundTripper{tripper}) fedClient.Client = *gomatrixserverlib.NewClientWithTransport(&roundTripper{tripper})
return fedClient return fedClient
} }

View file

@ -15,10 +15,10 @@
package keyserver package keyserver
import ( import (
"github.com/Shopify/sarama"
"github.com/gorilla/mux" "github.com/gorilla/mux"
fedsenderapi "github.com/matrix-org/dendrite/federationsender/api" fedsenderapi "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/setup/kafka"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/internal" "github.com/matrix-org/dendrite/keyserver/internal"
"github.com/matrix-org/dendrite/keyserver/inthttp" "github.com/matrix-org/dendrite/keyserver/inthttp"
@ -36,8 +36,10 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers // NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI( func NewInternalAPI(
cfg *config.KeyServer, fedClient fedsenderapi.FederationClient, producer sarama.SyncProducer, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient,
) api.KeyInternalAPI { ) api.KeyInternalAPI {
_, producer := kafka.SetupConsumerProducer(&cfg.Matrix.Kafka)
db, err := storage.NewDatabase(&cfg.Database) db, err := storage.NewDatabase(&cfg.Database)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to key server database") logrus.WithError(err).Panicf("failed to connect to key server database")

View file

@ -63,7 +63,8 @@ type QueryStateAfterEventsRequest struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
// The list of previous events to return the events after. // The list of previous events to return the events after.
PrevEventIDs []string `json:"prev_event_ids"` PrevEventIDs []string `json:"prev_event_ids"`
// The state key tuples to fetch from the state // The state key tuples to fetch from the state. If none are specified then
// the entire resolved room state will be returned.
StateToFetch []gomatrixserverlib.StateKeyTuple `json:"state_to_fetch"` StateToFetch []gomatrixserverlib.StateKeyTuple `json:"state_to_fetch"`
} }

View file

@ -133,8 +133,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// If the event has already been written to the output log then we // If the event has already been written to the output log then we
// don't need to do anything, as we've handled it already. // don't need to do anything, as we've handled it already.
hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID) if hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID); err != nil {
if err != nil {
return fmt.Errorf("u.updater.HasEventBeenSent: %w", err) return fmt.Errorf("u.updater.HasEventBeenSent: %w", err)
} else if hasBeenSent { } else if hasBeenSent {
return nil return nil
@ -142,17 +141,19 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// Work out what the latest events are. This will include the new // Work out what the latest events are. This will include the new
// event if it is not already referenced. // event if it is not already referenced.
u.calculateLatest( if err := u.calculateLatest(
oldLatest, oldLatest,
types.StateAtEventAndReference{ types.StateAtEventAndReference{
EventReference: u.event.EventReference(), EventReference: u.event.EventReference(),
StateAtEvent: u.stateAtEvent, StateAtEvent: u.stateAtEvent,
}, },
) ); err != nil {
return fmt.Errorf("u.calculateLatest: %w", err)
}
// Now that we know what the latest events are, it's time to get the // Now that we know what the latest events are, it's time to get the
// latest state. // latest state.
if err = u.latestState(); err != nil { if err := u.latestState(); err != nil {
return fmt.Errorf("u.latestState: %w", err) return fmt.Errorf("u.latestState: %w", err)
} }
@ -261,7 +262,7 @@ func (u *latestEventsUpdater) latestState() error {
func (u *latestEventsUpdater) calculateLatest( func (u *latestEventsUpdater) calculateLatest(
oldLatest []types.StateAtEventAndReference, oldLatest []types.StateAtEventAndReference,
newEvent types.StateAtEventAndReference, newEvent types.StateAtEventAndReference,
) { ) error {
var newLatest []types.StateAtEventAndReference var newLatest []types.StateAtEventAndReference
// First of all, let's see if any of the existing forward extremities // First of all, let's see if any of the existing forward extremities
@ -271,6 +272,7 @@ func (u *latestEventsUpdater) calculateLatest(
referenced, err := u.updater.IsReferenced(l.EventReference) referenced, err := u.updater.IsReferenced(l.EventReference)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", l.EventID) logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", l.EventID)
return fmt.Errorf("u.updater.IsReferenced (old): %w", err)
} else if !referenced { } else if !referenced {
newLatest = append(newLatest, l) newLatest = append(newLatest, l)
} }
@ -285,7 +287,7 @@ func (u *latestEventsUpdater) calculateLatest(
// We've already referenced this new event so we can just return // We've already referenced this new event so we can just return
// the newly completed extremities at this point. // the newly completed extremities at this point.
u.latest = newLatest u.latest = newLatest
return return nil
} }
} }
@ -296,11 +298,13 @@ func (u *latestEventsUpdater) calculateLatest(
referenced, err := u.updater.IsReferenced(newEvent.EventReference) referenced, err := u.updater.IsReferenced(newEvent.EventReference)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", newEvent.EventReference.EventID) logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", newEvent.EventReference.EventID)
return fmt.Errorf("u.updater.IsReferenced (new): %w", err)
} else if !referenced || len(newLatest) == 0 { } else if !referenced || len(newLatest) == 0 {
newLatest = append(newLatest, newEvent) newLatest = append(newLatest, newEvent)
} }
u.latest = newLatest u.latest = newLatest
return nil
} }
func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) { func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) {

View file

@ -49,6 +49,7 @@ func (r *Queryer) QueryLatestEventsAndState(
} }
// QueryStateAfterEvents implements api.RoomserverInternalAPI // QueryStateAfterEvents implements api.RoomserverInternalAPI
// nolint:gocyclo
func (r *Queryer) QueryStateAfterEvents( func (r *Queryer) QueryStateAfterEvents(
ctx context.Context, ctx context.Context,
request *api.QueryStateAfterEventsRequest, request *api.QueryStateAfterEventsRequest,
@ -78,10 +79,18 @@ func (r *Queryer) QueryStateAfterEvents(
} }
response.PrevEventsExist = true response.PrevEventsExist = true
// Look up the currrent state for the requested tuples. var stateEntries []types.StateEntry
stateEntries, err := roomState.LoadStateAfterEventsForStringTuples( if len(request.StateToFetch) == 0 {
ctx, prevStates, request.StateToFetch, // Look up all of the current room state.
) stateEntries, err = roomState.LoadCombinedStateAfterEvents(
ctx, prevStates,
)
} else {
// Look up the current state for the requested tuples.
stateEntries, err = roomState.LoadStateAfterEventsForStringTuples(
ctx, prevStates, request.StateToFetch,
)
}
if err != nil { if err != nil {
return err return err
} }
@ -91,6 +100,24 @@ func (r *Queryer) QueryStateAfterEvents(
return err return err
} }
if len(request.PrevEventIDs) > 1 && len(request.StateToFetch) == 0 {
var authEventIDs []string
for _, e := range stateEvents {
authEventIDs = append(authEventIDs, e.AuthEventIDs()...)
}
authEventIDs = util.UniqueStrings(authEventIDs)
authEvents, err := getAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
if err != nil {
return fmt.Errorf("getAuthChain: %w", err)
}
stateEvents, err = state.ResolveConflictsAdhoc(info.RoomVersion, stateEvents, authEvents)
if err != nil {
return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err)
}
}
for _, event := range stateEvents { for _, event := range stateEvents {
response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion)) response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion))
} }
@ -122,7 +149,7 @@ func (r *Queryer) QueryMissingAuthPrevEvents(
} }
for _, prevEventID := range request.PrevEventIDs { for _, prevEventID := range request.PrevEventIDs {
if nids, err := r.DB.EventNIDs(ctx, []string{prevEventID}); err != nil || len(nids) == 0 { if state, err := r.DB.StateAtEventIDs(ctx, []string{prevEventID}); err != nil || len(state) == 0 {
response.MissingPrevEventIDs = append(response.MissingPrevEventIDs, prevEventID) response.MissingPrevEventIDs = append(response.MissingPrevEventIDs, prevEventID)
} }
} }

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/internal/setup"
"github.com/matrix-org/dendrite/internal/setup/kafka"
"github.com/matrix-org/dendrite/roomserver/internal" "github.com/matrix-org/dendrite/roomserver/internal"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -41,6 +42,8 @@ func NewInternalAPI(
) api.RoomserverInternalAPI { ) api.RoomserverInternalAPI {
cfg := &base.Cfg.RoomServer cfg := &base.Cfg.RoomServer
_, producer := kafka.SetupConsumerProducer(&cfg.Matrix.Kafka)
var perspectiveServerNames []gomatrixserverlib.ServerName var perspectiveServerNames []gomatrixserverlib.ServerName
for _, kp := range base.Cfg.SigningKeyServer.KeyPerspectives { for _, kp := range base.Cfg.SigningKeyServer.KeyPerspectives {
perspectiveServerNames = append(perspectiveServerNames, kp.ServerName) perspectiveServerNames = append(perspectiveServerNames, kp.ServerName)
@ -52,7 +55,7 @@ func NewInternalAPI(
} }
return internal.NewRoomserverAPI( return internal.NewRoomserverAPI(
cfg, roomserverDB, base.KafkaProducer, string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputRoomEvent)), cfg, roomserverDB, producer, string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputRoomEvent)),
base.Caches, keyRing, perspectiveServerNames, base.Caches, keyRing, perspectiveServerNames,
) )
} }

View file

@ -17,7 +17,10 @@ import (
"github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/internal/setup"
"github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
) )
const ( const (
@ -160,7 +163,9 @@ func mustCreateRoomserverAPI(t *testing.T) (api.RoomserverInternalAPI, *dummyPro
cfg.Defaults() cfg.Defaults()
cfg.Global.ServerName = testOrigin cfg.Global.ServerName = testOrigin
cfg.Global.Kafka.UseNaffka = true cfg.Global.Kafka.UseNaffka = true
cfg.RoomServer.Database.ConnectionString = config.DataSource(roomserverDBFileURI) cfg.RoomServer.Database = config.DatabaseOptions{
ConnectionString: roomserverDBFileURI,
}
dp := &dummyProducer{ dp := &dummyProducer{
topic: cfg.Global.Kafka.TopicFor(config.TopicOutputRoomEvent), topic: cfg.Global.Kafka.TopicFor(config.TopicOutputRoomEvent),
} }
@ -169,12 +174,17 @@ func mustCreateRoomserverAPI(t *testing.T) (api.RoomserverInternalAPI, *dummyPro
t.Fatalf("failed to make caches: %s", err) t.Fatalf("failed to make caches: %s", err)
} }
base := &setup.BaseDendrite{ base := &setup.BaseDendrite{
KafkaProducer: dp, Caches: cache,
Caches: cache, Cfg: cfg,
Cfg: cfg,
} }
roomserverDB, err := storage.Open(&cfg.RoomServer.Database, base.Caches)
return NewInternalAPI(base, &test.NopJSONVerifier{}), dp if err != nil {
logrus.WithError(err).Panicf("failed to connect to room server db")
}
return internal.NewRoomserverAPI(
&cfg.RoomServer, roomserverDB, dp, string(cfg.Global.Kafka.TopicFor(config.TopicOutputRoomEvent)),
base.Caches, &test.NopJSONVerifier{}, nil,
), dp
} }
func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) (api.RoomserverInternalAPI, *dummyProducer, []gomatrixserverlib.HeaderedEvent) { func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) (api.RoomserverInternalAPI, *dummyProducer, []gomatrixserverlib.HeaderedEvent) {

View file

@ -39,7 +39,8 @@ CREATE INDEX IF NOT EXISTS roomserver_redactions_redacts_event_id ON roomserver_
const insertRedactionSQL = "" + const insertRedactionSQL = "" +
"INSERT INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" + "INSERT INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" +
" VALUES ($1, $2, $3)" " VALUES ($1, $2, $3)" +
" ON CONFLICT DO NOTHING"
const selectRedactionInfoByRedactionEventIDSQL = "" + const selectRedactionInfoByRedactionEventIDSQL = "" +
"SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" + "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +

View file

@ -18,23 +18,30 @@ type LatestEventsUpdater struct {
currentStateSnapshotNID types.StateSnapshotNID currentStateSnapshotNID types.StateSnapshotNID
} }
func rollback(txn *sql.Tx) {
if txn == nil {
return
}
txn.Rollback() // nolint: errcheck
}
func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) { func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) {
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID) d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID)
if err != nil { if err != nil {
txn.Rollback() // nolint: errcheck rollback(txn)
return nil, err return nil, err
} }
stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
if err != nil { if err != nil {
txn.Rollback() // nolint: errcheck rollback(txn)
return nil, err return nil, err
} }
var lastEventIDSent string var lastEventIDSent string
if lastEventNIDSent != 0 { if lastEventNIDSent != 0 {
lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent) lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent)
if err != nil { if err != nil {
txn.Rollback() // nolint: errcheck rollback(txn)
return nil, err return nil, err
} }
} }

View file

@ -22,12 +22,21 @@ func NewMembershipUpdater(
ctx context.Context, d *Database, txn *sql.Tx, roomID, targetUserID string, ctx context.Context, d *Database, txn *sql.Tx, roomID, targetUserID string,
targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
) (*MembershipUpdater, error) { ) (*MembershipUpdater, error) {
roomNID, err := d.assignRoomNID(ctx, txn, roomID, roomVersion) var roomNID types.RoomNID
if err != nil { var targetUserNID types.EventStateKeyNID
return nil, err var err error
} err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
roomNID, err = d.assignRoomNID(ctx, txn, roomID, roomVersion)
if err != nil {
return err
}
targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID) targetUserNID, err = d.assignStateKeyNID(ctx, txn, targetUserID)
if err != nil {
return err
}
return nil
})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -458,7 +458,7 @@ func (d *Database) StoreEvent(
eventNID, stateNID, err = d.EventsTable.SelectEvent(ctx, txn, event.EventID()) eventNID, stateNID, err = d.EventsTable.SelectEvent(ctx, txn, event.EventID())
} }
if err != nil { if err != nil {
return err return fmt.Errorf("d.EventsTable.SelectEvent: %w", err)
} }
} }
@ -467,6 +467,9 @@ func (d *Database) StoreEvent(
} }
if !isRejected { // ignore rejected redaction events if !isRejected { // ignore rejected redaction events
redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event) redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event)
if err != nil {
return fmt.Errorf("d.handleRedactions: %w", err)
}
} }
return nil return nil
}) })
@ -627,6 +630,7 @@ func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) (
// to cross-reference with other tables when loading. // to cross-reference with other tables when loading.
// //
// Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction. // Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction.
// nolint:gocyclo
func (d *Database) handleRedactions( func (d *Database) handleRedactions(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event gomatrixserverlib.Event, ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event gomatrixserverlib.Event,
) (*gomatrixserverlib.Event, string, error) { ) (*gomatrixserverlib.Event, string, error) {
@ -644,13 +648,13 @@ func (d *Database) handleRedactions(
RedactsEventID: event.Redacts(), RedactsEventID: event.Redacts(),
}) })
if err != nil { if err != nil {
return nil, "", err return nil, "", fmt.Errorf("d.RedactionsTable.InsertRedaction: %w", err)
} }
} }
redactionEvent, redactedEvent, validated, err := d.loadRedactionPair(ctx, txn, eventNID, event) redactionEvent, redactedEvent, validated, err := d.loadRedactionPair(ctx, txn, eventNID, event)
if err != nil { if err != nil {
return nil, "", err return nil, "", fmt.Errorf("d.loadRedactionPair: %w", err)
} }
if validated || redactedEvent == nil || redactionEvent == nil { if validated || redactedEvent == nil || redactionEvent == nil {
// we've seen this redaction before or there is nothing to redact // we've seen this redaction before or there is nothing to redact
@ -664,7 +668,7 @@ func (d *Database) handleRedactions(
// mark the event as redacted // mark the event as redacted
err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent) err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent)
if err != nil { if err != nil {
return nil, "", err return nil, "", fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
} }
if redactionsArePermanent { if redactionsArePermanent {
redactedEvent.Event = redactedEvent.Redact() redactedEvent.Event = redactedEvent.Redact()
@ -672,10 +676,15 @@ func (d *Database) handleRedactions(
// overwrite the eventJSON table // overwrite the eventJSON table
err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON()) err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON())
if err != nil { if err != nil {
return nil, "", err return nil, "", fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
} }
return &redactionEvent.Event, redactedEvent.EventID(), d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true) err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true)
if err != nil {
err = fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err)
}
return &redactionEvent.Event, redactedEvent.EventID(), err
} }
// loadRedactionPair returns both the redaction event and the redacted event, else nil. // loadRedactionPair returns both the redaction event and the redacted event, else nil.

View file

@ -37,7 +37,7 @@ CREATE TABLE IF NOT EXISTS roomserver_redactions (
` `
const insertRedactionSQL = "" + const insertRedactionSQL = "" +
"INSERT INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" + "INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" +
" VALUES ($1, $2, $3)" " VALUES ($1, $2, $3)"
const selectRedactionInfoByRedactionEventIDSQL = "" + const selectRedactionInfoByRedactionEventIDSQL = "" +

View file

@ -224,10 +224,6 @@ func (s *ServerKeyAPI) handleFetcherKeys(
// Now let's look at the results that we got from this fetcher. // Now let's look at the results that we got from this fetcher.
for req, res := range fetcherResults { for req, res := range fetcherResults {
if req.ServerName == s.ServerName {
continue
}
if prev, ok := results[req]; ok { if prev, ok := results[req]; ok {
// We've already got a previous entry for this request // We've already got a previous entry for this request
// so let's see if the newly retrieved one contains a more // so let's see if the newly retrieved one contains a more

View file

@ -17,11 +17,11 @@ package syncapi
import ( import (
"context" "context"
"github.com/Shopify/sarama"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/setup/kafka"
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
@ -37,13 +37,14 @@ import (
// component. // component.
func AddPublicRoutes( func AddPublicRoutes(
router *mux.Router, router *mux.Router,
consumer sarama.Consumer,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
keyAPI keyapi.KeyInternalAPI, keyAPI keyapi.KeyInternalAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
cfg *config.SyncAPI, cfg *config.SyncAPI,
) { ) {
consumer, _ := kafka.SetupConsumerProducer(&cfg.Matrix.Kafka)
syncDB, err := storage.NewSyncServerDatasource(&cfg.Database) syncDB, err := storage.NewSyncServerDatasource(&cfg.Database)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to sync db") logrus.WithError(err).Panicf("failed to connect to sync db")

View file

@ -68,6 +68,8 @@ Request to logout with invalid an access token is rejected
Request to logout without an access token is rejected Request to logout without an access token is rejected
Room creation reports m.room.create to myself Room creation reports m.room.create to myself
Room creation reports m.room.member to myself Room creation reports m.room.member to myself
Outbound federation rejects send_join responses with no m.room.create event
Outbound federation rejects m.room.create events with an unknown room version
Invited user can see room metadata Invited user can see room metadata
# Blacklisted because these tests call /r0/events which we don't implement # Blacklisted because these tests call /r0/events which we don't implement
# New room members see their own join event # New room members see their own join event
@ -480,3 +482,5 @@ m.room.history_visibility == "joined" allows/forbids appropriately for Real user
POST rejects invalid utf-8 in JSON POST rejects invalid utf-8 in JSON
Users cannot kick users who have already left a room Users cannot kick users who have already left a room
A prev_batch token from incremental sync can be used in the v1 messages API A prev_batch token from incremental sync can be used in the v1 messages API
Event with an invalid signature in the send_join response should not cause room join to fail
Inbound federation rejects typing notifications from wrong remote

View file

@ -75,11 +75,12 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *accountsStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(accountsSchema)
return err
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
_, err = db.Exec(accountsSchema)
if err != nil {
return
}
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return return
} }

View file

@ -0,0 +1,33 @@
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpIsActive, DownIsActive)
}
func LoadIsActive(m *sqlutil.Migrations) {
m.AddMigration(UpIsActive, DownIsActive)
}
func UpIsActive(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownIsActive(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN is_deactivated;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -1,9 +0,0 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE account_accounts DROP COLUMN is_deactivated;
-- +goose StatementEnd

View file

@ -25,6 +25,8 @@ import (
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
_ "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@ -55,6 +57,18 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
db: db, db: db,
writer: sqlutil.NewDummyWriter(), writer: sqlutil.NewDummyWriter(),
} }
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.accounts.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadIsActive(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil { if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil {
return nil, err return nil, err
} }
@ -70,6 +84,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.threepids.prepare(db); err != nil { if err = d.threepids.prepare(db); err != nil {
return nil, err return nil, err
} }
return d, nil return d, nil
} }

View file

@ -74,13 +74,13 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *accountsStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(accountsSchema)
return err
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db s.db = db
_, err = db.Exec(accountsSchema)
if err != nil {
return
}
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return return
} }

View file

@ -0,0 +1,64 @@
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpIsActive, DownIsActive)
}
func LoadIsActive(m *sqlutil.Migrations) {
m.AddMigration(UpIsActive, DownIsActive)
}
func UpIsActive(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
CREATE TABLE account_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL,
password_hash TEXT,
appservice_id TEXT,
is_deactivated BOOLEAN DEFAULT 0
);
INSERT
INTO account_accounts (
localpart, created_ts, password_hash, appservice_id
) SELECT
localpart, created_ts, password_hash, appservice_id
FROM account_accounts_tmp
;
DROP TABLE account_accounts_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownIsActive(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
CREATE TABLE account_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL,
password_hash TEXT,
appservice_id TEXT
);
INSERT
INTO account_accounts (
localpart, created_ts, password_hash, appservice_id
) SELECT
localpart, created_ts, password_hash, appservice_id
FROM account_accounts_tmp
;
DROP TABLE account_accounts_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -1,38 +0,0 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
CREATE TABLE account_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL,
password_hash TEXT,
appservice_id TEXT,
is_deactivated BOOLEAN DEFAULT 0
);
INSERT
INTO account_accounts (
localpart, created_ts, password_hash, appservice_id
) SELECT
localpart, created_ts, password_hash, appservice_id
FROM account_accounts_tmp
;
DROP TABLE account_accounts_tmp;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
CREATE TABLE account_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL,
password_hash TEXT,
appservice_id TEXT
);
INSERT
INTO account_accounts (
localpart, created_ts, password_hash, appservice_id
) SELECT
localpart, created_ts, password_hash, appservice_id
FROM account_accounts_tmp
;
DROP TABLE account_accounts_tmp;
-- +goose StatementEnd

View file

@ -26,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
// Import the sqlite3 database driver. // Import the sqlite3 database driver.
@ -60,6 +61,18 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
db: db, db: db,
writer: sqlutil.NewExclusiveWriter(), writer: sqlutil.NewExclusiveWriter(),
} }
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.accounts.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadIsActive(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
partitions := sqlutil.PartitionOffsetStatements{} partitions := sqlutil.PartitionOffsetStatements{}
if err = partitions.Prepare(db, d.writer, "account"); err != nil { if err = partitions.Prepare(db, d.writer, "account"); err != nil {
return nil, err return nil, err
@ -76,6 +89,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.threepids.prepare(db); err != nil { if err = d.threepids.prepare(db); err != nil {
return nil, err return nil, err
} }
return d, nil return d, nil
} }

View file

@ -0,0 +1,39 @@
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}
func LoadLastSeenTSIP(m *sqlutil.Migrations) {
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}
func UpLastSeenTSIP(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000;
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS ip TEXT;
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownLastSeenTSIP(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE device_devices DROP COLUMN last_seen_ts;
ALTER TABLE device_devices DROP COLUMN ip;
ALTER TABLE device_devices DROP COLUMN user_agent;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -1,13 +0,0 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL;
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS ip TEXT;
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE device_devices DROP COLUMN last_seen_ts;
ALTER TABLE device_devices DROP COLUMN ip;
ALTER TABLE device_devices DROP COLUMN user_agent;
-- +goose StatementEnd

View file

@ -111,11 +111,12 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *devicesStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(devicesSchema)
return err
}
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
_, err = db.Exec(devicesSchema)
if err != nil {
return
}
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return return
} }

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -42,9 +43,22 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
return nil, err return nil, err
} }
d := devicesStatements{} d := devicesStatements{}
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.prepare(db, serverName); err != nil { if err = d.prepare(db, serverName); err != nil {
return nil, err return nil, err
} }
return &Database{db, d}, nil return &Database{db, d}, nil
} }

View file

@ -0,0 +1,70 @@
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}
func LoadLastSeenTSIP(m *sqlutil.Migrations) {
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}
func UpLastSeenTSIP(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE device_devices RENAME TO device_devices_tmp;
CREATE TABLE device_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
created_ts BIGINT,
display_name TEXT,
last_seen_ts BIGINT,
ip TEXT,
user_agent TEXT,
UNIQUE (localpart, device_id)
);
INSERT
INTO device_devices (
access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent
) SELECT
access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', ''
FROM device_devices_tmp;
DROP TABLE device_devices_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownLastSeenTSIP(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE device_devices RENAME TO device_devices_tmp;
CREATE TABLE IF NOT EXISTS device_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
created_ts BIGINT,
display_name TEXT,
UNIQUE (localpart, device_id)
);
INSERT
INTO device_devices (
access_token, session_id, device_id, localpart, created_ts, display_name
) SELECT
access_token, session_id, device_id, localpart, created_ts, display_name
FROM device_devices_tmp;
DROP TABLE device_devices_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -1,44 +0,0 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE device_devices RENAME TO device_devices_tmp;
CREATE TABLE device_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
created_ts BIGINT,
display_name TEXT,
last_seen_ts BIGINT,
ip TEXT,
user_agent TEXT,
UNIQUE (localpart, device_id)
);
INSERT
INTO device_devices (
access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent
) SELECT
access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', ''
FROM device_devices_tmp;
DROP TABLE device_devices_tmp;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE device_devices RENAME TO device_devices_tmp;
CREATE TABLE IF NOT EXISTS device_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
created_ts BIGINT,
display_name TEXT,
UNIQUE (localpart, device_id)
);
INSERT
INTO device_devices (
access_token, session_id, device_id, localpart, created_ts, display_name
) SELECT
access_token, session_id, device_id, localpart, created_ts, display_name
FROM device_devices_tmp;
DROP TABLE device_devices_tmp;
-- +goose StatementEnd

View file

@ -98,13 +98,14 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *devicesStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(devicesSchema)
return err
}
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
s.db = db s.db = db
s.writer = writer s.writer = writer
_, err = db.Exec(devicesSchema)
if err != nil {
return
}
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return return
} }

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
@ -46,6 +47,17 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
} }
writer := sqlutil.NewExclusiveWriter() writer := sqlutil.NewExclusiveWriter()
d := devicesStatements{} d := devicesStatements{}
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
if err = d.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadLastSeenTSIP(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = d.prepare(db, writer, serverName); err != nil { if err = d.prepare(db, writer, serverName); err != nil {
return nil, err return nil, err
} }