From d3c7b4195d5958771b79fcb1f4ba948ab9a694a6 Mon Sep 17 00:00:00 2001 From: Till Faelligen Date: Tue, 5 Jul 2022 11:02:51 +0200 Subject: [PATCH] Fix QuerySharedUsers if no UserIDs are supplied --- roomserver/internal/input/input.go | 6 +- roomserver/roomserver_test.go | 69 +++++++++++++++++++ .../storage/postgres/membership_table.go | 22 +++++- roomserver/storage/shared/storage.go | 7 ++ .../storage/sqlite3/membership_table.go | 14 +++- 5 files changed, 113 insertions(+), 5 deletions(-) create mode 100644 roomserver/roomserver_test.go diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index fa07c1d2b..ecd4ecbb5 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/producers" "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -69,6 +70,7 @@ import ( // or C. type Inputer struct { Cfg *config.RoomServer + Base *base.BaseDendrite ProcessContext *process.ProcessContext DB storage.Database NATSClient *nats.Conn @@ -160,7 +162,9 @@ func (r *Inputer) startWorkerForRoom(roomID string) { // will look to see if we have a worker for that room which has its // own consumer. If we don't, we'll start one. func (r *Inputer) Start() error { - prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration) + if r.Base.EnableMetrics { + prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration) + } _, err := r.JetStream.Subscribe( "", // This is blank because we specified it in BindStream. func(m *nats.Msg) { diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go new file mode 100644 index 000000000..4e98af853 --- /dev/null +++ b/roomserver/roomserver_test.go @@ -0,0 +1,69 @@ +package roomserver_test + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" +) + +func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) { + base, close := testrig.CreateBaseDendrite(t, dbType) + db, err := storage.Open(base, &base.Cfg.KeyServer.Database, base.Caches) + if err != nil { + t.Fatalf("failed to create Database: %v", err) + } + return base, db, close +} + +func Test_SharedUsers(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) + + // Invite and join Bob + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, _, close := mustCreateDatabase(t, dbType) + defer close() + + rsAPI := roomserver.NewInternalAPI(base) + // SetFederationAPI starts the room event input consumer + rsAPI.SetFederationAPI(nil, nil) + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // Query the shared users for Alice, there should only be Bob. + // This is used by the SyncAPI keychange consumer. + res := &api.QuerySharedUsersResponse{} + if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil { + t.Fatalf("unable to query known users: %v", err) + } + if _, ok := res.UserIDsToCount[bob.ID]; !ok { + t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + } + // Also verify that we get the expected result when specifying OtherUserIDs. + // This is used by the SyncAPI when getting device list changes. + if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil { + t.Fatalf("unable to query known users: %v", err) + } + if _, ok := res.UserIDsToCount[bob.ID]; !ok { + t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + } + }) +} diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index c01753c3a..ce626ad1d 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -65,12 +65,18 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( ); ` -var selectJoinedUsersSetForRoomsSQL = "" + +var selectJoinedUsersSetForRoomsAndUserSQL = "" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + " WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " GROUP BY target_nid" +var selectJoinedUsersSetForRoomsSQL = "" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + + " WHERE room_nid = ANY($1) AND" + + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + + " GROUP BY target_nid" + // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + @@ -153,6 +159,7 @@ type membershipStatements struct { selectLocalMembershipsFromRoomStmt *sql.Stmt updateMembershipStmt *sql.Stmt selectRoomsWithMembershipStmt *sql.Stmt + selectJoinedUsersSetForRoomsAndUserStmt *sql.Stmt selectJoinedUsersSetForRoomsStmt *sql.Stmt selectKnownUsersStmt *sql.Stmt updateMembershipForgetRoomStmt *sql.Stmt @@ -178,6 +185,7 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, + {&s.selectJoinedUsersSetForRoomsAndUserStmt, selectJoinedUsersSetForRoomsAndUserSQL}, {&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL}, {&s.selectKnownUsersStmt, selectKnownUsersSQL}, {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, @@ -313,8 +321,18 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms( roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]int, error) { + var ( + rows *sql.Rows + err error + ) stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt) - rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs)) + if len(userNIDs) > 0 { + stmt = sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsAndUserStmt) + rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs)) + } else { + rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs)) + } + if err != nil { return nil, err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 67dcfdf38..5c633122d 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1214,6 +1214,13 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [ stateKeyNIDs[i] = nid i++ } + // If we didn't have any userIDs to look up, get the UserIDs for the returned userNIDToCount now + if len(userIDs) == 0 { + nidToUserID, err = d.EventStateKeys(ctx, stateKeyNIDs) + if err != nil { + return nil, err + } + } result := make(map[string]int, len(userNIDToCount)) for nid, count := range userNIDToCount { result[nidToUserID[nid]] = count diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 6f0fe8b64..570d3919c 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -41,12 +41,18 @@ const membershipSchema = ` ); ` -var selectJoinedUsersSetForRoomsSQL = "" + +var selectJoinedUsersSetForRoomsAndUserSQL = "" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + " WHERE room_nid IN ($1) AND target_nid IN ($2) AND" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " GROUP BY target_nid" +var selectJoinedUsersSetForRoomsSQL = "" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + + " WHERE room_nid IN ($1) AND " + + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + + " GROUP BY target_nid" + // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + @@ -293,8 +299,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, for _, v := range userNIDs { params = append(params, v) } + query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) - query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1) + if len(userNIDs) > 0 { + query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) + query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1) + } var rows *sql.Rows var err error if txn != nil {