Cleanup device list entries for users we don't share a room with anymore

This commit is contained in:
Till Faelligen 2022-11-04 11:17:33 +01:00
parent 98d3f88bfb
commit 26f625e015
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
21 changed files with 370 additions and 25 deletions

View file

@ -24,6 +24,8 @@ import (
"sync" "sync"
"time" "time"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -102,6 +104,7 @@ type DeviceListUpdater struct {
// block on or timeout via a select. // block on or timeout via a select.
userIDToChan map[string]chan bool userIDToChan map[string]chan bool
userIDToChanMu *sync.Mutex userIDToChanMu *sync.Mutex
rsAPI rsapi.KeyserverRoomserverAPI
} }
// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater. // DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
@ -124,6 +127,8 @@ type DeviceListUpdaterDatabase interface {
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error
} }
type DeviceListUpdaterAPI interface { type DeviceListUpdaterAPI interface {
@ -140,6 +145,7 @@ func NewDeviceListUpdater(
process *process.ProcessContext, db DeviceListUpdaterDatabase, process *process.ProcessContext, db DeviceListUpdaterDatabase,
api DeviceListUpdaterAPI, producer KeyChangeProducer, api DeviceListUpdaterAPI, producer KeyChangeProducer,
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
rsAPI rsapi.KeyserverRoomserverAPI,
) *DeviceListUpdater { ) *DeviceListUpdater {
return &DeviceListUpdater{ return &DeviceListUpdater{
process: process, process: process,
@ -152,11 +158,16 @@ func NewDeviceListUpdater(
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
userIDToChan: make(map[string]chan bool), userIDToChan: make(map[string]chan bool),
userIDToChanMu: &sync.Mutex{}, userIDToChanMu: &sync.Mutex{},
rsAPI: rsAPI,
} }
} }
// Start the device list updater, which will try to refresh any stale device lists. // Start the device list updater, which will try to refresh any stale device lists.
func (u *DeviceListUpdater) Start() error { func (u *DeviceListUpdater) Start() error {
if err := u.cleanUp(); err != nil {
return fmt.Errorf("failed to cleanup stale device lists: %w", err)
}
for i := 0; i < len(u.workerChans); i++ { for i := 0; i < len(u.workerChans); i++ {
// Allocate a small buffer per channel. // Allocate a small buffer per channel.
// If the buffer limit is reached, backpressure will cause the processing of EDUs // If the buffer limit is reached, backpressure will cause the processing of EDUs
@ -166,7 +177,7 @@ func (u *DeviceListUpdater) Start() error {
go u.worker(ch) go u.worker(ch)
} }
staleLists, err := u.db.StaleDeviceLists(context.Background(), []gomatrixserverlib.ServerName{}) staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
if err != nil { if err != nil {
return err return err
} }
@ -184,6 +195,36 @@ func (u *DeviceListUpdater) Start() error {
return nil return nil
} }
// cleanUp removes stale device entries for users we don't share a room with anymore
func (u *DeviceListUpdater) cleanUp() error {
if u.rsAPI == nil {
return nil
}
staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
if err != nil {
return err
}
maxRetries := 3
// In polylith mode, the roomserver api might not be up yet, so we try again
res := rsapi.QueryLeftUsersResponse{}
for i := 0; i <= maxRetries; i++ {
if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{UserIDs: staleUsers}, &res); err != nil {
if i == maxRetries {
return err
}
logrus.WithError(err).Warnf("unable to query left users (try %d/%d)", i+1, maxRetries)
time.Sleep(time.Second * 3)
}
}
if len(res.UserIDs) == 0 {
return nil
}
logrus.Debugf("Deleting %d stale device list entries", len(res.UserIDs))
return u.db.DeleteStaleDeviceLists(u.process.Context(), res.UserIDs)
}
func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex { func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
u.mu.Lock() u.mu.Lock()
defer u.mu.Unlock() defer u.mu.Unlock()

View file

@ -53,6 +53,10 @@ type mockDeviceListUpdaterDatabase struct {
mu sync.Mutex // protect staleUsers mu sync.Mutex // protect staleUsers
} }
func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error {
return nil
}
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned. // If no domains are given, all user IDs with stale device lists are returned.
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
@ -147,7 +151,7 @@ func TestUpdateHavePrevID(t *testing.T) {
} }
ap := &mockDeviceListUpdaterAPI{} ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{} producer := &mockKeyChangeProducer{}
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil)
event := gomatrixserverlib.DeviceListUpdateEvent{ event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Foo Bar", DeviceDisplayName: "Foo Bar",
Deleted: false, Deleted: false,
@ -219,7 +223,7 @@ func TestUpdateNoPrevID(t *testing.T) {
`)), `)),
}, nil }, nil
}) })
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil)
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err) t.Fatalf("failed to start updater: %s", err)
} }
@ -288,7 +292,7 @@ func TestDebounce(t *testing.T) {
close(incomingFedReq) close(incomingFedReq)
return <-fedCh, nil return <-fedCh, nil
}) })
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil)
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err) t.Fatalf("failed to start updater: %s", err)
} }

