Limit JoinedUsersSetInRooms to interested users (#2234)

* Limit database work in `JoinedUsersSetInRooms` to changed user IDs only

* Comments

* Fix variadic params for SQLite, update comments
This commit is contained in:
Neil Alexander 2022-03-01 13:01:38 +00:00 committed by GitHub
parent 58bf91a585
commit 530f05885d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 49 additions and 31 deletions

View file

@ -269,6 +269,7 @@ type QueryAuthChainResponse struct {
type QuerySharedUsersRequest struct { type QuerySharedUsersRequest struct {
UserID string UserID string
OtherUserIDs []string
ExcludeRoomIDs []string ExcludeRoomIDs []string
IncludeRoomIDs []string IncludeRoomIDs []string
} }

View file

@ -696,7 +696,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser
} }
roomIDs = roomIDs[:j] roomIDs = roomIDs[:j]
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs) users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs)
if err != nil { if err != nil {
return err return err
} }

View file

@ -151,8 +151,8 @@ type Database interface {
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. // JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms.
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error)
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
// GetServerInRoom returns true if we think a server is in a given room or false otherwise. // GetServerInRoom returns true if we think a server is in a given room or false otherwise.

View file

@ -66,7 +66,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
` `
var selectJoinedUsersSetForRoomsSQL = "" + var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid" " GROUP BY target_nid"
@ -306,13 +307,10 @@ func (s *membershipStatements) SelectRoomsWithMembership(
func (s *membershipStatements) SelectJoinedUsersSetForRooms( func (s *membershipStatements) SelectJoinedUsersSetForRooms(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomNIDs []types.RoomNID, roomNIDs []types.RoomNID,
userNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]int, error) { ) (map[types.EventStateKeyNID]int, error) {
roomIDarray := make([]int64, len(roomNIDs))
for i := range roomNIDs {
roomIDarray[i] = int64(roomNIDs[i])
}
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt) stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1104,13 +1104,23 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
return result, nil return result, nil
} }
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. // JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms.
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs) roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs) userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs)
if err != nil {
return nil, err
}
userNIDs := make([]types.EventStateKeyNID, 0, len(userNIDsMap))
nidToUserID := make(map[types.EventStateKeyNID]string, len(userNIDsMap))
for id, nid := range userNIDsMap {
userNIDs = append(userNIDs, nid)
nidToUserID[nid] = id
}
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1120,10 +1130,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
stateKeyNIDs[i] = nid stateKeyNIDs[i] = nid
i++ i++
} }
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
if err != nil {
return nil, err
}
if len(nidToUserID) != len(userNIDToCount) { if len(nidToUserID) != len(userNIDToCount) {
logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID)) logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID))
} }

View file

@ -42,7 +42,8 @@ const membershipSchema = `
` `
var selectJoinedUsersSetForRoomsSQL = "" + var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid" " GROUP BY target_nid"
@ -280,18 +281,22 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil return roomNIDs, nil
} }
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs)) params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs))
for i, v := range roomNIDs { for _, v := range roomNIDs {
iRoomNIDs[i] = v params = append(params, v)
} }
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) for _, v := range userNIDs {
params = append(params, v)
}
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
var rows *sql.Rows var rows *sql.Rows
var err error var err error
if txn != nil { if txn != nil {
rows, err = txn.QueryContext(ctx, query, iRoomNIDs...) rows, err = txn.QueryContext(ctx, query, params...)
} else { } else {
rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...) rows, err = s.db.QueryContext(ctx, query, params...)
} }
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -127,9 +127,8 @@ type Membership interface {
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the // SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
// counts of how many rooms they are joined. SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error)
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)

View file

@ -82,7 +82,16 @@ func DeviceListCatchup(
util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed") util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed")
return to, hasNew, nil return to, hasNew, nil
} }
// QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user.
// Work out which user IDs we care about — that includes those in the original request,
// the response from QueryKeyChanges (which includes ALL users who have changed keys)
// as well as every user who has a join or leave event in the current sync response. We
// will request information about which rooms these users are joined to, so that we can
// see if we still share any rooms with them.
joinUserIDs, leaveUserIDs := membershipEvents(res)
queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...)
queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...)
queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs)
var sharedUsersMap map[string]int var sharedUsersMap map[string]int
sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs) sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs)
util.GetLogger(ctx).Debugf( util.GetLogger(ctx).Debugf(
@ -100,9 +109,8 @@ func DeviceListCatchup(
userSet[userID] = true userSet[userID] = true
} }
} }
// if the response has any join/leave events, add them now. // Finally, add in users who have joined or left.
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them. // TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
joinUserIDs, leaveUserIDs := membershipEvents(res)
for _, userID := range joinUserIDs { for _, userID := range joinUserIDs {
if !userSet[userID] { if !userSet[userID] {
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
@ -214,6 +222,7 @@ func filterSharedUsers(
var sharedUsersRes roomserverAPI.QuerySharedUsersResponse var sharedUsersRes roomserverAPI.QuerySharedUsersResponse
err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{ err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{
UserID: userID, UserID: userID,
OtherUserIDs: usersWithChangedKeys,
}, &sharedUsersRes) }, &sharedUsersRes)
if err != nil { if err != nil {
// default to all users so we do needless queries rather than miss some important device update // default to all users so we do needless queries rather than miss some important device update