diff --git a/userapi/producers/keychange.go b/userapi/producers/keychange.go index b1fb29123..f7664b10f 100644 --- a/userapi/producers/keychange.go +++ b/userapi/producers/keychange.go @@ -29,7 +29,7 @@ import ( type KeyChange struct { Topic string JetStream nats.JetStreamContext - DB storage.KeyserverDatabase + DB storage.KeyChangeDatabase } // ProduceKeyChanges creates new change events for each key diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go index 51eaa9856..68efca5b8 100644 --- a/userapi/producers/syncapi.go +++ b/userapi/producers/syncapi.go @@ -19,7 +19,7 @@ type JetStreamPublisher interface { // SyncAPI produces messages for the Sync API server to consume. type SyncAPI struct { - db storage.Database + db storage.Notification producer JetStreamPublisher clientDataTopic string notificationDataTopic string diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 3320b4f48..31f340ac8 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -145,7 +145,14 @@ type Database interface { KeyserverDatabase } +type KeyChangeDatabase interface { + // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change. + // `userID` is the the user who has changed their keys in some way. + StoreKeyChange(ctx context.Context, userID string) (int64, error) +} + type KeyserverDatabase interface { + KeyChangeDatabase // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database. ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) @@ -185,10 +192,6 @@ type KeyserverDatabase interface { // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) - // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change. - // `userID` is the the user who has changed their keys in some way. - StoreKeyChange(ctx context.Context, userID string) (int64, error) - // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive). // A to offset of types.OffsetNewest means no upper limit. // Returns the offset of the latest key change. diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 2a4777d74..057160374 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -78,7 +78,13 @@ func (s *accountDataStatements) InsertAccountData( roomID, dataType string, content json.RawMessage, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) - _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content) + // Empty/nil json.RawMessage is not interpreted as "nil", so use *json.RawMessage + // when passing the data to trigger "NOT NULL" constraint + var data *json.RawMessage + if len(content) > 0 { + data = &content + } + _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, data) return } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 85755c74e..17127ba00 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -21,7 +21,9 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/setup/config" @@ -38,6 +40,13 @@ const ( type apiTestOpts struct { loginTokenLifetime time.Duration + serverName string +} + +type dummyProducer struct{} + +func (d *dummyProducer) PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error) { + return &nats.PubAck{}, nil } func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.Database, func()) { @@ -46,9 +55,13 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap } base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) + sName := serverName + if opts.serverName != "" { + sName = gomatrixserverlib.ServerName(opts.serverName) + } accountDB, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") + }, sName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") if err != nil { t.Fatalf("failed to create account DB: %s", err) } @@ -56,14 +69,17 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap cfg := &config.UserAPI{ Matrix: &config.Global{ SigningIdentity: gomatrixserverlib.SigningIdentity{ - ServerName: serverName, + ServerName: sName, }, }, } + syncProducer := producers.NewSyncAPI(accountDB, &dummyProducer{}, "", "") + return &internal.UserInternalAPI{ - DB: accountDB, - Config: cfg, + DB: accountDB, + Config: cfg, + SyncProducer: syncProducer, }, accountDB, func() { close() baseclose() @@ -332,3 +348,87 @@ func TestQueryAccountByLocalpart(t *testing.T) { testCases(t, intAPI) }) } + +func TestAccountData(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + + testCases := []struct { + name string + inputData *api.InputAccountDataRequest + wantErr bool + }{ + { + name: "not a local user", + inputData: &api.InputAccountDataRequest{UserID: "@notlocal:example.com"}, + wantErr: true, + }, + { + name: "local user missing datatype", + inputData: &api.InputAccountDataRequest{UserID: alice.ID}, + wantErr: true, + }, + { + name: "missing json", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: nil}, + wantErr: true, + }, + { + name: "with json", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}")}, + }, + { + name: "room data", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}"), RoomID: "!dummy:test"}, + }, + { + name: "ignored users", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.ignored_user_list", AccountData: []byte("{}")}, + }, + { + name: "m.fully_read", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.fully_read", AccountData: []byte("{}")}, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType) + defer close() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res := api.InputAccountDataResponse{} + err := intAPI.InputAccountData(ctx, tc.inputData, &res) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + + // query the data again and compare + queryRes := api.QueryAccountDataResponse{} + queryReq := api.QueryAccountDataRequest{ + UserID: tc.inputData.UserID, + DataType: tc.inputData.DataType, + RoomID: tc.inputData.RoomID, + } + err = intAPI.QueryAccountData(ctx, &queryReq, &queryRes) + if err != nil && !tc.wantErr { + t.Fatal(err) + } + // verify global data + if tc.inputData.RoomID == "" { + if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.GlobalAccountData[tc.inputData.DataType]) { + t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.GlobalAccountData[tc.inputData.DataType])) + } + } else { + // verify room data + if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType]) { + t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType])) + } + } + }) + } + }) +}