View file

@ -16,6 +16,7 @@ package keyserver
import ( import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
@ -40,6 +41,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI( func NewInternalAPI(
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI, base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI,
rsAPI rsapi.KeyserverRoomserverAPI,
) api.KeyInternalAPI { ) api.KeyInternalAPI {
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
@ -47,6 +49,7 @@ func NewInternalAPI(
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to key server database") logrus.WithError(err).Panicf("failed to connect to key server database")
} }
keyChangeProducer := &producers.KeyChange{ keyChangeProducer := &producers.KeyChange{
Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)), Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)),
JetStream: js, JetStream: js,
@ -58,7 +61,7 @@ func NewInternalAPI(
FedClient: fedClient, FedClient: fedClient,
Producer: keyChangeProducer, Producer: keyChangeProducer,
} }
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI) // 8 workers TODO: configurable
ap.Updater = updater ap.Updater = updater
go func() { go func() {
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {

View file

@ -85,4 +85,9 @@ type Database interface {
StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
DeleteStaleDeviceLists(
ctx context.Context,
userIDs []string,
) error
} }

View file

@ -19,6 +19,10 @@ import (
"database/sql" "database/sql"
"time" "time"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -48,10 +52,14 @@ const selectStaleDeviceListsWithDomainsSQL = "" +
const selectStaleDeviceListsSQL = "" + const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
const deleteStaleDevicesSQL = "" +
"DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)"
type staleDeviceListsStatements struct { type staleDeviceListsStatements struct {
upsertStaleDeviceListStmt *sql.Stmt upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt selectStaleDeviceListsStmt *sql.Stmt
deleteStaleDeviceListsStmt *sql.Stmt
} }
func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
@ -60,16 +68,12 @@ func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, erro
if err != nil { if err != nil {
return nil, err return nil, err
} }
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { return s, sqlutil.StatementList{
return nil, err {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
} {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
return nil, err {&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL},
} }.Prepare(db)
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
return nil, err
}
return s, nil
} }
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
@ -105,6 +109,15 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
return result, nil return result, nil
} }
// DeleteStaleDeviceLists removes users from stale device lists
func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
ctx context.Context, txn *sql.Tx, userIDs []string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt)
_, err := stmt.ExecContext(ctx, pq.Array(userIDs))
return err
}
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
for rows.Next() { for rows.Next() {

View file

@ -249,3 +249,13 @@ func (d *Database) StoreCrossSigningSigsForTarget(
return nil return nil
}) })
} }
// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore.
func (d *Database) DeleteStaleDeviceLists(
ctx context.Context,
userIDs []string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs)
})
}

View file

