mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-17 11:03:11 -06:00
Separate keyserver database again
This commit is contained in:
parent
ada6f04bd2
commit
24c1226359
|
|
@ -164,6 +164,7 @@ func startup() {
|
|||
cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db"
|
||||
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"
|
||||
cfg.SyncAPI.Database.ConnectionString = "file:/idb/dendritejs_syncapi.db"
|
||||
cfg.KeyServer.Database.ConnectionString = "file:/idb/dendritejs_e2ekey.db"
|
||||
cfg.Global.JetStream.StoragePath = "file:/idb/dendritejs/"
|
||||
cfg.Global.TrustedIDServers = []string{}
|
||||
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
||||
|
|
|
|||
|
|
@ -103,6 +103,12 @@ client_api:
|
|||
federation_api:
|
||||
database:
|
||||
connection_string: file:federationapi.db
|
||||
key_server:
|
||||
database:
|
||||
connection_string: file:keyserver.db
|
||||
max_open_conns: 100
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
media_api:
|
||||
database:
|
||||
connection_string: file:mediaapi.db
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ type OutputReceiptEventConsumer struct {
|
|||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
topic string
|
||||
db storage.Database
|
||||
db storage.UserDatabase
|
||||
serverName gomatrixserverlib.ServerName
|
||||
syncProducer *producers.SyncAPI
|
||||
pgClient pushgateway.Client
|
||||
|
|
@ -49,7 +49,7 @@ func NewOutputReceiptEventConsumer(
|
|||
process *process.ProcessContext,
|
||||
cfg *config.UserAPI,
|
||||
js nats.JetStreamContext,
|
||||
store storage.Database,
|
||||
store storage.UserDatabase,
|
||||
syncProducer *producers.SyncAPI,
|
||||
pgClient pushgateway.Client,
|
||||
) *OutputReceiptEventConsumer {
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ type OutputRoomEventConsumer struct {
|
|||
rsAPI rsapi.UserRoomserverAPI
|
||||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
db storage.Database
|
||||
db storage.UserDatabase
|
||||
topic string
|
||||
pgClient pushgateway.Client
|
||||
syncProducer *producers.SyncAPI
|
||||
|
|
@ -53,7 +53,7 @@ func NewOutputRoomEventConsumer(
|
|||
process *process.ProcessContext,
|
||||
cfg *config.UserAPI,
|
||||
js nats.JetStreamContext,
|
||||
store storage.Database,
|
||||
store storage.UserDatabase,
|
||||
pgClient pushgateway.Client,
|
||||
rsAPI rsapi.UserRoomserverAPI,
|
||||
syncProducer *producers.SyncAPI,
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import (
|
|||
userAPITypes "github.com/matrix-org/dendrite/userapi/types"
|
||||
)
|
||||
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) {
|
||||
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.
|
|||
// something if any of the specified keys in the request are different
|
||||
// to what we've got in the database, to avoid generating key change
|
||||
// notifications unnecessarily.
|
||||
existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID)
|
||||
existingKeys, err := a.KeyDatabase.CrossSigningKeysDataForUser(ctx, req.UserID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
|
||||
|
|
@ -216,7 +216,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.
|
|||
}
|
||||
|
||||
// Store the keys.
|
||||
if err := a.DB.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil {
|
||||
if err := a.KeyDatabase.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err),
|
||||
}
|
||||
|
|
@ -234,7 +234,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.
|
|||
continue
|
||||
}
|
||||
for sigKeyID, sigBytes := range forSigUserID {
|
||||
if err := a.DB.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil {
|
||||
if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err),
|
||||
}
|
||||
|
|
@ -373,7 +373,7 @@ func (a *UserInternalAPI) processSelfSignatures(
|
|||
}
|
||||
for originUserID, forOriginUserID := range sig.Signatures {
|
||||
for originKeyID, originSig := range forOriginUserID {
|
||||
if err := a.DB.StoreCrossSigningSigsForTarget(
|
||||
if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
|
||||
ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig,
|
||||
); err != nil {
|
||||
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
|
||||
|
|
@ -384,7 +384,7 @@ func (a *UserInternalAPI) processSelfSignatures(
|
|||
case *gomatrixserverlib.DeviceKeys:
|
||||
for originUserID, forOriginUserID := range sig.Signatures {
|
||||
for originKeyID, originSig := range forOriginUserID {
|
||||
if err := a.DB.StoreCrossSigningSigsForTarget(
|
||||
if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
|
||||
ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig,
|
||||
); err != nil {
|
||||
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
|
||||
|
|
@ -442,7 +442,7 @@ func (a *UserInternalAPI) processOtherSignatures(
|
|||
}
|
||||
|
||||
for originKeyID, originSig := range userSigs {
|
||||
if err := a.DB.StoreCrossSigningSigsForTarget(
|
||||
if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
|
||||
ctx, userID, originKeyID, targetUserID, targetKeyID, originSig,
|
||||
); err != nil {
|
||||
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
|
||||
|
|
@ -465,7 +465,7 @@ func (a *UserInternalAPI) crossSigningKeysFromDatabase(
|
|||
ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse,
|
||||
) {
|
||||
for targetUserID := range req.UserToDevices {
|
||||
keys, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID)
|
||||
keys, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID)
|
||||
continue
|
||||
|
|
@ -478,7 +478,7 @@ func (a *UserInternalAPI) crossSigningKeysFromDatabase(
|
|||
break
|
||||
}
|
||||
|
||||
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID)
|
||||
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID)
|
||||
continue
|
||||
|
|
@ -524,7 +524,7 @@ func (a *UserInternalAPI) crossSigningKeysFromDatabase(
|
|||
|
||||
func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error {
|
||||
for targetUserID, forTargetUser := range req.TargetIDs {
|
||||
keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID)
|
||||
keyMap, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.DB.CrossSigningKeysForUser: %s", err),
|
||||
|
|
@ -556,7 +556,7 @@ func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySig
|
|||
|
||||
for _, targetKeyID := range forTargetUser {
|
||||
// Get own signatures only.
|
||||
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID)
|
||||
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err),
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
roomserver "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
|
|
@ -361,12 +360,12 @@ func TestDebounce(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
|
||||
t.Helper()
|
||||
|
||||
base, _, _ := testrig.Base(nil)
|
||||
connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}, "localhost", bcrypt.MinCost, 2000, time.Second, "")
|
||||
db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ import (
|
|||
)
|
||||
|
||||
func (a *UserInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error {
|
||||
userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset)
|
||||
userIDs, latest, err := a.KeyDatabase.KeyChanges(ctx, req.Offset, req.ToOffset)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: err.Error(),
|
||||
|
|
@ -53,7 +53,7 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
|
|||
if len(req.OneTimeKeys) > 0 {
|
||||
a.uploadOneTimeKeys(ctx, req, res)
|
||||
}
|
||||
otks, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -83,7 +83,7 @@ func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.Perform
|
|||
continue
|
||||
}
|
||||
// claim local keys
|
||||
keys, err := a.DB.ClaimKeys(ctx, local)
|
||||
keys, err := a.KeyDatabase.ClaimKeys(ctx, local)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
|
||||
|
|
@ -160,7 +160,7 @@ func (a *UserInternalAPI) claimRemoteKeys(
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
|
||||
if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
|
||||
if err := a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("Failed to delete device keys: %s", err),
|
||||
}
|
||||
|
|
@ -169,7 +169,7 @@ func (a *UserInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.Perfor
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error {
|
||||
count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
count, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
|
||||
|
|
@ -181,7 +181,7 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
|
||||
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
|
||||
msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
|
||||
|
|
@ -208,7 +208,7 @@ func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Quer
|
|||
// PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present
|
||||
// in our database.
|
||||
func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error {
|
||||
knownDevices, err := a.DB.DeviceKeysForUser(ctx, req.UserID, []string{}, true)
|
||||
knownDevices, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, []string{}, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -245,7 +245,7 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque
|
|||
domain := string(serverName)
|
||||
// query local devices
|
||||
if a.Config.Matrix.IsLocalServerName(serverName) {
|
||||
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||
deviceKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to query local device keys: %s", err),
|
||||
|
|
@ -323,14 +323,14 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque
|
|||
masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
for targetKeyID := range masterKey.Keys {
|
||||
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
|
||||
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
|
||||
if err != nil {
|
||||
// Stop executing the function if the context was canceled/the deadline was exceeded,
|
||||
// as we can't continue without a valid context.
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
|
||||
logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed")
|
||||
continue
|
||||
}
|
||||
if len(sigMap) == 0 {
|
||||
|
|
@ -349,14 +349,14 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque
|
|||
|
||||
for targetUserID, forUserID := range res.DeviceKeys {
|
||||
for targetKeyID, key := range forUserID {
|
||||
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
|
||||
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
|
||||
if err != nil {
|
||||
// Stop executing the function if the context was canceled/the deadline was exceeded,
|
||||
// as we can't continue without a valid context.
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
|
||||
logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed")
|
||||
continue
|
||||
}
|
||||
if len(sigMap) == 0 {
|
||||
|
|
@ -571,7 +571,7 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer(
|
|||
func (a *UserInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
||||
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
|
||||
) error {
|
||||
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||
keys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
||||
if err != nil {
|
||||
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
|
||||
|
|
@ -625,7 +625,7 @@ func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Pe
|
|||
}
|
||||
|
||||
// Get all of the user existing device keys so we can check for changes.
|
||||
existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true)
|
||||
existingKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, true)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
|
||||
|
|
@ -644,7 +644,7 @@ func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Pe
|
|||
}
|
||||
|
||||
if len(toClean) > 0 {
|
||||
if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
|
||||
if err = a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
|
||||
logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean))
|
||||
} else {
|
||||
logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean))
|
||||
|
|
@ -704,7 +704,7 @@ func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Pe
|
|||
}
|
||||
|
||||
// store the device keys and emit changes
|
||||
err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
|
||||
err = a.KeyDatabase.StoreLocalDeviceKeys(ctx, keysToStore)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
|
||||
|
|
@ -724,10 +724,10 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
|
|||
}
|
||||
}
|
||||
if req.DeviceID != "" && len(req.OneTimeKeys) == 0 {
|
||||
counts, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
counts, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.DB.OneTimeKeysCount: %s", err),
|
||||
Err: fmt.Sprintf("a.KeyDatabase.OneTimeKeysCount: %s", err),
|
||||
}
|
||||
}
|
||||
if counts != nil {
|
||||
|
|
@ -743,7 +743,7 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
|
|||
keyIDsWithAlgorithms[i] = keyIDWithAlgo
|
||||
i++
|
||||
}
|
||||
existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms)
|
||||
existingKeys, err := a.KeyDatabase.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms)
|
||||
if err != nil {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: "failed to query existing one-time keys: " + err.Error(),
|
||||
|
|
@ -760,7 +760,7 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
|
|||
}
|
||||
}
|
||||
// store one-time keys
|
||||
counts, err := a.DB.StoreOneTimeKeys(ctx, key)
|
||||
counts, err := a.KeyDatabase.StoreOneTimeKeys(ctx, key)
|
||||
if err != nil {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()),
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
|
|
@ -12,16 +11,15 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/internal"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
base, _, _ := testrig.Base(nil)
|
||||
db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
|
||||
db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, "localhost", bcrypt.MinCost, 2000, time.Second, "")
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create new user db: %v", err)
|
||||
}
|
||||
|
|
@ -148,7 +146,7 @@ func Test_QueryDeviceMessages(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &internal.UserInternalAPI{
|
||||
DB: db,
|
||||
KeyDatabase: db,
|
||||
}
|
||||
if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr {
|
||||
t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
|
|
|||
|
|
@ -44,7 +44,8 @@ import (
|
|||
)
|
||||
|
||||
type UserInternalAPI struct {
|
||||
DB storage.Database
|
||||
DB storage.UserDatabase
|
||||
KeyDatabase storage.KeyDatabase
|
||||
SyncProducer *producers.SyncAPI
|
||||
KeyChangeProducer *producers.KeyChange
|
||||
Config *config.UserAPI
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ type SyncAPI struct {
|
|||
notificationDataTopic string
|
||||
}
|
||||
|
||||
func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI {
|
||||
func NewSyncAPI(db storage.UserDatabase, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI {
|
||||
return &SyncAPI{
|
||||
db: db,
|
||||
producer: js,
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ type KeyBackup interface {
|
|||
|
||||
type LoginToken interface {
|
||||
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
||||
// determined by the loginTokenLifetime given to the Database constructor.
|
||||
// determined by the loginTokenLifetime given to the UserDatabase constructor.
|
||||
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
|
||||
|
||||
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
|
||||
|
|
@ -130,7 +130,7 @@ type Notification interface {
|
|||
DeleteOldNotifications(ctx context.Context) error
|
||||
}
|
||||
|
||||
type Database interface {
|
||||
type UserDatabase interface {
|
||||
Account
|
||||
AccountData
|
||||
Device
|
||||
|
|
@ -142,7 +142,6 @@ type Database interface {
|
|||
Pusher
|
||||
Statistics
|
||||
ThreePID
|
||||
KeyserverDatabase
|
||||
}
|
||||
|
||||
type KeyChangeDatabase interface {
|
||||
|
|
@ -151,7 +150,7 @@ type KeyChangeDatabase interface {
|
|||
StoreKeyChange(ctx context.Context, userID string) (int64, error)
|
||||
}
|
||||
|
||||
type KeyserverDatabase interface {
|
||||
type KeyDatabase interface {
|
||||
KeyChangeDatabase
|
||||
// ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination
|
||||
// of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database.
|
||||
|
|
|
|||
|
|
@ -104,9 +104,41 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
return nil, fmt.Errorf("NewPostgresStatsTable: %w", err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
m = sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names populate",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNamesPopulate(ctx, txn, serverName)
|
||||
},
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &shared.Database{
|
||||
AccountDatas: accountDataTable,
|
||||
Accounts: accountsTable,
|
||||
Devices: devicesTable,
|
||||
KeyBackups: keyBackupTable,
|
||||
KeyBackupVersions: keyBackupVersionTable,
|
||||
LoginTokens: loginTokenTable,
|
||||
OpenIDTokens: openIDTable,
|
||||
Profiles: profilesTable,
|
||||
ThreePIDs: threePIDTable,
|
||||
Pushers: pusherTable,
|
||||
Notifications: notificationsTable,
|
||||
Stats: statsTable,
|
||||
ServerName: serverName,
|
||||
DB: db,
|
||||
Writer: writer,
|
||||
LoginTokenLifetime: loginTokenLifetime,
|
||||
BcryptCost: bcryptCost,
|
||||
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) {
|
||||
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter())
|
||||
otk, err := NewPostgresOneTimeKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -132,41 +164,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
m = sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names populate",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNamesPopulate(ctx, txn, serverName)
|
||||
},
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &shared.Database{
|
||||
AccountDatas: accountDataTable,
|
||||
Accounts: accountsTable,
|
||||
Devices: devicesTable,
|
||||
KeyBackups: keyBackupTable,
|
||||
KeyBackupVersions: keyBackupVersionTable,
|
||||
LoginTokens: loginTokenTable,
|
||||
OpenIDTokens: openIDTable,
|
||||
Profiles: profilesTable,
|
||||
ThreePIDs: threePIDTable,
|
||||
Pushers: pusherTable,
|
||||
Notifications: notificationsTable,
|
||||
Stats: statsTable,
|
||||
return &shared.KeyDatabase{
|
||||
OneTimeKeysTable: otk,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
StaleDeviceListsTable: sdl,
|
||||
CrossSigningKeysTable: csk,
|
||||
CrossSigningSigsTable: css,
|
||||
ServerName: serverName,
|
||||
DB: db,
|
||||
Writer: writer,
|
||||
LoginTokenLifetime: loginTokenLifetime,
|
||||
BcryptCost: bcryptCost,
|
||||
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,16 +53,21 @@ type Database struct {
|
|||
Notifications tables.NotificationTable
|
||||
Pushers tables.PusherTable
|
||||
Stats tables.StatsTable
|
||||
LoginTokenLifetime time.Duration
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
BcryptCost int
|
||||
OpenIDTokenLifetimeMS int64
|
||||
}
|
||||
|
||||
type KeyDatabase struct {
|
||||
OneTimeKeysTable tables.OneTimeKeys
|
||||
DeviceKeysTable tables.DeviceKeys
|
||||
KeyChangesTable tables.KeyChanges
|
||||
StaleDeviceListsTable tables.StaleDeviceLists
|
||||
CrossSigningKeysTable tables.CrossSigningKeys
|
||||
CrossSigningSigsTable tables.CrossSigningSigs
|
||||
LoginTokenLifetime time.Duration
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
BcryptCost int
|
||||
OpenIDTokenLifetimeMS int64
|
||||
DB *sql.DB
|
||||
Writer sqlutil.Writer
|
||||
}
|
||||
|
||||
const (
|
||||
|
|
@ -884,11 +889,11 @@ func (d *Database) DailyRoomsMessages(
|
|||
|
||||
//
|
||||
|
||||
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
func (d *KeyDatabase) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms)
|
||||
}
|
||||
|
||||
func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) {
|
||||
func (d *KeyDatabase) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) {
|
||||
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys)
|
||||
return err
|
||||
|
|
@ -896,15 +901,15 @@ func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (
|
|||
return
|
||||
}
|
||||
|
||||
func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||
func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
|
||||
}
|
||||
|
||||
func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
||||
}
|
||||
|
||||
func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
|
||||
func (d *KeyDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
|
||||
count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
|
@ -912,7 +917,7 @@ func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []i
|
|||
return count == len(prevIDs), nil
|
||||
}
|
||||
|
||||
func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
|
||||
func (d *KeyDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
for _, userID := range clearUserIDs {
|
||||
err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
|
||||
|
|
@ -924,7 +929,7 @@ func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceM
|
|||
})
|
||||
}
|
||||
|
||||
func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
func (d *KeyDatabase) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
// work out the latest stream IDs for each user
|
||||
userIDToStreamID := make(map[string]int64)
|
||||
for _, k := range keys {
|
||||
|
|
@ -949,11 +954,11 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe
|
|||
})
|
||||
}
|
||||
|
||||
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||
func (d *KeyDatabase) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
|
||||
}
|
||||
|
||||
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
|
||||
func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
|
||||
var result []api.OneTimeKeys
|
||||
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
for userID, deviceToAlgo := range userToDeviceToAlgorithm {
|
||||
|
|
@ -976,7 +981,7 @@ func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[st
|
|||
return result, err
|
||||
}
|
||||
|
||||
func (d *Database) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) {
|
||||
func (d *KeyDatabase) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) {
|
||||
err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||
id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID)
|
||||
return err
|
||||
|
|
@ -984,18 +989,18 @@ func (d *Database) StoreKeyChange(ctx context.Context, userID string) (id int64,
|
|||
return
|
||||
}
|
||||
|
||||
func (d *Database) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
|
||||
func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
|
||||
return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset)
|
||||
}
|
||||
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains)
|
||||
}
|
||||
|
||||
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||
func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
|
||||
func (d *KeyDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
|
||||
return d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||
return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
|
||||
})
|
||||
|
|
@ -1003,7 +1008,7 @@ func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isSta
|
|||
|
||||
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
|
||||
// cross-signing signatures relating to that device.
|
||||
func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error {
|
||||
func (d *KeyDatabase) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
for _, deviceID := range deviceIDs {
|
||||
if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows {
|
||||
|
|
@ -1021,7 +1026,7 @@ func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceID
|
|||
}
|
||||
|
||||
// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
|
||||
func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) {
|
||||
func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) {
|
||||
keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err)
|
||||
|
|
@ -1060,17 +1065,17 @@ func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (
|
|||
}
|
||||
|
||||
// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
|
||||
func (d *Database) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) {
|
||||
func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) {
|
||||
return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
|
||||
}
|
||||
|
||||
// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any.
|
||||
func (d *Database) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) {
|
||||
func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) {
|
||||
return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID)
|
||||
}
|
||||
|
||||
// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user.
|
||||
func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error {
|
||||
func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
for keyType, keyData := range keyMap {
|
||||
if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil {
|
||||
|
|
@ -1082,7 +1087,7 @@ func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID stri
|
|||
}
|
||||
|
||||
// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice.
|
||||
func (d *Database) StoreCrossSigningSigsForTarget(
|
||||
func (d *KeyDatabase) StoreCrossSigningSigsForTarget(
|
||||
ctx context.Context,
|
||||
originUserID string, originKeyID gomatrixserverlib.KeyID,
|
||||
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
|
|
@ -1097,7 +1102,7 @@ func (d *Database) StoreCrossSigningSigsForTarget(
|
|||
}
|
||||
|
||||
// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore.
|
||||
func (d *Database) DeleteStaleDeviceLists(
|
||||
func (d *KeyDatabase) DeleteStaleDeviceLists(
|
||||
ctx context.Context,
|
||||
userIDs []string,
|
||||
) error {
|
||||
|
|
|
|||
|
|
@ -30,8 +30,8 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
|
||||
)
|
||||
|
||||
// NewDatabase creates a new accounts and profiles database
|
||||
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
|
||||
// NewUserDatabase creates a new accounts and profiles database
|
||||
func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
|
||||
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -102,6 +102,41 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err)
|
||||
}
|
||||
|
||||
m = sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names populate",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNamesPopulate(ctx, txn, serverName)
|
||||
},
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &shared.Database{
|
||||
AccountDatas: accountDataTable,
|
||||
Accounts: accountsTable,
|
||||
Devices: devicesTable,
|
||||
KeyBackups: keyBackupTable,
|
||||
KeyBackupVersions: keyBackupVersionTable,
|
||||
LoginTokens: loginTokenTable,
|
||||
OpenIDTokens: openIDTable,
|
||||
Profiles: profilesTable,
|
||||
ThreePIDs: threePIDTable,
|
||||
Pushers: pusherTable,
|
||||
Notifications: notificationsTable,
|
||||
Stats: statsTable,
|
||||
ServerName: serverName,
|
||||
DB: db,
|
||||
Writer: writer,
|
||||
LoginTokenLifetime: loginTokenLifetime,
|
||||
BcryptCost: bcryptCost,
|
||||
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) {
|
||||
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -130,41 +165,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
m = sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names populate",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNamesPopulate(ctx, txn, serverName)
|
||||
},
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &shared.Database{
|
||||
AccountDatas: accountDataTable,
|
||||
Accounts: accountsTable,
|
||||
Devices: devicesTable,
|
||||
KeyBackups: keyBackupTable,
|
||||
KeyBackupVersions: keyBackupVersionTable,
|
||||
LoginTokens: loginTokenTable,
|
||||
OpenIDTokens: openIDTable,
|
||||
Profiles: profilesTable,
|
||||
ThreePIDs: threePIDTable,
|
||||
Pushers: pusherTable,
|
||||
Notifications: notificationsTable,
|
||||
Stats: statsTable,
|
||||
return &shared.KeyDatabase{
|
||||
OneTimeKeysTable: otk,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
StaleDeviceListsTable: sdl,
|
||||
CrossSigningKeysTable: csk,
|
||||
CrossSigningSigsTable: css,
|
||||
ServerName: serverName,
|
||||
DB: db,
|
||||
Writer: writer,
|
||||
LoginTokenLifetime: loginTokenLifetime,
|
||||
BcryptCost: bcryptCost,
|
||||
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,13 +39,26 @@ func NewUserDatabase(
|
|||
openIDTokenLifetimeMS int64,
|
||||
loginTokenLifetime time.Duration,
|
||||
serverNoticesLocalpart string,
|
||||
) (Database, error) {
|
||||
) (UserDatabase, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||
return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
}
|
||||
|
||||
// NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme)
|
||||
// and sets postgres connection parameters.
|
||||
func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (KeyDatabase, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewKeyDatabase(base, dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return postgres.NewKeyDatabase(base, dbProperties)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ var (
|
|||
ctx = context.Background()
|
||||
)
|
||||
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) {
|
||||
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
|
||||
|
|
@ -50,7 +50,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun
|
|||
// Tests storing and getting account data
|
||||
func Test_AccountData(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser(t)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
|
|
@ -81,7 +81,7 @@ func Test_AccountData(t *testing.T) {
|
|||
// Tests the creation of accounts
|
||||
func Test_Accounts(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
|
|
@ -161,7 +161,7 @@ func Test_Devices(t *testing.T) {
|
|||
accessToken := util.RandomString(16)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
|
||||
|
|
@ -241,7 +241,7 @@ func Test_KeyBackup(t *testing.T) {
|
|||
room := test.NewRoom(t, alice)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
wantAuthData := json.RawMessage("my auth data")
|
||||
|
|
@ -318,7 +318,7 @@ func Test_KeyBackup(t *testing.T) {
|
|||
func Test_LoginToken(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
// create a new token
|
||||
|
|
@ -350,7 +350,7 @@ func Test_OpenID(t *testing.T) {
|
|||
token := util.RandomString(24)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS
|
||||
|
|
@ -371,7 +371,7 @@ func Test_Profile(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
// create account, which also creates a profile
|
||||
|
|
@ -420,7 +420,7 @@ func Test_Pusher(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
appID := util.RandomString(8)
|
||||
|
|
@ -471,7 +471,7 @@ func Test_ThreePID(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
threePID := util.RandomString(8)
|
||||
medium := util.RandomString(8)
|
||||
|
|
@ -510,7 +510,7 @@ func Test_Notification(t *testing.T) {
|
|||
room := test.NewRoom(t, alice)
|
||||
room2 := test.NewRoom(t, alice)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
// generate some dummy notifications
|
||||
for i := 0; i < 10; i++ {
|
||||
|
|
@ -575,9 +575,9 @@ func Test_Notification(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
|
||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||
db, err := storage.NewUserDatabase(base, &base.Cfg.KeyServer.Database, "localhost", bcrypt.MinCost, 2000, time.Second, "")
|
||||
db, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create new database: %v", err)
|
||||
}
|
||||
|
|
@ -594,7 +594,7 @@ func MustNotError(t *testing.T, err error) {
|
|||
|
||||
func TestKeyChanges(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := MustCreateDatabase(t, dbType)
|
||||
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||
defer clean()
|
||||
_, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
|
|
@ -617,7 +617,7 @@ func TestKeyChanges(t *testing.T) {
|
|||
|
||||
func TestKeyChangesNoDupes(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := MustCreateDatabase(t, dbType)
|
||||
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||
defer clean()
|
||||
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
|
|
@ -643,7 +643,7 @@ func TestKeyChangesNoDupes(t *testing.T) {
|
|||
|
||||
func TestKeyChangesUpperLimit(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := MustCreateDatabase(t, dbType)
|
||||
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||
defer clean()
|
||||
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
|
|
@ -672,7 +672,7 @@ var deviceArray = []string{"AAA", "another_device"}
|
|||
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||
var err error
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := MustCreateDatabase(t, dbType)
|
||||
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||
defer clean()
|
||||
alice := "@alice:TestDeviceKeysStreamIDGeneration"
|
||||
bob := "@bob:TestDeviceKeysStreamIDGeneration"
|
||||
|
|
|
|||
|
|
@ -32,10 +32,10 @@ func NewUserAPIDatabase(
|
|||
openIDTokenLifetimeMS int64,
|
||||
loginTokenLifetime time.Duration,
|
||||
serverNoticesLocalpart string,
|
||||
) (Database, error) {
|
||||
) (UserDatabase, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||
return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return nil, fmt.Errorf("can't use Postgres implementation")
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -57,6 +57,11 @@ func NewInternalAPI(
|
|||
logrus.WithError(err).Panicf("failed to connect to accounts db")
|
||||
}
|
||||
|
||||
keyDB, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to connect to key db")
|
||||
}
|
||||
|
||||
syncProducer := producers.NewSyncAPI(
|
||||
db, js,
|
||||
// TODO: user API should handle syncs for account data. Right now,
|
||||
|
|
@ -69,11 +74,12 @@ func NewInternalAPI(
|
|||
keyChangeProducer := &producers.KeyChange{
|
||||
Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent),
|
||||
JetStream: js,
|
||||
DB: db,
|
||||
DB: keyDB,
|
||||
}
|
||||
|
||||
userAPI := &internal.UserInternalAPI{
|
||||
DB: db,
|
||||
KeyDatabase: keyDB,
|
||||
SyncProducer: syncProducer,
|
||||
KeyChangeProducer: keyChangeProducer,
|
||||
Config: cfg,
|
||||
|
|
@ -84,7 +90,7 @@ func NewInternalAPI(
|
|||
FedClient: fedClient,
|
||||
}
|
||||
|
||||
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, userAPI, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable
|
||||
updater := internal.NewDeviceListUpdater(base.ProcessContext, keyDB, userAPI, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable
|
||||
userAPI.Updater = updater
|
||||
// Remove users which we don't share a room with anymore
|
||||
if err := updater.CleanUp(); err != nil {
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ func (d *dummyProducer) PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, err
|
|||
return &nats.PubAck{}, nil
|
||||
}
|
||||
|
||||
func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.Database, func()) {
|
||||
func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.UserDatabase, func()) {
|
||||
if opts.loginTokenLifetime == 0 {
|
||||
opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond
|
||||
}
|
||||
|
|
@ -67,6 +67,13 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap
|
|||
t.Fatalf("failed to create account DB: %s", err)
|
||||
}
|
||||
|
||||
keyDB, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create key DB: %s", err)
|
||||
}
|
||||
|
||||
cfg := &config.UserAPI{
|
||||
Matrix: &config.Global{
|
||||
SigningIdentity: gomatrixserverlib.SigningIdentity{
|
||||
|
|
@ -76,9 +83,10 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap
|
|||
}
|
||||
|
||||
syncProducer := producers.NewSyncAPI(accountDB, &dummyProducer{}, "", "")
|
||||
keyChangeProducer := &producers.KeyChange{DB: accountDB, JetStream: &dummyProducer{}}
|
||||
keyChangeProducer := &producers.KeyChange{DB: keyDB, JetStream: &dummyProducer{}}
|
||||
return &internal.UserInternalAPI{
|
||||
DB: accountDB,
|
||||
KeyDatabase: keyDB,
|
||||
Config: cfg,
|
||||
SyncProducer: syncProducer,
|
||||
KeyChangeProducer: keyChangeProducer,
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ type PusherDevice struct {
|
|||
}
|
||||
|
||||
// GetPushDevices pushes to the configured devices of a local user.
|
||||
func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
|
||||
func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.UserDatabase) ([]*PusherDevice, error) {
|
||||
pushers, err := db.GetPushers(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.GetPushers: %w", err)
|
||||
|
|
|
|||
|
|
@ -13,11 +13,11 @@ import (
|
|||
)
|
||||
|
||||
// NotifyUserCountsAsync sends notifications to a local user's
|
||||
// notification destinations. Database lookups run synchronously, but
|
||||
// notification destinations. UserDatabase lookups run synchronously, but
|
||||
// a single goroutine is started when talking to the Push
|
||||
// gateways. There is no way to know when the background goroutine has
|
||||
// finished.
|
||||
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error {
|
||||
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.UserDatabase) error {
|
||||
pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
|||
Loading…
Reference in a new issue