diff --git a/clientapi/producers/syncapi.go b/clientapi/producers/syncapi.go index 48b1ae88d..0ac637793 100644 --- a/clientapi/producers/syncapi.go +++ b/clientapi/producers/syncapi.go @@ -17,6 +17,7 @@ package producers import ( "context" "encoding/json" + "fmt" "strconv" "time" @@ -83,7 +84,7 @@ func (p *SyncAPIProducer) SendReceipt( m.Header.Set(jetstream.RoomID, roomID) m.Header.Set(jetstream.EventID, eventID) m.Header.Set("type", receiptType) - m.Header.Set("timestamp", strconv.Itoa(int(timestamp))) + m.Header.Set("timestamp", fmt.Sprintf("%d", timestamp)) log.WithFields(log.Fields{}).Tracef("Producing to topic '%s'", p.TopicReceiptEvent) _, err := p.JetStream.PublishMsg(m, nats.Context(ctx)) diff --git a/federationapi/consumers/receipts.go b/federationapi/consumers/receipts.go index 9300451eb..2c9d79bcb 100644 --- a/federationapi/consumers/receipts.go +++ b/federationapi/consumers/receipts.go @@ -90,7 +90,7 @@ func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msg *nats.Msg) bo return true } - timestamp, err := strconv.Atoi(msg.Header.Get("timestamp")) + timestamp, err := strconv.ParseUint(msg.Header.Get("timestamp"), 10, 64) if err != nil { // If the message was invalid, log it and move on to the next message in the stream log.WithError(err).Errorf("EDU output log: message parse failure") diff --git a/federationapi/producers/syncapi.go b/federationapi/producers/syncapi.go index e371baaaa..43dd08dd8 100644 --- a/federationapi/producers/syncapi.go +++ b/federationapi/producers/syncapi.go @@ -53,7 +53,7 @@ func (p *SyncAPIProducer) SendReceipt( m.Header.Set(jetstream.RoomID, roomID) m.Header.Set(jetstream.EventID, eventID) m.Header.Set("type", receiptType) - m.Header.Set("timestamp", strconv.Itoa(int(timestamp))) + m.Header.Set("timestamp", fmt.Sprintf("%d", timestamp)) log.WithFields(log.Fields{}).Tracef("Producing to topic '%s'", p.TopicReceiptEvent) _, err := p.JetStream.PublishMsg(m, nats.Context(ctx)) diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index 9940eac60..44cfb5f2a 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -1,36 +1,26 @@ -package storage +package storage_test import ( "context" - "fmt" - "io/ioutil" - "log" - "os" "reflect" "testing" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" ) var ctx = context.Background() -func MustCreateDatabase(t *testing.T) (Database, func()) { - tmpfile, err := ioutil.TempFile("", "keyserver_storage_test") +func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + base, close := testrig.CreateBaseDendrite(t, dbType) + db, err := storage.NewDatabase(base, &base.Cfg.KeyServer.Database) if err != nil { - log.Fatal(err) - } - t.Logf("Database %s", tmpfile.Name()) - db, err := NewDatabase(nil, &config.DatabaseOptions{ - ConnectionString: config.DataSource(fmt.Sprintf("file://%s", tmpfile.Name())), - }) - if err != nil { - t.Fatalf("Failed to NewDatabase: %s", err) - } - return db, func() { - os.Remove(tmpfile.Name()) + t.Fatalf("failed to create new database: %v", err) } + return db, close } func MustNotError(t *testing.T, err error) { @@ -42,151 +32,159 @@ func MustNotError(t *testing.T, err error) { } func TestKeyChanges(t *testing.T) { - db, clean := MustCreateDatabase(t) - defer clean() - _, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") - MustNotError(t, err) - deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeIDC { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) - } - if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := MustCreateDatabase(t, dbType) + defer clean() + _, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeIDC { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) + } + if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) } func TestKeyChangesNoDupes(t *testing.T) { - db, clean := MustCreateDatabase(t) - defer clean() - deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - if deviceChangeIDA == deviceChangeIDB { - t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) - } - deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeID { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) - } - if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := MustCreateDatabase(t, dbType) + defer clean() + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + if deviceChangeIDA == deviceChangeIDB { + t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) + } + deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeID { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) + } + if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) } func TestKeyChangesUpperLimit(t *testing.T) { - db, clean := MustCreateDatabase(t) - defer clean() - deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") - MustNotError(t, err) - _, err = db.StoreKeyChange(ctx, "@charlie:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeIDB { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) - } - if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := MustCreateDatabase(t, dbType) + defer clean() + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + _, err = db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeIDB { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) + } + if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) } // The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, // and that they are returned correctly when querying for device keys. func TestDeviceKeysStreamIDGeneration(t *testing.T) { var err error - db, clean := MustCreateDatabase(t) - defer clean() - alice := "@alice:TestDeviceKeysStreamIDGeneration" - bob := "@bob:TestDeviceKeysStreamIDGeneration" - msgs := []api.DeviceMessage{ - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: alice, - KeyJSON: []byte(`{"key":"v1"}`), + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := MustCreateDatabase(t, dbType) + defer clean() + alice := "@alice:TestDeviceKeysStreamIDGeneration" + bob := "@bob:TestDeviceKeysStreamIDGeneration" + msgs := []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 }, - // StreamID: 1 - }, - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: bob, - KeyJSON: []byte(`{"key":"v1"}`), + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: bob, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 as this is a different user }, - // StreamID: 1 as this is a different user - }, - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "another_device", - UserID: alice, - KeyJSON: []byte(`{"key":"v1"}`), + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "another_device", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 2 as this is a 2nd device key }, - // StreamID: 2 as this is a 2nd device key - }, - } - MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) - if msgs[0].StreamID != 1 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) - } - if msgs[1].StreamID != 1 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) - } - if msgs[2].StreamID != 2 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) - } - - // updating a device sets the next stream ID for that user - msgs = []api.DeviceMessage{ - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: alice, - KeyJSON: []byte(`{"key":"v2"}`), - }, - // StreamID: 3 - }, - } - MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) - if msgs[0].StreamID != 3 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) - } - - // Querying for device keys returns the latest stream IDs - msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false) - if err != nil { - t.Fatalf("DeviceKeysForUser returned error: %s", err) - } - wantStreamIDs := map[string]int64{ - "AAA": 3, - "another_device": 2, - } - if len(msgs) != len(wantStreamIDs) { - t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) - } - for _, m := range msgs { - if m.StreamID != wantStreamIDs[m.DeviceID] { - t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) } - } + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 1 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) + } + if msgs[1].StreamID != 1 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) + } + if msgs[2].StreamID != 2 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) + } + + // updating a device sets the next stream ID for that user + msgs = []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v2"}`), + }, + // StreamID: 3 + }, + } + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 3 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) + } + + // Querying for device keys returns the latest stream IDs + msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false) + if err != nil { + t.Fatalf("DeviceKeysForUser returned error: %s", err) + } + wantStreamIDs := map[string]int64{ + "AAA": 3, + "another_device": 2, + } + if len(msgs) != len(wantStreamIDs) { + t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) + } + for _, m := range msgs { + if m.StreamID != wantStreamIDs[m.DeviceID] { + t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) + } + } + }) } diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index d59b8be7a..1a11586a5 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -14,6 +14,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/producers" "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -39,6 +40,7 @@ type RoomserverInternalAPI struct { *perform.Upgrader *perform.Admin ProcessContext *process.ProcessContext + Base *base.BaseDendrite DB storage.Database Cfg *config.RoomServer Cache caching.RoomServerCaches @@ -56,33 +58,38 @@ type RoomserverInternalAPI struct { } func NewRoomserverAPI( - processCtx *process.ProcessContext, cfg *config.RoomServer, roomserverDB storage.Database, - js nats.JetStreamContext, nc *nats.Conn, inputRoomEventTopic string, - caches caching.RoomServerCaches, perspectiveServerNames []gomatrixserverlib.ServerName, + base *base.BaseDendrite, roomserverDB storage.Database, + js nats.JetStreamContext, nc *nats.Conn, ) *RoomserverInternalAPI { + var perspectiveServerNames []gomatrixserverlib.ServerName + for _, kp := range base.Cfg.FederationAPI.KeyPerspectives { + perspectiveServerNames = append(perspectiveServerNames, kp.ServerName) + } + serverACLs := acls.NewServerACLs(roomserverDB) producer := &producers.RoomEventProducer{ - Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent)), + Topic: string(base.Cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent)), JetStream: js, ACLs: serverACLs, } a := &RoomserverInternalAPI{ - ProcessContext: processCtx, + ProcessContext: base.ProcessContext, DB: roomserverDB, - Cfg: cfg, - Cache: caches, - ServerName: cfg.Matrix.ServerName, + Base: base, + Cfg: &base.Cfg.RoomServer, + Cache: base.Caches, + ServerName: base.Cfg.Global.ServerName, PerspectiveServerNames: perspectiveServerNames, - InputRoomEventTopic: inputRoomEventTopic, + InputRoomEventTopic: base.Cfg.Global.JetStream.Prefixed(jetstream.InputRoomEvent), OutputProducer: producer, JetStream: js, NATSClient: nc, - Durable: cfg.Matrix.JetStream.Durable("RoomserverInputConsumer"), + Durable: base.Cfg.Global.JetStream.Durable("RoomserverInputConsumer"), ServerACLs: serverACLs, Queryer: &query.Queryer{ DB: roomserverDB, - Cache: caches, - ServerName: cfg.Matrix.ServerName, + Cache: base.Caches, + ServerName: base.Cfg.Global.ServerName, ServerACLs: serverACLs, }, // perform-er structs get initialised when we have a federation sender to use @@ -98,8 +105,9 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio r.KeyRing = keyRing r.Inputer = &input.Inputer{ - Cfg: r.Cfg, - ProcessContext: r.ProcessContext, + Cfg: &r.Base.Cfg.RoomServer, + Base: r.Base, + ProcessContext: r.Base.ProcessContext, DB: r.DB, InputRoomEventTopic: r.InputRoomEventTopic, OutputProducer: r.OutputProducer, diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index fa07c1d2b..ecd4ecbb5 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/producers" "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -69,6 +70,7 @@ import ( // or C. type Inputer struct { Cfg *config.RoomServer + Base *base.BaseDendrite ProcessContext *process.ProcessContext DB storage.Database NATSClient *nats.Conn @@ -160,7 +162,9 @@ func (r *Inputer) startWorkerForRoom(roomID string) { // will look to see if we have a worker for that room which has its // own consumer. If we don't, we'll start one. func (r *Inputer) Start() error { - prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration) + if r.Base.EnableMetrics { + prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration) + } _, err := r.JetStream.Subscribe( "", // This is blank because we specified it in BindStream. func(m *nats.Msg) { diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index eb68100fe..1f707735b 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -17,13 +17,10 @@ package roomserver import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/inthttp" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/roomserver/internal" + "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/jetstream" "github.com/sirupsen/logrus" ) @@ -40,11 +37,6 @@ func NewInternalAPI( ) api.RoomserverInternalAPI { cfg := &base.Cfg.RoomServer - var perspectiveServerNames []gomatrixserverlib.ServerName - for _, kp := range base.Cfg.FederationAPI.KeyPerspectives { - perspectiveServerNames = append(perspectiveServerNames, kp.ServerName) - } - roomserverDB, err := storage.Open(base, &cfg.Database, base.Caches) if err != nil { logrus.WithError(err).Panicf("failed to connect to room server db") @@ -53,8 +45,6 @@ func NewInternalAPI( js, nc := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) return internal.NewRoomserverAPI( - base.ProcessContext, cfg, roomserverDB, js, nc, - cfg.Matrix.JetStream.Prefixed(jetstream.InputRoomEvent), - base.Caches, perspectiveServerNames, + base, roomserverDB, js, nc, ) } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go new file mode 100644 index 000000000..4e98af853 --- /dev/null +++ b/roomserver/roomserver_test.go @@ -0,0 +1,69 @@ +package roomserver_test + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" +) + +func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) { + base, close := testrig.CreateBaseDendrite(t, dbType) + db, err := storage.Open(base, &base.Cfg.KeyServer.Database, base.Caches) + if err != nil { + t.Fatalf("failed to create Database: %v", err) + } + return base, db, close +} + +func Test_SharedUsers(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) + + // Invite and join Bob + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, _, close := mustCreateDatabase(t, dbType) + defer close() + + rsAPI := roomserver.NewInternalAPI(base) + // SetFederationAPI starts the room event input consumer + rsAPI.SetFederationAPI(nil, nil) + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // Query the shared users for Alice, there should only be Bob. + // This is used by the SyncAPI keychange consumer. + res := &api.QuerySharedUsersResponse{} + if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil { + t.Fatalf("unable to query known users: %v", err) + } + if _, ok := res.UserIDsToCount[bob.ID]; !ok { + t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + } + // Also verify that we get the expected result when specifying OtherUserIDs. + // This is used by the SyncAPI when getting device list changes. + if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil { + t.Fatalf("unable to query known users: %v", err) + } + if _, ok := res.UserIDsToCount[bob.ID]; !ok { + t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + } + }) +} diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 6c4e4b860..91f271652 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -110,7 +110,7 @@ func (v *StateResolution) LoadStateAtEvent( snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) if err != nil { - return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %s", eventID, err) + return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err) } if snapshotNID == 0 { return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID) diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index c01753c3a..ce626ad1d 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -65,12 +65,18 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( ); ` -var selectJoinedUsersSetForRoomsSQL = "" + +var selectJoinedUsersSetForRoomsAndUserSQL = "" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + " WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " GROUP BY target_nid" +var selectJoinedUsersSetForRoomsSQL = "" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + + " WHERE room_nid = ANY($1) AND" + + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + + " GROUP BY target_nid" + // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + @@ -153,6 +159,7 @@ type membershipStatements struct { selectLocalMembershipsFromRoomStmt *sql.Stmt updateMembershipStmt *sql.Stmt selectRoomsWithMembershipStmt *sql.Stmt + selectJoinedUsersSetForRoomsAndUserStmt *sql.Stmt selectJoinedUsersSetForRoomsStmt *sql.Stmt selectKnownUsersStmt *sql.Stmt updateMembershipForgetRoomStmt *sql.Stmt @@ -178,6 +185,7 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, + {&s.selectJoinedUsersSetForRoomsAndUserStmt, selectJoinedUsersSetForRoomsAndUserSQL}, {&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL}, {&s.selectKnownUsersStmt, selectKnownUsersSQL}, {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, @@ -313,8 +321,18 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms( roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]int, error) { + var ( + rows *sql.Rows + err error + ) stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt) - rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs)) + if len(userNIDs) > 0 { + stmt = sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsAndUserStmt) + rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs)) + } else { + rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs)) + } + if err != nil { return nil, err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 67dcfdf38..3191280cb 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -263,6 +263,12 @@ func (d *Database) snapshotNIDFromEventID( ctx context.Context, txn *sql.Tx, eventID string, ) (types.StateSnapshotNID, error) { _, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID) + if err != nil { + return 0, err + } + if stateNID == 0 { + return 0, sql.ErrNoRows // effectively there's no state entry + } return stateNID, err } @@ -1214,6 +1220,13 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [ stateKeyNIDs[i] = nid i++ } + // If we didn't have any userIDs to look up, get the UserIDs for the returned userNIDToCount now + if len(userIDs) == 0 { + nidToUserID, err = d.EventStateKeys(ctx, stateKeyNIDs) + if err != nil { + return nil, err + } + } result := make(map[string]int, len(userNIDToCount)) for nid, count := range userNIDToCount { result[nidToUserID[nid]] = count diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 6f0fe8b64..570d3919c 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -41,12 +41,18 @@ const membershipSchema = ` ); ` -var selectJoinedUsersSetForRoomsSQL = "" + +var selectJoinedUsersSetForRoomsAndUserSQL = "" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + " WHERE room_nid IN ($1) AND target_nid IN ($2) AND" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " GROUP BY target_nid" +var selectJoinedUsersSetForRoomsSQL = "" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + + " WHERE room_nid IN ($1) AND " + + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + + " GROUP BY target_nid" + // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + @@ -293,8 +299,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, for _, v := range userNIDs { params = append(params, v) } + query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) - query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1) + if len(userNIDs) > 0 { + query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) + query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1) + } var rows *sql.Rows var err error if txn != nil { diff --git a/setup/config/config_appservice.go b/setup/config/config_appservice.go index ff3287714..9b89fc9af 100644 --- a/setup/config/config_appservice.go +++ b/setup/config/config_appservice.go @@ -187,7 +187,7 @@ func loadAppServices(config *AppServiceAPI, derived *Derived) error { } // Load the config data into our struct - if err = yaml.UnmarshalStrict(configData, &appservice); err != nil { + if err = yaml.Unmarshal(configData, &appservice); err != nil { return err } @@ -315,6 +315,20 @@ func checkErrors(config *AppServiceAPI, derived *Derived) (err error) { } } + // Check required fields + if appservice.ID == "" { + return ConfigErrors([]string{"Application service ID is required"}) + } + if appservice.ASToken == "" { + return ConfigErrors([]string{"Application service Token is required"}) + } + if appservice.HSToken == "" { + return ConfigErrors([]string{"Homeserver Token is required"}) + } + if appservice.SenderLocalpart == "" { + return ConfigErrors([]string{"Sender Localpart is required"}) + } + // Check if the url has trailing /'s. If so, remove them appservice.URL = strings.TrimRight(appservice.URL, "/") diff --git a/syncapi/consumers/receipts.go b/syncapi/consumers/receipts.go index 6bb0747f0..83156cf93 100644 --- a/syncapi/consumers/receipts.go +++ b/syncapi/consumers/receipts.go @@ -87,7 +87,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Ms Type: msg.Header.Get("type"), } - timestamp, err := strconv.Atoi(msg.Header.Get("timestamp")) + timestamp, err := strconv.ParseUint(msg.Header.Get("timestamp"), 10, 64) if err != nil { // If the message was invalid, log it and move on to the next message in the stream log.WithError(err).Errorf("output log: message parse failure")