@ -17,8 +17,11 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings"
"time" "time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -48,11 +51,15 @@ const selectStaleDeviceListsWithDomainsSQL = "" +
const selectStaleDeviceListsSQL = "" + const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
const deleteStaleDevicesSQL = "" +
"DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)"
type staleDeviceListsStatements struct { type staleDeviceListsStatements struct {
db *sql.DB db *sql.DB
upsertStaleDeviceListStmt *sql.Stmt upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt selectStaleDeviceListsStmt *sql.Stmt
// deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime
} }
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
@ -63,16 +70,12 @@ func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { return s, sqlutil.StatementList{
return nil, err {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
} {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
return nil, err // { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime
} }.Prepare(db)
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
return nil, err
}
return s, nil
} }
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
@ -108,6 +111,27 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
return result, nil return result, nil
} }
// DeleteStaleDeviceLists removes users from stale device lists
func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
ctx context.Context, txn *sql.Tx, userIDs []string,
) error {
qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1)
stmt, err := s.db.Prepare(qry)
if err != nil {
return err
}
defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed")
stmt = sqlutil.TxStmt(txn, stmt)
params := make([]any, len(userIDs))
for i := range userIDs {
params[i] = userIDs[i]
}
_, err = stmt.ExecContext(ctx, params...)
return err
}
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
for rows.Next() { for rows.Next() {

View file

@ -56,6 +56,7 @@ type KeyChanges interface {
type StaleDeviceLists interface { type StaleDeviceLists interface {
InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error
} }
type CrossSigningKeys interface { type CrossSigningKeys interface {

View file

@ -0,0 +1,94 @@
package tables_test
import (
"context"
"testing"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/keyserver/storage/postgres"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/test"
)
func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, nil)
if err != nil {
t.Fatalf("failed to open database: %s", err)
}
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresStaleDeviceListsTable(db)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db)
}
if err != nil {
t.Fatalf("failed to create new table: %s", err)
}
return tab, close
}
func TestStaleDeviceLists(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
charlie := "@charlie:localhost"
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, closeDB := mustCreateTable(t, dbType)
defer closeDB()
if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil {
t.Fatalf("failed to insert stale device: %s", err)
}
if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil {
t.Fatalf("failed to insert stale device: %s", err)
}
if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil {
t.Fatalf("failed to insert stale device: %s", err)
}
// Query one server
wantStaleUsers := []string{alice.ID, bob.ID}
gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
if err != nil {
t.Fatalf("failed to query stale device lists: %s", err)
}
if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
}
// Query all servers
wantStaleUsers = []string{alice.ID, bob.ID, charlie}
gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{})
if err != nil {
t.Fatalf("failed to query stale device lists: %s", err)
}
if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
}
// Delete stale devices
deleteUsers := []string{alice.ID, bob.ID}
if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil {
t.Fatalf("failed to delete stale device lists: %s", err)
}
// Verify we don't get anything back after deleting
gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
if err != nil {
t.Fatalf("failed to query stale device lists: %s", err)
}
if gotCount := len(gotStaleUsers); gotCount > 0 {
t.Fatalf("expected no stale users, got %d", gotCount)
}
})
}

View file

@ -17,6 +17,7 @@ type RoomserverInternalAPI interface {
ClientRoomserverAPI ClientRoomserverAPI
UserRoomserverAPI UserRoomserverAPI
FederationRoomserverAPI FederationRoomserverAPI
KeyserverRoomserverAPI
// needed to avoid chicken and egg scenario when setting up the // needed to avoid chicken and egg scenario when setting up the
// interdependencies between the roomserver and other input APIs // interdependencies between the roomserver and other input APIs
@ -198,3 +199,7 @@ type FederationRoomserverAPI interface {
// Query a given amount (or less) of events prior to a given set of events. // Query a given amount (or less) of events prior to a given set of events.
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
} }
type KeyserverRoomserverAPI interface {
QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error
}

View file

@ -19,6 +19,10 @@ type RoomserverInternalAPITrace struct {
Impl RoomserverInternalAPI Impl RoomserverInternalAPI
} }
func (t *RoomserverInternalAPITrace) QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error {
return t.Impl.QueryLeftUsers(ctx, req, res)
}
func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) { func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) {
t.Impl.SetFederationAPI(fsAPI, keyRing) t.Impl.SetFederationAPI(fsAPI, keyRing)
} }

View file

