Use sync API database in filterSharedUsers
(#2572)
* Add function to the sync API storage package for filtering shared users * Use the database instead of asking the RS API * Fix unit tests * Fix map handling in `filterSharedUsers`
This commit is contained in:
parent
69c86295f7
commit
90bf01d8b1
|
@ -21,6 +21,7 @@ import (
|
||||||
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||||
keytypes "github.com/matrix-org/dendrite/keyserver/types"
|
keytypes "github.com/matrix-org/dendrite/keyserver/types"
|
||||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
@ -46,7 +47,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, devi
|
||||||
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
|
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
|
||||||
// be already filled in with join/leave information.
|
// be already filled in with join/leave information.
|
||||||
func DeviceListCatchup(
|
func DeviceListCatchup(
|
||||||
ctx context.Context, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI,
|
ctx context.Context, db storage.SharedUsers, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI,
|
||||||
userID string, res *types.Response, from, to types.StreamPosition,
|
userID string, res *types.Response, from, to types.StreamPosition,
|
||||||
) (newPos types.StreamPosition, hasNew bool, err error) {
|
) (newPos types.StreamPosition, hasNew bool, err error) {
|
||||||
|
|
||||||
|
@ -93,7 +94,7 @@ func DeviceListCatchup(
|
||||||
queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...)
|
queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...)
|
||||||
queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs)
|
queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs)
|
||||||
var sharedUsersMap map[string]int
|
var sharedUsersMap map[string]int
|
||||||
sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs)
|
sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, db, userID, queryRes.UserIDs)
|
||||||
util.GetLogger(ctx).Debugf(
|
util.GetLogger(ctx).Debugf(
|
||||||
"QueryKeyChanges request off=%d,to=%d response off=%d uids=%v",
|
"QueryKeyChanges request off=%d,to=%d response off=%d uids=%v",
|
||||||
offset, toOffset, queryRes.Offset, queryRes.UserIDs,
|
offset, toOffset, queryRes.Offset, queryRes.UserIDs,
|
||||||
|
@ -215,30 +216,28 @@ func TrackChangedUsers(
|
||||||
return changed, left, nil
|
return changed, left, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// filterSharedUsers takes a list of remote users whose keys have changed and filters
|
||||||
|
// it down to include only users who the requesting user shares a room with.
|
||||||
func filterSharedUsers(
|
func filterSharedUsers(
|
||||||
ctx context.Context, rsAPI roomserverAPI.SyncRoomserverAPI, userID string, usersWithChangedKeys []string,
|
ctx context.Context, db storage.SharedUsers, userID string, usersWithChangedKeys []string,
|
||||||
) (map[string]int, []string) {
|
) (map[string]int, []string) {
|
||||||
var result []string
|
sharedUsersMap := make(map[string]int, len(usersWithChangedKeys))
|
||||||
var sharedUsersRes roomserverAPI.QuerySharedUsersResponse
|
for _, userID := range usersWithChangedKeys {
|
||||||
err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{
|
sharedUsersMap[userID] = 0
|
||||||
UserID: userID,
|
}
|
||||||
OtherUserIDs: usersWithChangedKeys,
|
sharedUsers, err := db.SharedUsers(ctx, userID, usersWithChangedKeys)
|
||||||
}, &sharedUsersRes)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// default to all users so we do needless queries rather than miss some important device update
|
// default to all users so we do needless queries rather than miss some important device update
|
||||||
return nil, usersWithChangedKeys
|
return nil, usersWithChangedKeys
|
||||||
}
|
}
|
||||||
|
for _, userID := range sharedUsers {
|
||||||
|
sharedUsersMap[userID]++
|
||||||
|
}
|
||||||
// We forcibly put ourselves in this list because we should be notified about our own device updates
|
// We forcibly put ourselves in this list because we should be notified about our own device updates
|
||||||
// and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't
|
// and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't
|
||||||
// be notified about key changes.
|
// be notified about key changes.
|
||||||
sharedUsersRes.UserIDsToCount[userID] = 1
|
sharedUsersMap[userID] = 1
|
||||||
|
return sharedUsersMap, sharedUsers
|
||||||
for _, uid := range usersWithChangedKeys {
|
|
||||||
if sharedUsersRes.UserIDsToCount[uid] > 0 {
|
|
||||||
result = append(result, uid)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return sharedUsersRes.UserIDsToCount, result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func joinedRooms(res *types.Response, userID string) []string {
|
func joinedRooms(res *types.Response, userID string) []string {
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -105,6 +106,22 @@ func (s *mockRoomserverAPI) QuerySharedUsers(ctx context.Context, req *api.Query
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This is actually a database function, but seeing as we track the state inside the
|
||||||
|
// *mockRoomserverAPI, we'll just comply with the interface here instead.
|
||||||
|
func (s *mockRoomserverAPI) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
|
||||||
|
commonUsers := []string{}
|
||||||
|
for _, members := range s.roomIDToJoinedMembers {
|
||||||
|
for _, member := range members {
|
||||||
|
for _, userID := range otherUserIDs {
|
||||||
|
if member == userID {
|
||||||
|
commonUsers = append(commonUsers, userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return util.UniqueStrings(commonUsers), nil
|
||||||
|
}
|
||||||
|
|
||||||
type wantCatchup struct {
|
type wantCatchup struct {
|
||||||
hasNew bool
|
hasNew bool
|
||||||
changed []string
|
changed []string
|
||||||
|
@ -178,7 +195,7 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) {
|
||||||
"!another:room": {syncingUser},
|
"!another:room": {syncingUser},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
_, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -201,7 +218,7 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) {
|
||||||
"!another:room": {syncingUser},
|
"!another:room": {syncingUser},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
_, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -224,7 +241,7 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) {
|
||||||
"!another:room": {syncingUser, existingUser},
|
"!another:room": {syncingUser, existingUser},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
_, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Catchup returned an error: %s", err)
|
t.Fatalf("Catchup returned an error: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -246,7 +263,7 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) {
|
||||||
"!another:room": {syncingUser, existingUser},
|
"!another:room": {syncingUser, existingUser},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
_, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -305,7 +322,7 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) {
|
||||||
roomID: {syncingUser, existingUser},
|
roomID: {syncingUser, existingUser},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
_, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -333,7 +350,7 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) {
|
||||||
"!another:room": {syncingUser},
|
"!another:room": {syncingUser},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
_, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Catchup returned an error: %s", err)
|
t.Fatalf("Catchup returned an error: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -419,7 +436,7 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, hasNew, err := DeviceListCatchup(
|
_, hasNew, err := DeviceListCatchup(
|
||||||
context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken,
|
context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
||||||
|
|
|
@ -27,6 +27,8 @@ import (
|
||||||
|
|
||||||
type Database interface {
|
type Database interface {
|
||||||
Presence
|
Presence
|
||||||
|
SharedUsers
|
||||||
|
|
||||||
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
|
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
|
||||||
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
|
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
|
||||||
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
|
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
|
||||||
|
@ -165,3 +167,8 @@ type Presence interface {
|
||||||
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
|
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
|
||||||
MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
|
MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SharedUsers interface {
|
||||||
|
// SharedUsers returns a subset of otherUserIDs that share a room with userID.
|
||||||
|
SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error)
|
||||||
|
}
|
||||||
|
|
|
@ -107,6 +107,11 @@ const selectEventsWithEventIDsSQL = "" +
|
||||||
"SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
|
"SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
|
||||||
" FROM syncapi_current_room_state WHERE event_id = ANY($1)"
|
" FROM syncapi_current_room_state WHERE event_id = ANY($1)"
|
||||||
|
|
||||||
|
const selectSharedUsersSQL = "" +
|
||||||
|
"SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
|
||||||
|
" SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
|
||||||
|
") AND state_key = ANY($2) AND membership='join';"
|
||||||
|
|
||||||
type currentRoomStateStatements struct {
|
type currentRoomStateStatements struct {
|
||||||
upsertRoomStateStmt *sql.Stmt
|
upsertRoomStateStmt *sql.Stmt
|
||||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||||
|
@ -118,6 +123,7 @@ type currentRoomStateStatements struct {
|
||||||
selectJoinedUsersInRoomStmt *sql.Stmt
|
selectJoinedUsersInRoomStmt *sql.Stmt
|
||||||
selectEventsWithEventIDsStmt *sql.Stmt
|
selectEventsWithEventIDsStmt *sql.Stmt
|
||||||
selectStateEventStmt *sql.Stmt
|
selectStateEventStmt *sql.Stmt
|
||||||
|
selectSharedUsersStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
|
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
|
||||||
|
@ -156,6 +162,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
|
||||||
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
|
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.selectSharedUsersStmt, err = db.Prepare(selectSharedUsersSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -379,3 +388,24 @@ func (s *currentRoomStateStatements) SelectStateEvent(
|
||||||
}
|
}
|
||||||
return &ev, err
|
return &ev, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *currentRoomStateStatements) SelectSharedUsers(
|
||||||
|
ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
|
||||||
|
) ([]string, error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectSharedUsersStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, userID, otherUserIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectSharedUsersStmt: rows.close() failed")
|
||||||
|
|
||||||
|
var stateKey string
|
||||||
|
result := make([]string, 0, len(otherUserIDs))
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(&stateKey); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, stateKey)
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
|
@ -176,6 +176,10 @@ func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]t
|
||||||
return d.Peeks.SelectPeekingDevices(ctx)
|
return d.Peeks.SelectPeekingDevices(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
|
||||||
|
return d.CurrentRoomState.SelectSharedUsers(ctx, nil, userID, otherUserIDs)
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Database) GetStateEvent(
|
func (d *Database) GetStateEvent(
|
||||||
ctx context.Context, roomID, evType, stateKey string,
|
ctx context.Context, roomID, evType, stateKey string,
|
||||||
) (*gomatrixserverlib.HeaderedEvent, error) {
|
) (*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
|
|
|
@ -91,6 +91,11 @@ const selectEventsWithEventIDsSQL = "" +
|
||||||
"SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
|
"SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
|
||||||
" FROM syncapi_current_room_state WHERE event_id IN ($1)"
|
" FROM syncapi_current_room_state WHERE event_id IN ($1)"
|
||||||
|
|
||||||
|
const selectSharedUsersSQL = "" +
|
||||||
|
"SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
|
||||||
|
" SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
|
||||||
|
") AND state_key IN ($2) AND membership='join';"
|
||||||
|
|
||||||
type currentRoomStateStatements struct {
|
type currentRoomStateStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
streamIDStatements *StreamIDStatements
|
streamIDStatements *StreamIDStatements
|
||||||
|
@ -100,8 +105,9 @@ type currentRoomStateStatements struct {
|
||||||
selectRoomIDsWithMembershipStmt *sql.Stmt
|
selectRoomIDsWithMembershipStmt *sql.Stmt
|
||||||
selectRoomIDsWithAnyMembershipStmt *sql.Stmt
|
selectRoomIDsWithAnyMembershipStmt *sql.Stmt
|
||||||
selectJoinedUsersStmt *sql.Stmt
|
selectJoinedUsersStmt *sql.Stmt
|
||||||
//selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
|
//selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
selectStateEventStmt *sql.Stmt
|
selectStateEventStmt *sql.Stmt
|
||||||
|
//selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
|
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
|
||||||
|
@ -396,3 +402,29 @@ func (s *currentRoomStateStatements) SelectStateEvent(
|
||||||
}
|
}
|
||||||
return &ev, err
|
return &ev, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *currentRoomStateStatements) SelectSharedUsers(
|
||||||
|
ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
|
||||||
|
) ([]string, error) {
|
||||||
|
query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1)
|
||||||
|
stmt, err := s.db.Prepare(query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("SelectSharedUsers s.db.Prepare: %w", err)
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, stmt, "SelectSharedUsers: stmt.close() failed")
|
||||||
|
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, userID, otherUserIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectSharedUsersStmt: rows.close() failed")
|
||||||
|
|
||||||
|
var stateKey string
|
||||||
|
result := make([]string, 0, len(otherUserIDs))
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(&stateKey); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, stateKey)
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
|
@ -104,6 +104,8 @@ type CurrentRoomState interface {
|
||||||
SelectJoinedUsers(ctx context.Context) (map[string][]string, error)
|
SelectJoinedUsers(ctx context.Context) (map[string][]string, error)
|
||||||
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
|
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
|
||||||
SelectJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
|
SelectJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
|
||||||
|
// SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
|
||||||
|
SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BackwardsExtremities keeps track of backwards extremities for a room.
|
// BackwardsExtremities keeps track of backwards extremities for a room.
|
||||||
|
|
|
@ -28,7 +28,7 @@ func (p *DeviceListStreamProvider) IncrementalSync(
|
||||||
from, to types.StreamPosition,
|
from, to types.StreamPosition,
|
||||||
) types.StreamPosition {
|
) types.StreamPosition {
|
||||||
var err error
|
var err error
|
||||||
to, _, err = internal.DeviceListCatchup(context.Background(), p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
|
to, _, err = internal.DeviceListCatchup(context.Background(), p.DB, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
|
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
|
||||||
return from
|
return from
|
||||||
|
|
|
@ -429,7 +429,7 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use
|
||||||
}
|
}
|
||||||
rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), syncReq, fromToken.PDUPosition, toToken.PDUPosition)
|
rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), syncReq, fromToken.PDUPosition, toToken.PDUPosition)
|
||||||
_, _, err = internal.DeviceListCatchup(
|
_, _, err = internal.DeviceListCatchup(
|
||||||
req.Context(), rp.keyAPI, rp.rsAPI, syncReq.Device.UserID,
|
req.Context(), rp.db, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID,
|
||||||
syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition,
|
syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in a new issue