Change query to include invited users

This commit is contained in:
Till Faelligen 2022-08-03 18:12:42 +02:00
parent 401187efae
commit 7f4aeb7091
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
4 changed files with 18 additions and 20 deletions

View file

@ -28,8 +28,6 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
const DeviceListLogName = "dl"
// DeviceOTKCounts adds one-time key counts to the /sync response // DeviceOTKCounts adds one-time key counts to the /sync response
func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error { func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error {
var queryRes keyapi.QueryOneTimeKeysResponse var queryRes keyapi.QueryOneTimeKeysResponse
@ -94,16 +92,13 @@ func DeviceListCatchup(
queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...) queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...)
queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...) queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...)
queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs) queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs)
var sharedUsersMap map[string]int sharedUsersMap := filterSharedUsers(ctx, db, userID, queryRes.UserIDs)
sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, db, userID, queryRes.UserIDs)
userSet := make(map[string]bool) userSet := make(map[string]bool)
for _, userID := range res.DeviceLists.Changed { for _, userID := range res.DeviceLists.Changed {
if sharedUsersMap[userID] > 0 { userSet[userID] = true
userSet[userID] = true
}
} }
for _, userID := range queryRes.UserIDs { for userID, count := range sharedUsersMap {
if !userSet[userID] && sharedUsersMap[userID] > 0 { if !userSet[userID] && count > 0 {
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
hasNew = true hasNew = true
userSet[userID] = true userSet[userID] = true
@ -226,25 +221,27 @@ func TrackChangedUsers(
// it down to include only users who the requesting user shares a room with. // it down to include only users who the requesting user shares a room with.
func filterSharedUsers( func filterSharedUsers(
ctx context.Context, db storage.SharedUsers, userID string, usersWithChangedKeys []string, ctx context.Context, db storage.SharedUsers, userID string, usersWithChangedKeys []string,
) (map[string]int, []string) { ) map[string]int {
sharedUsersMap := make(map[string]int, len(usersWithChangedKeys)) sharedUsersMap := make(map[string]int, len(usersWithChangedKeys))
for _, userID := range usersWithChangedKeys { for _, changedUserID := range usersWithChangedKeys {
sharedUsersMap[userID] = 0 sharedUsersMap[changedUserID] = 0
if changedUserID == userID {
// We forcibly put ourselves in this list because we should be notified about our own device updates
// and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't
// be notified about key changes.
sharedUsersMap[userID] = 1
}
} }
sharedUsers, err := db.SharedUsers(ctx, userID, usersWithChangedKeys) sharedUsers, err := db.SharedUsers(ctx, userID, usersWithChangedKeys)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("db.SharedUsers failed: %s", err) util.GetLogger(ctx).WithError(err).Errorf("db.SharedUsers failed: %s", err)
// default to all users so we do needless queries rather than miss some important device update // default to all users so we do needless queries rather than miss some important device update
return nil, usersWithChangedKeys return sharedUsersMap
} }
for _, userID := range sharedUsers { for _, userID := range sharedUsers {
sharedUsersMap[userID]++ sharedUsersMap[userID]++
} }
// We forcibly put ourselves in this list because we should be notified about our own device updates return sharedUsersMap
// and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't
// be notified about key changes.
sharedUsersMap[userID] = 1
return sharedUsersMap, sharedUsers
} }
func joinedRooms(res *types.Response, userID string) []string { func joinedRooms(res *types.Response, userID string) []string {

View file

@ -129,6 +129,7 @@ type wantCatchup struct {
} }
func assertCatchup(t *testing.T, hasNew bool, syncResponse *types.Response, want wantCatchup) { func assertCatchup(t *testing.T, hasNew bool, syncResponse *types.Response, want wantCatchup) {
t.Helper()
if hasNew != want.hasNew { if hasNew != want.hasNew {
t.Errorf("got hasNew=%v want %v", hasNew, want.hasNew) t.Errorf("got hasNew=%v want %v", hasNew, want.hasNew)
} }

View file

@ -112,7 +112,7 @@ const selectEventsWithEventIDsSQL = "" +
const selectSharedUsersSQL = "" + const selectSharedUsersSQL = "" +
"SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" + "SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
" SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" + " SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
") AND state_key = ANY($2) AND membership='join';" ") AND state_key = ANY($2) AND membership IN ('join', 'invite');"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt

View file

@ -96,7 +96,7 @@ const selectEventsWithEventIDsSQL = "" +
const selectSharedUsersSQL = "" + const selectSharedUsersSQL = "" +
"SELECT state_key FROM syncapi_current_room_state WHERE room_id IN(" + "SELECT state_key FROM syncapi_current_room_state WHERE room_id IN(" +
" SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" + " SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
") AND state_key IN ($2) AND membership='join';" ") AND state_key IN ($2) AND membership IN ('join', 'invite');"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
db *sql.DB db *sql.DB