@ -447,3 +447,11 @@ type QueryMembershipAtEventResponse struct {
// do not have known state will return an empty array here. // do not have known state will return an empty array here.
Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"` Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"`
} }
type QueryLeftUsersRequest struct {
UserIDs []string `json:"user_ids"`
}
type QueryLeftUsersResponse struct {
UserIDs []string `json:"user_ids"`
}

View file

@ -805,6 +805,12 @@ func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkS
return nil return nil
} }
func (r *Queryer) QueryLeftUsers(ctx context.Context, req *api.QueryLeftUsersRequest, res *api.QueryLeftUsersResponse) error {
var err error
res.UserIDs, err = r.DB.GetLeftUsers(ctx, req.UserIDs)
return err
}
func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join") roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join")
if err != nil { if err != nil {

View file

@ -63,6 +63,7 @@ const (
RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain" RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain"
RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed" RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed"
RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent" RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent"
RoomserverQueryLeftMembersPath = "/roomserver/queryLeftMembers"
) )
type httpRoomserverInternalAPI struct { type httpRoomserverInternalAPI struct {
@ -553,3 +554,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context,
h.httpClient, ctx, request, response, h.httpClient, ctx, request, response,
) )
} }
func (h *httpRoomserverInternalAPI) QueryLeftUsers(ctx context.Context, request *api.QueryLeftUsersRequest, response *api.QueryLeftUsersResponse) error {
return httputil.CallInternalRPCAPI(
"RoomserverQueryLeftMembers", h.roomserverURL+RoomserverQueryLeftMembersPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -203,4 +203,9 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
RoomserverQueryMembershipAtEventPath, RoomserverQueryMembershipAtEventPath,
httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", r.QueryMembershipAtEvent), httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", r.QueryMembershipAtEvent),
) )
internalAPIMux.Handle(
RoomserverQueryLeftMembersPath,
httputil.MakeInternalRPCAPI("RoomserverQueryLeftMembersPath", r.QueryLeftUsers),
)
} }

View file

@ -172,4 +172,6 @@ type Database interface {
ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error
GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error)
GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error)
} }

View file

@ -157,6 +157,12 @@ const selectServerInRoomSQL = "" +
" JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + " JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
" WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1" " WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1"
const selectJoinedUsersSQL = `
SELECT DISTINCT target_nid
FROM roomserver_membership m
WHERE membership_nid > $1 AND target_nid = ANY($2)
`
type membershipStatements struct { type membershipStatements struct {
insertMembershipStmt *sql.Stmt insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt
@ -174,6 +180,7 @@ type membershipStatements struct {
selectLocalServerInRoomStmt *sql.Stmt selectLocalServerInRoomStmt *sql.Stmt
selectServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt
deleteMembershipStmt *sql.Stmt deleteMembershipStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt
} }
func CreateMembershipTable(db *sql.DB) error { func CreateMembershipTable(db *sql.DB) error {
@ -209,9 +216,33 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL}, {&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL},
{&s.selectServerInRoomStmt, selectServerInRoomSQL}, {&s.selectServerInRoomStmt, selectServerInRoomSQL},
{&s.deleteMembershipStmt, deleteMembershipSQL}, {&s.deleteMembershipStmt, deleteMembershipSQL},
{&s.selectJoinedUsersStmt, selectJoinedUsersSQL},
}.Prepare(db) }.Prepare(db)
} }
func (s *membershipStatements) SelectJoinedUsers(
ctx context.Context, txn *sql.Tx,
targetUserNIDs []types.EventStateKeyNID,
) ([]types.EventStateKeyNID, error) {
result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs))
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt)
rows, err := stmt.QueryContext(ctx, tables.MembershipStateLeaveOrBan, pq.Array(targetUserNIDs))
if err != nil {
return nil, err
}
var targetNID types.EventStateKeyNID
for rows.Next() {
if err = rows.Scan(&targetNID); err != nil {
return nil, err
}
result = append(result, targetNID)
}
return result, rows.Err()
}
func (s *membershipStatements) InsertMembership( func (s *membershipStatements) InsertMembership(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,

View file

@ -1343,6 +1343,36 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
return result, nil return result, nil
} }
func (d *Database) GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) {
stateKeyNIDMap, err := d.EventStateKeyNIDs(ctx, userIDs)
if err != nil {
return nil, err
}
userNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDMap))
userNIDtoUserID := make(map[types.EventStateKeyNID]string, len(stateKeyNIDMap))
for userID, nid := range stateKeyNIDMap {
userNIDs = append(userNIDs, nid)
userNIDtoUserID[nid] = userID
}
stillJoinedUsersNIDs, err := d.MembershipTable.SelectJoinedUsers(ctx, nil, userNIDs)
if err != nil {
return nil, err
}
for _, joinedUser := range stillJoinedUsersNIDs {
delete(userNIDtoUserID, joinedUser)
}
leftUsers := make([]string, 0, len(userNIDtoUserID))
for _, userID := range userNIDtoUserID {
leftUsers = append(leftUsers, userID)
}
return leftUsers, nil
}
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID) return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID)

View file

@ -133,6 +133,12 @@ const selectServerInRoomSQL = "" +
const deleteMembershipSQL = "" + const deleteMembershipSQL = "" +
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2" "DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
const selectJoinedUsersSQL = `
SELECT DISTINCT target_nid
FROM roomserver_membership m
WHERE membership_nid > $1 AND target_nid IN ($2)
`
type membershipStatements struct { type membershipStatements struct {
db *sql.DB db *sql.DB
insertMembershipStmt *sql.Stmt insertMembershipStmt *sql.Stmt
@ -149,6 +155,7 @@ type membershipStatements struct {
selectLocalServerInRoomStmt *sql.Stmt selectLocalServerInRoomStmt *sql.Stmt
selectServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt
deleteMembershipStmt *sql.Stmt deleteMembershipStmt *sql.Stmt
// selectJoinedUsersStmt *sql.Stmt // Prepared at runtime
} }
func CreateMembershipTable(db *sql.DB) error { func CreateMembershipTable(db *sql.DB) error {
@ -412,3 +419,40 @@ func (s *membershipStatements) DeleteMembership(
) )
return err return err
} }
func (s *membershipStatements) SelectJoinedUsers(
ctx context.Context, txn *sql.Tx,
targetUserNIDs []types.EventStateKeyNID,
) ([]types.EventStateKeyNID, error) {
result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs))
qry := strings.Replace(selectJoinedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(targetUserNIDs), 1), 1)
stmt, err := s.db.Prepare(qry)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsers: stmt.Close failed")
params := make([]any, len(targetUserNIDs)+1)
params[0] = tables.MembershipStateLeaveOrBan
for i := range targetUserNIDs {
params[i+1] = targetUserNIDs[i]
}
stmt = sqlutil.TxStmt(txn, stmt)
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
}
var targetNID types.EventStateKeyNID
for rows.Next() {
if err = rows.Scan(&targetNID); err != nil {
return nil, err
}
result = append(result, targetNID)
}
return result, rows.Err()
}

View file

@ -144,6 +144,7 @@ type Membership interface {
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error
SelectJoinedUsers(ctx context.Context, txn *sql.Tx, targetUserNIDs []types.EventStateKeyNID) ([]types.EventStateKeyNID, error)
} }
type Published interface { type Published interface {

View file

@ -129,5 +129,11 @@ func TestMembershipTable(t *testing.T) {
knownUsers, err := tab.SelectKnownUsers(ctx, nil, userNIDs[0], "localhost", 2) knownUsers, err := tab.SelectKnownUsers(ctx, nil, userNIDs[0], "localhost", 2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 1, len(knownUsers)) assert.Equal(t, 1, len(knownUsers))
// get users we share a room with, given their userNID
joinedUsers, err := tab.SelectJoinedUsers(ctx, nil, userNIDs)
assert.NoError(t, err)
// Only userNIDs[0] is actually joined, so we only expect this userNID
assert.Equal(t, userNIDs[:1], joinedUsers)
}) })
} }