diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 8b02f3d6c..70cd8f718 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -24,6 +24,8 @@ import ( "sync" "time" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -102,6 +104,7 @@ type DeviceListUpdater struct { // block on or timeout via a select. userIDToChan map[string]chan bool userIDToChanMu *sync.Mutex + rsAPI rsapi.KeyserverRoomserverAPI } // DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater. @@ -124,6 +127,8 @@ type DeviceListUpdaterDatabase interface { // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + + DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error } type DeviceListUpdaterAPI interface { @@ -140,6 +145,7 @@ func NewDeviceListUpdater( process *process.ProcessContext, db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, + rsAPI rsapi.KeyserverRoomserverAPI, ) *DeviceListUpdater { return &DeviceListUpdater{ process: process, @@ -152,11 +158,16 @@ func NewDeviceListUpdater( workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), userIDToChan: make(map[string]chan bool), userIDToChanMu: &sync.Mutex{}, + rsAPI: rsAPI, } } // Start the device list updater, which will try to refresh any stale device lists. func (u *DeviceListUpdater) Start() error { + if err := u.cleanUp(); err != nil { + return fmt.Errorf("failed to cleanup stale device lists: %w", err) + } + for i := 0; i < len(u.workerChans); i++ { // Allocate a small buffer per channel. // If the buffer limit is reached, backpressure will cause the processing of EDUs @@ -166,7 +177,7 @@ func (u *DeviceListUpdater) Start() error { go u.worker(ch) } - staleLists, err := u.db.StaleDeviceLists(context.Background(), []gomatrixserverlib.ServerName{}) + staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) if err != nil { return err } @@ -184,6 +195,36 @@ func (u *DeviceListUpdater) Start() error { return nil } +// cleanUp removes stale device entries for users we don't share a room with anymore +func (u *DeviceListUpdater) cleanUp() error { + if u.rsAPI == nil { + return nil + } + staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) + if err != nil { + return err + } + + maxRetries := 3 + + // In polylith mode, the roomserver api might not be up yet, so we try again + res := rsapi.QueryLeftUsersResponse{} + for i := 0; i <= maxRetries; i++ { + if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{UserIDs: staleUsers}, &res); err != nil { + if i == maxRetries { + return err + } + logrus.WithError(err).Warnf("unable to query left users (try %d/%d)", i+1, maxRetries) + time.Sleep(time.Second * 3) + } + } + if len(res.UserIDs) == 0 { + return nil + } + logrus.Debugf("Deleting %d stale device list entries", len(res.UserIDs)) + return u.db.DeleteStaleDeviceLists(u.process.Context(), res.UserIDs) +} + func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex { u.mu.Lock() defer u.mu.Unlock() diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index 28a13a0a0..d277c137e 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -53,6 +53,10 @@ type mockDeviceListUpdaterDatabase struct { mu sync.Mutex // protect staleUsers } +func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error { + return nil +} + // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { @@ -147,7 +151,7 @@ func TestUpdateHavePrevID(t *testing.T) { } ap := &mockDeviceListUpdaterAPI{} producer := &mockKeyChangeProducer{} - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil) event := gomatrixserverlib.DeviceListUpdateEvent{ DeviceDisplayName: "Foo Bar", Deleted: false, @@ -219,7 +223,7 @@ func TestUpdateNoPrevID(t *testing.T) { `)), }, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil) if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } @@ -288,7 +292,7 @@ func TestDebounce(t *testing.T) { close(incomingFedReq) return <-fedCh, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil) if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 9ae4f9ca3..8cc31f77a 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -16,6 +16,7 @@ package keyserver import ( "github.com/gorilla/mux" + rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/sirupsen/logrus" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" @@ -40,6 +41,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI, + rsAPI rsapi.KeyserverRoomserverAPI, ) api.KeyInternalAPI { js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) @@ -47,6 +49,7 @@ func NewInternalAPI( if err != nil { logrus.WithError(err).Panicf("failed to connect to key server database") } + keyChangeProducer := &producers.KeyChange{ Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)), JetStream: js, @@ -58,7 +61,7 @@ func NewInternalAPI( FedClient: fedClient, Producer: keyChangeProducer, } - updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable + updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI) // 8 workers TODO: configurable ap.Updater = updater go func() { if err := updater.Start(); err != nil { diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 242e16a06..c6a8f44cd 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -85,4 +85,9 @@ type Database interface { StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + + DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, + ) error } diff --git a/keyserver/storage/postgres/stale_device_lists.go b/keyserver/storage/postgres/stale_device_lists.go index 63281adfb..abd523f23 100644 --- a/keyserver/storage/postgres/stale_device_lists.go +++ b/keyserver/storage/postgres/stale_device_lists.go @@ -19,6 +19,10 @@ import ( "database/sql" "time" + "github.com/lib/pq" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" @@ -48,10 +52,14 @@ const selectStaleDeviceListsWithDomainsSQL = "" + const selectStaleDeviceListsSQL = "" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" +const deleteStaleDevicesSQL = "" + + "DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)" + type staleDeviceListsStatements struct { upsertStaleDeviceListStmt *sql.Stmt selectStaleDeviceListsWithDomainsStmt *sql.Stmt selectStaleDeviceListsStmt *sql.Stmt + deleteStaleDeviceListsStmt *sql.Stmt } func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { @@ -60,16 +68,12 @@ func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, erro if err != nil { return nil, err } - if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, + {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, + {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, + {&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, + }.Prepare(db) } func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { @@ -105,6 +109,15 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte return result, nil } +// DeleteStaleDeviceLists removes users from stale device lists +func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( + ctx context.Context, txn *sql.Tx, userIDs []string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt) + _, err := stmt.ExecContext(ctx, pq.Array(userIDs)) + return err +} + func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") for rows.Next() { diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 5beeed0f1..54dd6ddc9 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -249,3 +249,13 @@ func (d *Database) StoreCrossSigningSigsForTarget( return nil }) } + +// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore. +func (d *Database) DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs) + }) +} diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go index fc2cc37c4..54702c885 100644 --- a/keyserver/storage/sqlite3/stale_device_lists.go +++ b/keyserver/storage/sqlite3/stale_device_lists.go @@ -17,8 +17,11 @@ package sqlite3 import ( "context" "database/sql" + "strings" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" @@ -48,11 +51,15 @@ const selectStaleDeviceListsWithDomainsSQL = "" + const selectStaleDeviceListsSQL = "" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" +const deleteStaleDevicesSQL = "" + + "DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)" + type staleDeviceListsStatements struct { db *sql.DB upsertStaleDeviceListStmt *sql.Stmt selectStaleDeviceListsWithDomainsStmt *sql.Stmt selectStaleDeviceListsStmt *sql.Stmt + // deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime } func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { @@ -63,16 +70,12 @@ func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) if err != nil { return nil, err } - if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, + {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, + {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, + // { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime + }.Prepare(db) } func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { @@ -108,6 +111,27 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte return result, nil } +// DeleteStaleDeviceLists removes users from stale device lists +func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( + ctx context.Context, txn *sql.Tx, userIDs []string, +) error { + qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1) + stmt, err := s.db.Prepare(qry) + if err != nil { + return err + } + defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed") + stmt = sqlutil.TxStmt(txn, stmt) + + params := make([]any, len(userIDs)) + for i := range userIDs { + params[i] = userIDs[i] + } + + _, err = stmt.ExecContext(ctx, params...) + return err +} + func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") for rows.Next() { diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 37a010a7c..24da1125e 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -56,6 +56,7 @@ type KeyChanges interface { type StaleDeviceLists interface { InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error } type CrossSigningKeys interface { diff --git a/keyserver/storage/tables/stale_device_lists_test.go b/keyserver/storage/tables/stale_device_lists_test.go new file mode 100644 index 000000000..76d3baddd --- /dev/null +++ b/keyserver/storage/tables/stale_device_lists_test.go @@ -0,0 +1,94 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/config" + + "github.com/matrix-org/dendrite/keyserver/storage/postgres" + "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/test" +) + +func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, nil) + if err != nil { + t.Fatalf("failed to open database: %s", err) + } + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresStaleDeviceListsTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db) + } + if err != nil { + t.Fatalf("failed to create new table: %s", err) + } + return tab, close +} + +func TestStaleDeviceLists(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := "@charlie:localhost" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustCreateTable(t, dbType) + defer closeDB() + + if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + + // Query one server + wantStaleUsers := []string{alice.ID, bob.ID} + gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { + t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) + } + + // Query all servers + wantStaleUsers = []string{alice.ID, bob.ID, charlie} + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { + t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) + } + + // Delete stale devices + deleteUsers := []string{alice.ID, bob.ID} + if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil { + t.Fatalf("failed to delete stale device lists: %s", err) + } + + // Verify we don't get anything back after deleting + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + + if gotCount := len(gotStaleUsers); gotCount > 0 { + t.Fatalf("expected no stale users, got %d", gotCount) + } + }) +} diff --git a/roomserver/api/api.go b/roomserver/api/api.go index a1373a62b..5850f0de3 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -17,6 +17,7 @@ type RoomserverInternalAPI interface { ClientRoomserverAPI UserRoomserverAPI FederationRoomserverAPI + KeyserverRoomserverAPI // needed to avoid chicken and egg scenario when setting up the // interdependencies between the roomserver and other input APIs @@ -198,3 +199,7 @@ type FederationRoomserverAPI interface { // Query a given amount (or less) of events prior to a given set of events. PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error } + +type KeyserverRoomserverAPI interface { + QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error +} diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 342a3904c..81b0dc2dc 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -19,6 +19,10 @@ type RoomserverInternalAPITrace struct { Impl RoomserverInternalAPI } +func (t *RoomserverInternalAPITrace) QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error { + return t.Impl.QueryLeftUsers(ctx, req, res) +} + func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) { t.Impl.SetFederationAPI(fsAPI, keyRing) } diff --git a/roomserver/api/query.go b/roomserver/api/query.go index b62907f3c..51008825b 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -447,3 +447,11 @@ type QueryMembershipAtEventResponse struct { // do not have known state will return an empty array here. Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"` } + +type QueryLeftUsersRequest struct { + UserIDs []string `json:"user_ids"` +} + +type QueryLeftUsersResponse struct { + UserIDs []string `json:"user_ids"` +} diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 8850e5c46..5c5a62d8c 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -805,6 +805,12 @@ func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkS return nil } +func (r *Queryer) QueryLeftUsers(ctx context.Context, req *api.QueryLeftUsersRequest, res *api.QueryLeftUsersResponse) error { + var err error + res.UserIDs, err = r.DB.GetLeftUsers(ctx, req.UserIDs) + return err +} + func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join") if err != nil { diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 1bd1b3fb7..8a2e0a03c 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -63,6 +63,7 @@ const ( RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain" RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed" RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent" + RoomserverQueryLeftMembersPath = "/roomserver/queryLeftMembers" ) type httpRoomserverInternalAPI struct { @@ -553,3 +554,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context, h.httpClient, ctx, request, response, ) } + +func (h *httpRoomserverInternalAPI) QueryLeftUsers(ctx context.Context, request *api.QueryLeftUsersRequest, response *api.QueryLeftUsersResponse) error { + return httputil.CallInternalRPCAPI( + "RoomserverQueryLeftMembers", h.roomserverURL+RoomserverQueryLeftMembersPath, + h.httpClient, ctx, request, response, + ) +} diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 4d37e90b5..3cb05548a 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -203,4 +203,9 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { RoomserverQueryMembershipAtEventPath, httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", r.QueryMembershipAtEvent), ) + + internalAPIMux.Handle( + RoomserverQueryLeftMembersPath, + httputil.MakeInternalRPCAPI("RoomserverQueryLeftMembersPath", r.QueryLeftUsers), + ) } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index c39a8cbba..740be8dc6 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -172,4 +172,6 @@ type Database interface { ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) + + GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) } diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 0150534e1..ce8888b2f 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -157,6 +157,12 @@ const selectServerInRoomSQL = "" + " JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + " WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1" +const selectJoinedUsersSQL = ` +SELECT DISTINCT target_nid +FROM roomserver_membership m +WHERE membership_nid > $1 AND target_nid = ANY($2) +` + type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt @@ -174,6 +180,7 @@ type membershipStatements struct { selectLocalServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt deleteMembershipStmt *sql.Stmt + selectJoinedUsersStmt *sql.Stmt } func CreateMembershipTable(db *sql.DB) error { @@ -209,9 +216,33 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL}, {&s.selectServerInRoomStmt, selectServerInRoomSQL}, {&s.deleteMembershipStmt, deleteMembershipSQL}, + {&s.selectJoinedUsersStmt, selectJoinedUsersSQL}, }.Prepare(db) } +func (s *membershipStatements) SelectJoinedUsers( + ctx context.Context, txn *sql.Tx, + targetUserNIDs []types.EventStateKeyNID, +) ([]types.EventStateKeyNID, error) { + result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs)) + + stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt) + rows, err := stmt.QueryContext(ctx, tables.MembershipStateLeaveOrBan, pq.Array(targetUserNIDs)) + if err != nil { + return nil, err + } + + var targetNID types.EventStateKeyNID + for rows.Next() { + if err = rows.Scan(&targetNID); err != nil { + return nil, err + } + result = append(result, targetNID) + } + + return result, rows.Err() +} + func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 4455ec3bf..0a46514af 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1343,6 +1343,36 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [ return result, nil } +func (d *Database) GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) { + stateKeyNIDMap, err := d.EventStateKeyNIDs(ctx, userIDs) + if err != nil { + return nil, err + } + + userNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDMap)) + userNIDtoUserID := make(map[types.EventStateKeyNID]string, len(stateKeyNIDMap)) + for userID, nid := range stateKeyNIDMap { + userNIDs = append(userNIDs, nid) + userNIDtoUserID[nid] = userID + } + + stillJoinedUsersNIDs, err := d.MembershipTable.SelectJoinedUsers(ctx, nil, userNIDs) + if err != nil { + return nil, err + } + + for _, joinedUser := range stillJoinedUsersNIDs { + delete(userNIDtoUserID, joinedUser) + } + + leftUsers := make([]string, 0, len(userNIDtoUserID)) + for _, userID := range userNIDtoUserID { + leftUsers = append(leftUsers, userID) + } + + return leftUsers, nil +} + // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID) diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index cd149f0ed..66ed3f45f 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -133,6 +133,12 @@ const selectServerInRoomSQL = "" + const deleteMembershipSQL = "" + "DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2" +const selectJoinedUsersSQL = ` +SELECT DISTINCT target_nid +FROM roomserver_membership m +WHERE membership_nid > $1 AND target_nid IN ($2) +` + type membershipStatements struct { db *sql.DB insertMembershipStmt *sql.Stmt @@ -149,6 +155,7 @@ type membershipStatements struct { selectLocalServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt deleteMembershipStmt *sql.Stmt + // selectJoinedUsersStmt *sql.Stmt // Prepared at runtime } func CreateMembershipTable(db *sql.DB) error { @@ -412,3 +419,40 @@ func (s *membershipStatements) DeleteMembership( ) return err } + +func (s *membershipStatements) SelectJoinedUsers( + ctx context.Context, txn *sql.Tx, + targetUserNIDs []types.EventStateKeyNID, +) ([]types.EventStateKeyNID, error) { + result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs)) + + qry := strings.Replace(selectJoinedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(targetUserNIDs), 1), 1) + + stmt, err := s.db.Prepare(qry) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsers: stmt.Close failed") + + params := make([]any, len(targetUserNIDs)+1) + params[0] = tables.MembershipStateLeaveOrBan + for i := range targetUserNIDs { + params[i+1] = targetUserNIDs[i] + } + + stmt = sqlutil.TxStmt(txn, stmt) + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + + var targetNID types.EventStateKeyNID + for rows.Next() { + if err = rows.Scan(&targetNID); err != nil { + return nil, err + } + result = append(result, targetNID) + } + + return result, rows.Err() +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 50d27c756..80fcf72dd 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -144,6 +144,7 @@ type Membership interface { SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error + SelectJoinedUsers(ctx context.Context, txn *sql.Tx, targetUserNIDs []types.EventStateKeyNID) ([]types.EventStateKeyNID, error) } type Published interface { diff --git a/roomserver/storage/tables/membership_table_test.go b/roomserver/storage/tables/membership_table_test.go index c9541d9d2..c4524ee44 100644 --- a/roomserver/storage/tables/membership_table_test.go +++ b/roomserver/storage/tables/membership_table_test.go @@ -129,5 +129,11 @@ func TestMembershipTable(t *testing.T) { knownUsers, err := tab.SelectKnownUsers(ctx, nil, userNIDs[0], "localhost", 2) assert.NoError(t, err) assert.Equal(t, 1, len(knownUsers)) + + // get users we share a room with, given their userNID + joinedUsers, err := tab.SelectJoinedUsers(ctx, nil, userNIDs) + assert.NoError(t, err) + // Only userNIDs[0] is actually joined, so we only expect this userNID + assert.Equal(t, userNIDs[:1], joinedUsers) }) }