diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index c146b2aa0..91f011517 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "sync" "time" @@ -314,6 +315,11 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques for targetKeyID := range masterKey.Keys { sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID) if err != nil { + // Stop executing the function if the context was canceled/the deadline was exceeded, + // as we can't continue without a valid context. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return + } logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") continue } @@ -335,6 +341,11 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques for targetKeyID, key := range forUserID { sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID)) if err != nil { + // Stop executing the function if the context was canceled/the deadline was exceeded, + // as we can't continue without a valid context. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return + } logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") continue } diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 03df9285c..4bf54cae0 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -25,10 +25,9 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) -const DeviceListLogName = "dl" - // DeviceOTKCounts adds one-time key counts to the /sync response func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error { var queryRes keyapi.QueryOneTimeKeysResponse @@ -93,18 +92,13 @@ func DeviceListCatchup( queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...) queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...) queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs) - var sharedUsersMap map[string]int - sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, db, userID, queryRes.UserIDs) - util.GetLogger(ctx).Debugf( - "QueryKeyChanges request off=%d,to=%d response off=%d uids=%v", - offset, toOffset, queryRes.Offset, queryRes.UserIDs, - ) + sharedUsersMap := filterSharedUsers(ctx, db, userID, queryRes.UserIDs) userSet := make(map[string]bool) for _, userID := range res.DeviceLists.Changed { userSet[userID] = true } - for _, userID := range queryRes.UserIDs { - if !userSet[userID] { + for userID, count := range sharedUsersMap { + if !userSet[userID] && count > 0 { res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) hasNew = true userSet[userID] = true @@ -113,7 +107,7 @@ func DeviceListCatchup( // Finally, add in users who have joined or left. // TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them. for _, userID := range joinUserIDs { - if !userSet[userID] { + if !userSet[userID] && sharedUsersMap[userID] > 0 { res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) hasNew = true userSet[userID] = true @@ -126,6 +120,13 @@ func DeviceListCatchup( } } + util.GetLogger(ctx).WithFields(logrus.Fields{ + "user_id": userID, + "from": offset, + "to": toOffset, + "response_offset": queryRes.Offset, + }).Debugf("QueryKeyChanges request result: %+v", res.DeviceLists) + return types.StreamPosition(queryRes.Offset), hasNew, nil } @@ -220,24 +221,27 @@ func TrackChangedUsers( // it down to include only users who the requesting user shares a room with. func filterSharedUsers( ctx context.Context, db storage.SharedUsers, userID string, usersWithChangedKeys []string, -) (map[string]int, []string) { +) map[string]int { sharedUsersMap := make(map[string]int, len(usersWithChangedKeys)) - for _, userID := range usersWithChangedKeys { - sharedUsersMap[userID] = 0 + for _, changedUserID := range usersWithChangedKeys { + sharedUsersMap[changedUserID] = 0 + if changedUserID == userID { + // We forcibly put ourselves in this list because we should be notified about our own device updates + // and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't + // be notified about key changes. + sharedUsersMap[userID] = 1 + } } sharedUsers, err := db.SharedUsers(ctx, userID, usersWithChangedKeys) if err != nil { + util.GetLogger(ctx).WithError(err).Errorf("db.SharedUsers failed: %s", err) // default to all users so we do needless queries rather than miss some important device update - return nil, usersWithChangedKeys + return sharedUsersMap } for _, userID := range sharedUsers { sharedUsersMap[userID]++ } - // We forcibly put ourselves in this list because we should be notified about our own device updates - // and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't - // be notified about key changes. - sharedUsersMap[userID] = 1 - return sharedUsersMap, sharedUsers + return sharedUsersMap } func joinedRooms(res *types.Response, userID string) []string { diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index 79ed440e7..6bfc91edd 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -129,6 +129,7 @@ type wantCatchup struct { } func assertCatchup(t *testing.T, hasNew bool, syncResponse *types.Response, want wantCatchup) { + t.Helper() if hasNew != want.hasNew { t.Errorf("got hasNew=%v want %v", hasNew, want.hasNew) } diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index d13b7be41..58f404511 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -112,7 +112,7 @@ const selectEventsWithEventIDsSQL = "" + const selectSharedUsersSQL = "" + "SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" + " SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" + - ") AND state_key = ANY($2) AND membership='join';" + ") AND state_key = ANY($2) AND membership IN ('join', 'invite');" type currentRoomStateStatements struct { upsertRoomStateStmt *sql.Stmt @@ -407,7 +407,7 @@ func (s *currentRoomStateStatements) SelectSharedUsers( ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string, ) ([]string, error) { stmt := sqlutil.TxStmt(txn, s.selectSharedUsersStmt) - rows, err := stmt.QueryContext(ctx, userID, otherUserIDs) + rows, err := stmt.QueryContext(ctx, userID, pq.Array(otherUserIDs)) if err != nil { return nil, err } diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index e19298aee..3a10b2325 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -94,9 +94,9 @@ const selectEventsWithEventIDsSQL = "" + " FROM syncapi_current_room_state WHERE event_id IN ($1)" const selectSharedUsersSQL = "" + - "SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" + + "SELECT state_key FROM syncapi_current_room_state WHERE room_id IN(" + " SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" + - ") AND state_key IN ($2) AND membership='join';" + ") AND state_key IN ($2) AND membership IN ('join', 'invite');" type currentRoomStateStatements struct { db *sql.DB @@ -420,25 +420,28 @@ func (s *currentRoomStateStatements) SelectStateEvent( func (s *currentRoomStateStatements) SelectSharedUsers( ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string, ) ([]string, error) { - query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1) - stmt, err := s.db.Prepare(query) - if err != nil { - return nil, fmt.Errorf("SelectSharedUsers s.db.Prepare: %w", err) - } - defer internal.CloseAndLogIfError(ctx, stmt, "SelectSharedUsers: stmt.close() failed") - rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, userID, otherUserIDs) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectSharedUsersStmt: rows.close() failed") - var stateKey string - result := make([]string, 0, len(otherUserIDs)) - for rows.Next() { - if err := rows.Scan(&stateKey); err != nil { - return nil, err - } - result = append(result, stateKey) + params := make([]interface{}, len(otherUserIDs)+1) + params[0] = userID + for k, v := range otherUserIDs { + params[k+1] = v } - return result, rows.Err() + + result := make([]string, 0, len(otherUserIDs)) + query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1) + err := sqlutil.RunLimitedVariablesQuery( + ctx, query, s.db, params, sqlutil.SQLite3MaxVariables, + func(rows *sql.Rows) error { + var stateKey string + for rows.Next() { + if err := rows.Scan(&stateKey); err != nil { + return err + } + result = append(result, stateKey) + } + return nil + }, + ) + + return result, err }