diff --git a/build/dendritejs-pinecone/main.go b/build/dendritejs-pinecone/main.go index bbb1d4c7a..44e52286f 100644 --- a/build/dendritejs-pinecone/main.go +++ b/build/dendritejs-pinecone/main.go @@ -164,6 +164,7 @@ func startup() { cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db" cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db" cfg.SyncAPI.Database.ConnectionString = "file:/idb/dendritejs_syncapi.db" + cfg.KeyServer.Database.ConnectionString = "file:/idb/dendritejs_e2ekey.db" cfg.Global.JetStream.StoragePath = "file:/idb/dendritejs/" cfg.Global.TrustedIDServers = []string{} cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) diff --git a/setup/config/config_test.go b/setup/config/config_test.go index 582437bb8..79407f30d 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -103,6 +103,12 @@ client_api: federation_api: database: connection_string: file:federationapi.db +key_server: + database: + connection_string: file:keyserver.db + max_open_conns: 100 + max_idle_conns: 2 + conn_max_lifetime: -1 media_api: database: connection_string: file:mediaapi.db diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go index 42ae72e77..51bd2753a 100644 --- a/userapi/consumers/clientapi.go +++ b/userapi/consumers/clientapi.go @@ -37,7 +37,7 @@ type OutputReceiptEventConsumer struct { jetstream nats.JetStreamContext durable string topic string - db storage.Database + db storage.UserDatabase serverName gomatrixserverlib.ServerName syncProducer *producers.SyncAPI pgClient pushgateway.Client @@ -49,7 +49,7 @@ func NewOutputReceiptEventConsumer( process *process.ProcessContext, cfg *config.UserAPI, js nats.JetStreamContext, - store storage.Database, + store storage.UserDatabase, syncProducer *producers.SyncAPI, pgClient pushgateway.Client, ) *OutputReceiptEventConsumer { diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 3ce5af621..47d330959 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -38,7 +38,7 @@ type OutputRoomEventConsumer struct { rsAPI rsapi.UserRoomserverAPI jetstream nats.JetStreamContext durable string - db storage.Database + db storage.UserDatabase topic string pgClient pushgateway.Client syncProducer *producers.SyncAPI @@ -53,7 +53,7 @@ func NewOutputRoomEventConsumer( process *process.ProcessContext, cfg *config.UserAPI, js nats.JetStreamContext, - store storage.Database, + store storage.UserDatabase, pgClient pushgateway.Client, rsAPI rsapi.UserRoomserverAPI, syncProducer *producers.SyncAPI, diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index b5206dd85..bc5ae652d 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -18,7 +18,7 @@ import ( userAPITypes "github.com/matrix-org/dendrite/userapi/types" ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { base, baseclose := testrig.CreateBaseDendrite(t, dbType) t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) diff --git a/userapi/internal/cross_signing.go b/userapi/internal/cross_signing.go index 64fe8bafa..8b9704d1b 100644 --- a/userapi/internal/cross_signing.go +++ b/userapi/internal/cross_signing.go @@ -169,7 +169,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. // something if any of the specified keys in the request are different // to what we've got in the database, to avoid generating key change // notifications unnecessarily. - existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) + existingKeys, err := a.KeyDatabase.CrossSigningKeysDataForUser(ctx, req.UserID) if err != nil { res.Error = &api.KeyError{ Err: "Retrieving cross-signing keys from database failed: " + err.Error(), @@ -216,7 +216,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. } // Store the keys. - if err := a.DB.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil { + if err := a.KeyDatabase.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err), } @@ -234,7 +234,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. continue } for sigKeyID, sigBytes := range forSigUserID { - if err := a.DB.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil { + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err), } @@ -373,7 +373,7 @@ func (a *UserInternalAPI) processSelfSignatures( } for originUserID, forOriginUserID := range sig.Signatures { for originKeyID, originSig := range forOriginUserID { - if err := a.DB.StoreCrossSigningSigsForTarget( + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig, ); err != nil { return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) @@ -384,7 +384,7 @@ func (a *UserInternalAPI) processSelfSignatures( case *gomatrixserverlib.DeviceKeys: for originUserID, forOriginUserID := range sig.Signatures { for originKeyID, originSig := range forOriginUserID { - if err := a.DB.StoreCrossSigningSigsForTarget( + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig, ); err != nil { return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) @@ -442,7 +442,7 @@ func (a *UserInternalAPI) processOtherSignatures( } for originKeyID, originSig := range userSigs { - if err := a.DB.StoreCrossSigningSigsForTarget( + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( ctx, userID, originKeyID, targetUserID, targetKeyID, originSig, ); err != nil { return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) @@ -465,7 +465,7 @@ func (a *UserInternalAPI) crossSigningKeysFromDatabase( ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse, ) { for targetUserID := range req.UserToDevices { - keys, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID) + keys, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID) if err != nil { logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID) continue @@ -478,7 +478,7 @@ func (a *UserInternalAPI) crossSigningKeysFromDatabase( break } - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID) + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID) if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID) continue @@ -524,7 +524,7 @@ func (a *UserInternalAPI) crossSigningKeysFromDatabase( func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error { for targetUserID, forTargetUser := range req.TargetIDs { - keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID) + keyMap, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID) if err != nil && err != sql.ErrNoRows { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.CrossSigningKeysForUser: %s", err), @@ -556,7 +556,7 @@ func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySig for _, targetKeyID := range forTargetUser { // Get own signatures only. - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID) + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID) if err != nil && err != sql.ErrNoRows { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err), diff --git a/userapi/internal/device_list_update_test.go b/userapi/internal/device_list_update_test.go index 332a72f10..868fc9be8 100644 --- a/userapi/internal/device_list_update_test.go +++ b/userapi/internal/device_list_update_test.go @@ -28,7 +28,6 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" roomserver "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" @@ -361,12 +360,12 @@ func TestDebounce(t *testing.T) { } } -func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { t.Helper() base, _, _ := testrig.Base(nil) connStr, clearDB := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}, "localhost", bcrypt.MinCost, 2000, time.Second, "") + db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) if err != nil { t.Fatal(err) } diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go index 1d243fff6..be816fe5d 100644 --- a/userapi/internal/key_api.go +++ b/userapi/internal/key_api.go @@ -33,7 +33,7 @@ import ( ) func (a *UserInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error { - userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset) + userIDs, latest, err := a.KeyDatabase.KeyChanges(ctx, req.Offset, req.ToOffset) if err != nil { res.Error = &api.KeyError{ Err: err.Error(), @@ -53,7 +53,7 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor if len(req.OneTimeKeys) > 0 { a.uploadOneTimeKeys(ctx, req, res) } - otks, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { return err } @@ -83,7 +83,7 @@ func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.Perform continue } // claim local keys - keys, err := a.DB.ClaimKeys(ctx, local) + keys, err := a.KeyDatabase.ClaimKeys(ctx, local) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err), @@ -160,7 +160,7 @@ func (a *UserInternalAPI) claimRemoteKeys( } func (a *UserInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error { - if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil { + if err := a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("Failed to delete device keys: %s", err), } @@ -169,7 +169,7 @@ func (a *UserInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.Perfor } func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error { - count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + count, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("Failed to query OTK counts: %s", err), @@ -181,7 +181,7 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn } func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error { - msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false) + msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query DB for device keys: %s", err), @@ -208,7 +208,7 @@ func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Quer // PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present // in our database. func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error { - knownDevices, err := a.DB.DeviceKeysForUser(ctx, req.UserID, []string{}, true) + knownDevices, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, []string{}, true) if err != nil { return err } @@ -245,7 +245,7 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque domain := string(serverName) // query local devices if a.Config.Matrix.IsLocalServerName(serverName) { - deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) + deviceKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query local device keys: %s", err), @@ -323,14 +323,14 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} } for targetKeyID := range masterKey.Keys { - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID) + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID) if err != nil { // Stop executing the function if the context was canceled/the deadline was exceeded, // as we can't continue without a valid context. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil } - logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") + logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed") continue } if len(sigMap) == 0 { @@ -349,14 +349,14 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque for targetUserID, forUserID := range res.DeviceKeys { for targetKeyID, key := range forUserID { - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID)) + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID)) if err != nil { // Stop executing the function if the context was canceled/the deadline was exceeded, // as we can't continue without a valid context. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil } - logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") + logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed") continue } if len(sigMap) == 0 { @@ -571,7 +571,7 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer( func (a *UserInternalAPI) populateResponseWithDeviceKeysFromDatabase( ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string, ) error { - keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) + keys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false) // if we can't query the db or there are fewer keys than requested, fetch from remote. if err != nil { return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err) @@ -625,7 +625,7 @@ func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Pe } // Get all of the user existing device keys so we can check for changes. - existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true) + existingKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, true) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), @@ -644,7 +644,7 @@ func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Pe } if len(toClean) > 0 { - if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { + if err = a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean)) } else { logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean)) @@ -704,7 +704,7 @@ func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Pe } // store the device keys and emit changes - err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore) + err = a.KeyDatabase.StoreLocalDeviceKeys(ctx, keysToStore) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), @@ -724,10 +724,10 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor } } if req.DeviceID != "" && len(req.OneTimeKeys) == 0 { - counts, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + counts, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.DB.OneTimeKeysCount: %s", err), + Err: fmt.Sprintf("a.KeyDatabase.OneTimeKeysCount: %s", err), } } if counts != nil { @@ -743,7 +743,7 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor keyIDsWithAlgorithms[i] = keyIDWithAlgo i++ } - existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms) + existingKeys, err := a.KeyDatabase.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms) if err != nil { res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ Err: "failed to query existing one-time keys: " + err.Error(), @@ -760,7 +760,7 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor } } // store one-time keys - counts, err := a.DB.StoreOneTimeKeys(ctx, key) + counts, err := a.KeyDatabase.StoreOneTimeKeys(ctx, key) if err != nil { res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()), diff --git a/userapi/internal/key_api_test.go b/userapi/internal/key_api_test.go index 8f8bc8d9c..fc7e7e0df 100644 --- a/userapi/internal/key_api_test.go +++ b/userapi/internal/key_api_test.go @@ -4,7 +4,6 @@ import ( "context" "reflect" "testing" - "time" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" @@ -12,16 +11,15 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/storage" - "golang.org/x/crypto/bcrypt" ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) base, _, _ := testrig.Base(nil) - db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ + db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, "localhost", bcrypt.MinCost, 2000, time.Second, "") + }) if err != nil { t.Fatalf("failed to create new user db: %v", err) } @@ -148,7 +146,7 @@ func Test_QueryDeviceMessages(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &internal.UserInternalAPI{ - DB: db, + KeyDatabase: db, } if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr { t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr) diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 6957de42c..1cbd97190 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -44,7 +44,8 @@ import ( ) type UserInternalAPI struct { - DB storage.Database + DB storage.UserDatabase + KeyDatabase storage.KeyDatabase SyncProducer *producers.SyncAPI KeyChangeProducer *producers.KeyChange Config *config.UserAPI diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go index 68efca5b8..165de8994 100644 --- a/userapi/producers/syncapi.go +++ b/userapi/producers/syncapi.go @@ -25,7 +25,7 @@ type SyncAPI struct { notificationDataTopic string } -func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI { +func NewSyncAPI(db storage.UserDatabase, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI { return &SyncAPI{ db: db, producer: js, diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 31f340ac8..278378861 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -90,7 +90,7 @@ type KeyBackup interface { type LoginToken interface { // CreateLoginToken generates a token, stores and returns it. The lifetime is - // determined by the loginTokenLifetime given to the Database constructor. + // determined by the loginTokenLifetime given to the UserDatabase constructor. CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) // RemoveLoginToken removes the named token (and may clean up other expired tokens). @@ -130,7 +130,7 @@ type Notification interface { DeleteOldNotifications(ctx context.Context) error } -type Database interface { +type UserDatabase interface { Account AccountData Device @@ -142,7 +142,6 @@ type Database interface { Pusher Statistics ThreePID - KeyserverDatabase } type KeyChangeDatabase interface { @@ -151,7 +150,7 @@ type KeyChangeDatabase interface { StoreKeyChange(ctx context.Context, userID string) (int64, error) } -type KeyserverDatabase interface { +type KeyDatabase interface { KeyChangeDatabase // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database. diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 3b5908113..5c1c38de5 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -104,9 +104,41 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return nil, fmt.Errorf("NewPostgresStatsTable: %w", err) } - if err != nil { + m = sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names populate", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNamesPopulate(ctx, txn, serverName) + }, + }) + if err = m.Up(base.Context()); err != nil { return nil, err } + + return &shared.Database{ + AccountDatas: accountDataTable, + Accounts: accountsTable, + Devices: devicesTable, + KeyBackups: keyBackupTable, + KeyBackupVersions: keyBackupVersionTable, + LoginTokens: loginTokenTable, + OpenIDTokens: openIDTable, + Profiles: profilesTable, + ThreePIDs: threePIDTable, + Pushers: pusherTable, + Notifications: notificationsTable, + Stats: statsTable, + ServerName: serverName, + DB: db, + Writer: writer, + LoginTokenLifetime: loginTokenLifetime, + BcryptCost: bcryptCost, + OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, + }, nil +} + +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) otk, err := NewPostgresOneTimeKeysTable(db) if err != nil { return nil, err @@ -132,41 +164,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return nil, err } - m = sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "userapi: server names populate", - Up: func(ctx context.Context, txn *sql.Tx) error { - return deltas.UpServerNamesPopulate(ctx, txn, serverName) - }, - }) - if err = m.Up(base.Context()); err != nil { - return nil, err - } - - return &shared.Database{ - AccountDatas: accountDataTable, - Accounts: accountsTable, - Devices: devicesTable, - KeyBackups: keyBackupTable, - KeyBackupVersions: keyBackupVersionTable, - LoginTokens: loginTokenTable, - OpenIDTokens: openIDTable, - Profiles: profilesTable, - ThreePIDs: threePIDTable, - Pushers: pusherTable, - Notifications: notificationsTable, - Stats: statsTable, + return &shared.KeyDatabase{ OneTimeKeysTable: otk, DeviceKeysTable: dk, KeyChangesTable: kc, StaleDeviceListsTable: sdl, CrossSigningKeysTable: csk, CrossSigningSigsTable: css, - ServerName: serverName, - DB: db, Writer: writer, - LoginTokenLifetime: loginTokenLifetime, - BcryptCost: bcryptCost, - OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, }, nil } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index cfc149bf1..d3272a032 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -53,16 +53,21 @@ type Database struct { Notifications tables.NotificationTable Pushers tables.PusherTable Stats tables.StatsTable + LoginTokenLifetime time.Duration + ServerName gomatrixserverlib.ServerName + BcryptCost int + OpenIDTokenLifetimeMS int64 +} + +type KeyDatabase struct { OneTimeKeysTable tables.OneTimeKeys DeviceKeysTable tables.DeviceKeys KeyChangesTable tables.KeyChanges StaleDeviceListsTable tables.StaleDeviceLists CrossSigningKeysTable tables.CrossSigningKeys CrossSigningSigsTable tables.CrossSigningSigs - LoginTokenLifetime time.Duration - ServerName gomatrixserverlib.ServerName - BcryptCost int - OpenIDTokenLifetimeMS int64 + DB *sql.DB + Writer sqlutil.Writer } const ( @@ -884,11 +889,11 @@ func (d *Database) DailyRoomsMessages( // -func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { +func (d *KeyDatabase) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms) } -func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) { +func (d *KeyDatabase) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) { _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys) return err @@ -896,15 +901,15 @@ func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) ( return } -func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { +func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) } -func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { +func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) } -func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { +func (d *KeyDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs) if err != nil { return false, err @@ -912,7 +917,7 @@ func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []i return count == len(prevIDs), nil } -func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error { +func (d *KeyDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { for _, userID := range clearUserIDs { err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID) @@ -924,7 +929,7 @@ func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceM }) } -func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { +func (d *KeyDatabase) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { // work out the latest stream IDs for each user userIDToStreamID := make(map[string]int64) for _, k := range keys { @@ -949,11 +954,11 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe }) } -func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { +func (d *KeyDatabase) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty) } -func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { +func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { var result []api.OneTimeKeys err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { for userID, deviceToAlgo := range userToDeviceToAlgorithm { @@ -976,7 +981,7 @@ func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[st return result, err } -func (d *Database) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) { +func (d *KeyDatabase) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) { err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID) return err @@ -984,18 +989,18 @@ func (d *Database) StoreKeyChange(ctx context.Context, userID string) (id int64, return } -func (d *Database) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { +func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset) } // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. -func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { +func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) } // MarkDeviceListStale sets the stale bit for this user to isStale. -func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { +func (d *KeyDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { return d.Writer.Do(nil, nil, func(_ *sql.Tx) error { return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) }) @@ -1003,7 +1008,7 @@ func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isSta // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying // cross-signing signatures relating to that device. -func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error { +func (d *KeyDatabase) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { for _, deviceID := range deviceIDs { if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows { @@ -1021,7 +1026,7 @@ func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceID } // CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. -func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) { +func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) { keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) if err != nil { return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err) @@ -1060,17 +1065,17 @@ func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) ( } // CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. -func (d *Database) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) { +func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) { return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) } // CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. -func (d *Database) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { +func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID) } // StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. -func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error { +func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { for keyType, keyData := range keyMap { if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil { @@ -1082,7 +1087,7 @@ func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID stri } // StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice. -func (d *Database) StoreCrossSigningSigsForTarget( +func (d *KeyDatabase) StoreCrossSigningSigsForTarget( ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, @@ -1097,7 +1102,7 @@ func (d *Database) StoreCrossSigningSigsForTarget( } // DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore. -func (d *Database) DeleteStaleDeviceLists( +func (d *KeyDatabase) DeleteStaleDeviceLists( ctx context.Context, userIDs []string, ) error { diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 0c38d3328..0f3eeed1b 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -30,8 +30,8 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" ) -// NewDatabase creates a new accounts and profiles database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { +// NewUserDatabase creates a new accounts and profiles database +func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) if err != nil { return nil, err @@ -102,6 +102,41 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err) } + m = sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names populate", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNamesPopulate(ctx, txn, serverName) + }, + }) + if err = m.Up(base.Context()); err != nil { + return nil, err + } + + return &shared.Database{ + AccountDatas: accountDataTable, + Accounts: accountsTable, + Devices: devicesTable, + KeyBackups: keyBackupTable, + KeyBackupVersions: keyBackupVersionTable, + LoginTokens: loginTokenTable, + OpenIDTokens: openIDTable, + Profiles: profilesTable, + ThreePIDs: threePIDTable, + Pushers: pusherTable, + Notifications: notificationsTable, + Stats: statsTable, + ServerName: serverName, + DB: db, + Writer: writer, + LoginTokenLifetime: loginTokenLifetime, + BcryptCost: bcryptCost, + OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, + }, nil +} + +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) if err != nil { return nil, err } @@ -130,41 +165,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return nil, err } - m = sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "userapi: server names populate", - Up: func(ctx context.Context, txn *sql.Tx) error { - return deltas.UpServerNamesPopulate(ctx, txn, serverName) - }, - }) - if err = m.Up(base.Context()); err != nil { - return nil, err - } - - return &shared.Database{ - AccountDatas: accountDataTable, - Accounts: accountsTable, - Devices: devicesTable, - KeyBackups: keyBackupTable, - KeyBackupVersions: keyBackupVersionTable, - LoginTokens: loginTokenTable, - OpenIDTokens: openIDTable, - Profiles: profilesTable, - ThreePIDs: threePIDTable, - Pushers: pusherTable, - Notifications: notificationsTable, - Stats: statsTable, + return &shared.KeyDatabase{ OneTimeKeysTable: otk, DeviceKeysTable: dk, KeyChangesTable: kc, StaleDeviceListsTable: sdl, CrossSigningKeysTable: csk, CrossSigningSigsTable: css, - ServerName: serverName, - DB: db, Writer: writer, - LoginTokenLifetime: loginTokenLifetime, - BcryptCost: bcryptCost, - OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, }, nil } diff --git a/userapi/storage/storage.go b/userapi/storage/storage.go index a9de53f95..0329fb46a 100644 --- a/userapi/storage/storage.go +++ b/userapi/storage/storage.go @@ -39,13 +39,26 @@ func NewUserDatabase( openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string, -) (Database, error) { +) (UserDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) default: return nil, fmt.Errorf("unexpected database type") } } + +// NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme) +// and sets postgres connection parameters. +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (KeyDatabase, error) { + switch { + case dbProperties.ConnectionString.IsSQLite(): + return sqlite3.NewKeyDatabase(base, dbProperties) + case dbProperties.ConnectionString.IsPostgres(): + return postgres.NewKeyDatabase(base, dbProperties) + default: + return nil, fmt.Errorf("unexpected database type") + } +} diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index b92ae1610..f52e7e17d 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -32,7 +32,7 @@ var ( ctx = context.Background() ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ @@ -50,7 +50,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun // Tests storing and getting account data func Test_AccountData(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() alice := test.NewUser(t) localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -81,7 +81,7 @@ func Test_AccountData(t *testing.T) { // Tests the creation of accounts func Test_Accounts(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() alice := test.NewUser(t) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -161,7 +161,7 @@ func Test_Devices(t *testing.T) { accessToken := util.RandomString(16) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "") @@ -241,7 +241,7 @@ func Test_KeyBackup(t *testing.T) { room := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() wantAuthData := json.RawMessage("my auth data") @@ -318,7 +318,7 @@ func Test_KeyBackup(t *testing.T) { func Test_LoginToken(t *testing.T) { alice := test.NewUser(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // create a new token @@ -350,7 +350,7 @@ func Test_OpenID(t *testing.T) { token := util.RandomString(24) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS @@ -371,7 +371,7 @@ func Test_Profile(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // create account, which also creates a profile @@ -420,7 +420,7 @@ func Test_Pusher(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() appID := util.RandomString(8) @@ -471,7 +471,7 @@ func Test_ThreePID(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() threePID := util.RandomString(8) medium := util.RandomString(8) @@ -510,7 +510,7 @@ func Test_Notification(t *testing.T) { room := test.NewRoom(t, alice) room2 := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // generate some dummy notifications for i := 0; i < 10; i++ { @@ -575,9 +575,9 @@ func Test_Notification(t *testing.T) { }) } -func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { base, close := testrig.CreateBaseDendrite(t, dbType) - db, err := storage.NewUserDatabase(base, &base.Cfg.KeyServer.Database, "localhost", bcrypt.MinCost, 2000, time.Second, "") + db, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database) if err != nil { t.Fatalf("failed to create new database: %v", err) } @@ -594,7 +594,7 @@ func MustNotError(t *testing.T, err error) { func TestKeyChanges(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) + db, clean := mustCreateKeyDatabase(t, dbType) defer clean() _, err := db.StoreKeyChange(ctx, "@alice:localhost") MustNotError(t, err) @@ -617,7 +617,7 @@ func TestKeyChanges(t *testing.T) { func TestKeyChangesNoDupes(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) + db, clean := mustCreateKeyDatabase(t, dbType) defer clean() deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") MustNotError(t, err) @@ -643,7 +643,7 @@ func TestKeyChangesNoDupes(t *testing.T) { func TestKeyChangesUpperLimit(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) + db, clean := mustCreateKeyDatabase(t, dbType) defer clean() deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") MustNotError(t, err) @@ -672,7 +672,7 @@ var deviceArray = []string{"AAA", "another_device"} func TestDeviceKeysStreamIDGeneration(t *testing.T) { var err error test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) + db, clean := mustCreateKeyDatabase(t, dbType) defer clean() alice := "@alice:TestDeviceKeysStreamIDGeneration" bob := "@bob:TestDeviceKeysStreamIDGeneration" diff --git a/userapi/storage/storage_wasm.go b/userapi/storage/storage_wasm.go index 5d5d292e6..163e3e173 100644 --- a/userapi/storage/storage_wasm.go +++ b/userapi/storage/storage_wasm.go @@ -32,10 +32,10 @@ func NewUserAPIDatabase( openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string, -) (Database, error) { +) (UserDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/userapi.go b/userapi/userapi.go index 8eba0a445..826bd7213 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -57,6 +57,11 @@ func NewInternalAPI( logrus.WithError(err).Panicf("failed to connect to accounts db") } + keyDB, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database) + if err != nil { + logrus.WithError(err).Panicf("failed to connect to key db") + } + syncProducer := producers.NewSyncAPI( db, js, // TODO: user API should handle syncs for account data. Right now, @@ -69,11 +74,12 @@ func NewInternalAPI( keyChangeProducer := &producers.KeyChange{ Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), JetStream: js, - DB: db, + DB: keyDB, } userAPI := &internal.UserInternalAPI{ DB: db, + KeyDatabase: keyDB, SyncProducer: syncProducer, KeyChangeProducer: keyChangeProducer, Config: cfg, @@ -84,7 +90,7 @@ func NewInternalAPI( FedClient: fedClient, } - updater := internal.NewDeviceListUpdater(base.ProcessContext, db, userAPI, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable + updater := internal.NewDeviceListUpdater(base.ProcessContext, keyDB, userAPI, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable userAPI.Updater = updater // Remove users which we don't share a room with anymore if err := updater.CleanUp(); err != nil { diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 5f853d52e..08b1336be 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -50,7 +50,7 @@ func (d *dummyProducer) PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, err return &nats.PubAck{}, nil } -func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.Database, func()) { +func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.UserDatabase, func()) { if opts.loginTokenLifetime == 0 { opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond } @@ -67,6 +67,13 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap t.Fatalf("failed to create account DB: %s", err) } + keyDB, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }) + if err != nil { + t.Fatalf("failed to create key DB: %s", err) + } + cfg := &config.UserAPI{ Matrix: &config.Global{ SigningIdentity: gomatrixserverlib.SigningIdentity{ @@ -76,9 +83,10 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap } syncProducer := producers.NewSyncAPI(accountDB, &dummyProducer{}, "", "") - keyChangeProducer := &producers.KeyChange{DB: accountDB, JetStream: &dummyProducer{}} + keyChangeProducer := &producers.KeyChange{DB: keyDB, JetStream: &dummyProducer{}} return &internal.UserInternalAPI{ DB: accountDB, + KeyDatabase: keyDB, Config: cfg, SyncProducer: syncProducer, KeyChangeProducer: keyChangeProducer, diff --git a/userapi/util/devices.go b/userapi/util/devices.go index c55fc7999..31617d8c1 100644 --- a/userapi/util/devices.go +++ b/userapi/util/devices.go @@ -19,7 +19,7 @@ type PusherDevice struct { } // GetPushDevices pushes to the configured devices of a local user. -func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { +func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.UserDatabase) ([]*PusherDevice, error) { pushers, err := db.GetPushers(ctx, localpart, serverName) if err != nil { return nil, fmt.Errorf("db.GetPushers: %w", err) diff --git a/userapi/util/notify.go b/userapi/util/notify.go index fc0ab39bf..08d1371d6 100644 --- a/userapi/util/notify.go +++ b/userapi/util/notify.go @@ -13,11 +13,11 @@ import ( ) // NotifyUserCountsAsync sends notifications to a local user's -// notification destinations. Database lookups run synchronously, but +// notification destinations. UserDatabase lookups run synchronously, but // a single goroutine is started when talking to the Push // gateways. There is no way to know when the background goroutine has // finished. -func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error { +func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.UserDatabase) error { pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db) if err != nil { return err