Cleanup device list entries for users we don't share a room with anymore

This commit is contained in:
Till Faelligen 2022-11-04 11:17:33 +01:00
parent 98d3f88bfb
commit 26f625e015
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
21 changed files with 370 additions and 25 deletions

View file

@ -24,6 +24,8 @@ import (
"sync"
"time"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@ -102,6 +104,7 @@ type DeviceListUpdater struct {
// block on or timeout via a select.
userIDToChan map[string]chan bool
userIDToChanMu *sync.Mutex
rsAPI rsapi.KeyserverRoomserverAPI
}
// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
@ -124,6 +127,8 @@ type DeviceListUpdaterDatabase interface {
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error
}
type DeviceListUpdaterAPI interface {
@ -140,6 +145,7 @@ func NewDeviceListUpdater(
process *process.ProcessContext, db DeviceListUpdaterDatabase,
api DeviceListUpdaterAPI, producer KeyChangeProducer,
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
rsAPI rsapi.KeyserverRoomserverAPI,
) *DeviceListUpdater {
return &DeviceListUpdater{
process: process,
@ -152,11 +158,16 @@ func NewDeviceListUpdater(
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
userIDToChan: make(map[string]chan bool),
userIDToChanMu: &sync.Mutex{},
rsAPI: rsAPI,
}
}
// Start the device list updater, which will try to refresh any stale device lists.
func (u *DeviceListUpdater) Start() error {
if err := u.cleanUp(); err != nil {
return fmt.Errorf("failed to cleanup stale device lists: %w", err)
}
for i := 0; i < len(u.workerChans); i++ {
// Allocate a small buffer per channel.
// If the buffer limit is reached, backpressure will cause the processing of EDUs
@ -166,7 +177,7 @@ func (u *DeviceListUpdater) Start() error {
go u.worker(ch)
}
staleLists, err := u.db.StaleDeviceLists(context.Background(), []gomatrixserverlib.ServerName{})
staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
if err != nil {
return err
}
@ -184,6 +195,36 @@ func (u *DeviceListUpdater) Start() error {
return nil
}
// cleanUp removes stale device entries for users we don't share a room with anymore
func (u *DeviceListUpdater) cleanUp() error {
if u.rsAPI == nil {
return nil
}
staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
if err != nil {
return err
}
maxRetries := 3
// In polylith mode, the roomserver api might not be up yet, so we try again
res := rsapi.QueryLeftUsersResponse{}
for i := 0; i <= maxRetries; i++ {
if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{UserIDs: staleUsers}, &res); err != nil {
if i == maxRetries {
return err
}
logrus.WithError(err).Warnf("unable to query left users (try %d/%d)", i+1, maxRetries)
time.Sleep(time.Second * 3)
}
}
if len(res.UserIDs) == 0 {
return nil
}
logrus.Debugf("Deleting %d stale device list entries", len(res.UserIDs))
return u.db.DeleteStaleDeviceLists(u.process.Context(), res.UserIDs)
}
func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
u.mu.Lock()
defer u.mu.Unlock()

View file

@ -53,6 +53,10 @@ type mockDeviceListUpdaterDatabase struct {
mu sync.Mutex // protect staleUsers
}
func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error {
return nil
}
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned.
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
@ -147,7 +151,7 @@ func TestUpdateHavePrevID(t *testing.T) {
}
ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{}
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1)
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil)
event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Foo Bar",
Deleted: false,
@ -219,7 +223,7 @@ func TestUpdateNoPrevID(t *testing.T) {
`)),
}, nil
})
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2)
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil)
if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err)
}
@ -288,7 +292,7 @@ func TestDebounce(t *testing.T) {
close(incomingFedReq)
return <-fedCh, nil
})
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1)
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil)
if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err)
}

View file

@ -16,6 +16,7 @@ package keyserver
import (
"github.com/gorilla/mux"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/sirupsen/logrus"
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
@ -40,6 +41,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI,
rsAPI rsapi.KeyserverRoomserverAPI,
) api.KeyInternalAPI {
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
@ -47,6 +49,7 @@ func NewInternalAPI(
if err != nil {
logrus.WithError(err).Panicf("failed to connect to key server database")
}
keyChangeProducer := &producers.KeyChange{
Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)),
JetStream: js,
@ -58,7 +61,7 @@ func NewInternalAPI(
FedClient: fedClient,
Producer: keyChangeProducer,
}
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI) // 8 workers TODO: configurable
ap.Updater = updater
go func() {
if err := updater.Start(); err != nil {

View file

@ -85,4 +85,9 @@ type Database interface {
StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
DeleteStaleDeviceLists(
ctx context.Context,
userIDs []string,
) error
}

View file

@ -19,6 +19,10 @@ import (
"database/sql"
"time"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
@ -48,10 +52,14 @@ const selectStaleDeviceListsWithDomainsSQL = "" +
const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
const deleteStaleDevicesSQL = "" +
"DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)"
type staleDeviceListsStatements struct {
upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt
deleteStaleDeviceListsStmt *sql.Stmt
}
func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
@ -60,16 +68,12 @@ func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, erro
if err != nil {
return nil, err
}
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
return nil, err
}
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
return nil, err
}
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
{&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
{&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
{&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL},
}.Prepare(db)
}
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
@ -105,6 +109,15 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
return result, nil
}
// DeleteStaleDeviceLists removes users from stale device lists
func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
ctx context.Context, txn *sql.Tx, userIDs []string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt)
_, err := stmt.ExecContext(ctx, pq.Array(userIDs))
return err
}
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
for rows.Next() {

View file

@ -249,3 +249,13 @@ func (d *Database) StoreCrossSigningSigsForTarget(
return nil
})
}
// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore.
func (d *Database) DeleteStaleDeviceLists(
ctx context.Context,
userIDs []string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs)
})
}

View file

@ -17,8 +17,11 @@ package sqlite3
import (
"context"
"database/sql"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
@ -48,11 +51,15 @@ const selectStaleDeviceListsWithDomainsSQL = "" +
const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
const deleteStaleDevicesSQL = "" +
"DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)"
type staleDeviceListsStatements struct {
db *sql.DB
upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt
// deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime
}
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
@ -63,16 +70,12 @@ func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error)
if err != nil {
return nil, err
}
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
return nil, err
}
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
return nil, err
}
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
{&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
{&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
// { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime
}.Prepare(db)
}
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
@ -108,6 +111,27 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
return result, nil
}
// DeleteStaleDeviceLists removes users from stale device lists
func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
ctx context.Context, txn *sql.Tx, userIDs []string,
) error {
qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1)
stmt, err := s.db.Prepare(qry)
if err != nil {
return err
}
defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed")
stmt = sqlutil.TxStmt(txn, stmt)
params := make([]any, len(userIDs))
for i := range userIDs {
params[i] = userIDs[i]
}
_, err = stmt.ExecContext(ctx, params...)
return err
}
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
for rows.Next() {

View file

@ -56,6 +56,7 @@ type KeyChanges interface {
type StaleDeviceLists interface {
InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error
}
type CrossSigningKeys interface {

View file

@ -0,0 +1,94 @@
package tables_test
import (
"context"
"testing"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/keyserver/storage/postgres"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/test"
)
func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, nil)
if err != nil {
t.Fatalf("failed to open database: %s", err)
}
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresStaleDeviceListsTable(db)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db)
}
if err != nil {
t.Fatalf("failed to create new table: %s", err)
}
return tab, close
}
func TestStaleDeviceLists(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
charlie := "@charlie:localhost"
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, closeDB := mustCreateTable(t, dbType)
defer closeDB()
if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil {
t.Fatalf("failed to insert stale device: %s", err)
}
if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil {
t.Fatalf("failed to insert stale device: %s", err)
}
if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil {
t.Fatalf("failed to insert stale device: %s", err)
}
// Query one server
wantStaleUsers := []string{alice.ID, bob.ID}
gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
if err != nil {
t.Fatalf("failed to query stale device lists: %s", err)
}
if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
}
// Query all servers
wantStaleUsers = []string{alice.ID, bob.ID, charlie}
gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{})
if err != nil {
t.Fatalf("failed to query stale device lists: %s", err)
}
if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
}
// Delete stale devices
deleteUsers := []string{alice.ID, bob.ID}
if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil {
t.Fatalf("failed to delete stale device lists: %s", err)
}
// Verify we don't get anything back after deleting
gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
if err != nil {
t.Fatalf("failed to query stale device lists: %s", err)
}
if gotCount := len(gotStaleUsers); gotCount > 0 {
t.Fatalf("expected no stale users, got %d", gotCount)
}
})
}

View file

@ -17,6 +17,7 @@ type RoomserverInternalAPI interface {
ClientRoomserverAPI
UserRoomserverAPI
FederationRoomserverAPI
KeyserverRoomserverAPI
// needed to avoid chicken and egg scenario when setting up the
// interdependencies between the roomserver and other input APIs
@ -198,3 +199,7 @@ type FederationRoomserverAPI interface {
// Query a given amount (or less) of events prior to a given set of events.
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
}
type KeyserverRoomserverAPI interface {
QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error
}

View file

@ -19,6 +19,10 @@ type RoomserverInternalAPITrace struct {
Impl RoomserverInternalAPI
}
func (t *RoomserverInternalAPITrace) QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error {
return t.Impl.QueryLeftUsers(ctx, req, res)
}
func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) {
t.Impl.SetFederationAPI(fsAPI, keyRing)
}

View file

@ -447,3 +447,11 @@ type QueryMembershipAtEventResponse struct {
// do not have known state will return an empty array here.
Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"`
}
type QueryLeftUsersRequest struct {
UserIDs []string `json:"user_ids"`
}
type QueryLeftUsersResponse struct {
UserIDs []string `json:"user_ids"`
}

View file

@ -805,6 +805,12 @@ func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkS
return nil
}
func (r *Queryer) QueryLeftUsers(ctx context.Context, req *api.QueryLeftUsersRequest, res *api.QueryLeftUsersResponse) error {
var err error
res.UserIDs, err = r.DB.GetLeftUsers(ctx, req.UserIDs)
return err
}
func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join")
if err != nil {

View file

@ -63,6 +63,7 @@ const (
RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain"
RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed"
RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent"
RoomserverQueryLeftMembersPath = "/roomserver/queryLeftMembers"
)
type httpRoomserverInternalAPI struct {
@ -553,3 +554,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryLeftUsers(ctx context.Context, request *api.QueryLeftUsersRequest, response *api.QueryLeftUsersResponse) error {
return httputil.CallInternalRPCAPI(
"RoomserverQueryLeftMembers", h.roomserverURL+RoomserverQueryLeftMembersPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -203,4 +203,9 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
RoomserverQueryMembershipAtEventPath,
httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", r.QueryMembershipAtEvent),
)
internalAPIMux.Handle(
RoomserverQueryLeftMembersPath,
httputil.MakeInternalRPCAPI("RoomserverQueryLeftMembersPath", r.QueryLeftUsers),
)
}

View file

@ -172,4 +172,6 @@ type Database interface {
ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error
GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error)
GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error)
}

View file

@ -157,6 +157,12 @@ const selectServerInRoomSQL = "" +
" JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
" WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1"
const selectJoinedUsersSQL = `
SELECT DISTINCT target_nid
FROM roomserver_membership m
WHERE membership_nid > $1 AND target_nid = ANY($2)
`
type membershipStatements struct {
insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt
@ -174,6 +180,7 @@ type membershipStatements struct {
selectLocalServerInRoomStmt *sql.Stmt
selectServerInRoomStmt *sql.Stmt
deleteMembershipStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt
}
func CreateMembershipTable(db *sql.DB) error {
@ -209,9 +216,33 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL},
{&s.selectServerInRoomStmt, selectServerInRoomSQL},
{&s.deleteMembershipStmt, deleteMembershipSQL},
{&s.selectJoinedUsersStmt, selectJoinedUsersSQL},
}.Prepare(db)
}
func (s *membershipStatements) SelectJoinedUsers(
ctx context.Context, txn *sql.Tx,
targetUserNIDs []types.EventStateKeyNID,
) ([]types.EventStateKeyNID, error) {
result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs))
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt)
rows, err := stmt.QueryContext(ctx, tables.MembershipStateLeaveOrBan, pq.Array(targetUserNIDs))
if err != nil {
return nil, err
}
var targetNID types.EventStateKeyNID
for rows.Next() {
if err = rows.Scan(&targetNID); err != nil {
return nil, err
}
result = append(result, targetNID)
}
return result, rows.Err()
}
func (s *membershipStatements) InsertMembership(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,

View file

@ -1343,6 +1343,36 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
return result, nil
}
func (d *Database) GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) {
stateKeyNIDMap, err := d.EventStateKeyNIDs(ctx, userIDs)
if err != nil {
return nil, err
}
userNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDMap))
userNIDtoUserID := make(map[types.EventStateKeyNID]string, len(stateKeyNIDMap))
for userID, nid := range stateKeyNIDMap {
userNIDs = append(userNIDs, nid)
userNIDtoUserID[nid] = userID
}
stillJoinedUsersNIDs, err := d.MembershipTable.SelectJoinedUsers(ctx, nil, userNIDs)
if err != nil {
return nil, err
}
for _, joinedUser := range stillJoinedUsersNIDs {
delete(userNIDtoUserID, joinedUser)
}
leftUsers := make([]string, 0, len(userNIDtoUserID))
for _, userID := range userNIDtoUserID {
leftUsers = append(leftUsers, userID)
}
return leftUsers, nil
}
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID)

View file

@ -133,6 +133,12 @@ const selectServerInRoomSQL = "" +
const deleteMembershipSQL = "" +
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
const selectJoinedUsersSQL = `
SELECT DISTINCT target_nid
FROM roomserver_membership m
WHERE membership_nid > $1 AND target_nid IN ($2)
`
type membershipStatements struct {
db *sql.DB
insertMembershipStmt *sql.Stmt
@ -149,6 +155,7 @@ type membershipStatements struct {
selectLocalServerInRoomStmt *sql.Stmt
selectServerInRoomStmt *sql.Stmt
deleteMembershipStmt *sql.Stmt
// selectJoinedUsersStmt *sql.Stmt // Prepared at runtime
}
func CreateMembershipTable(db *sql.DB) error {
@ -412,3 +419,40 @@ func (s *membershipStatements) DeleteMembership(
)
return err
}
func (s *membershipStatements) SelectJoinedUsers(
ctx context.Context, txn *sql.Tx,
targetUserNIDs []types.EventStateKeyNID,
) ([]types.EventStateKeyNID, error) {
result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs))
qry := strings.Replace(selectJoinedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(targetUserNIDs), 1), 1)
stmt, err := s.db.Prepare(qry)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsers: stmt.Close failed")
params := make([]any, len(targetUserNIDs)+1)
params[0] = tables.MembershipStateLeaveOrBan
for i := range targetUserNIDs {
params[i+1] = targetUserNIDs[i]
}
stmt = sqlutil.TxStmt(txn, stmt)
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
}
var targetNID types.EventStateKeyNID
for rows.Next() {
if err = rows.Scan(&targetNID); err != nil {
return nil, err
}
result = append(result, targetNID)
}
return result, rows.Err()
}

View file

@ -144,6 +144,7 @@ type Membership interface {
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error
SelectJoinedUsers(ctx context.Context, txn *sql.Tx, targetUserNIDs []types.EventStateKeyNID) ([]types.EventStateKeyNID, error)
}
type Published interface {

View file

@ -129,5 +129,11 @@ func TestMembershipTable(t *testing.T) {
knownUsers, err := tab.SelectKnownUsers(ctx, nil, userNIDs[0], "localhost", 2)
assert.NoError(t, err)
assert.Equal(t, 1, len(knownUsers))
// get users we share a room with, given their userNID
joinedUsers, err := tab.SelectJoinedUsers(ctx, nil, userNIDs)
assert.NoError(t, err)
// Only userNIDs[0] is actually joined, so we only expect this userNID
assert.Equal(t, userNIDs[:1], joinedUsers)
})
}