Fix unit tests

This commit is contained in:
Neil Alexander 2022-07-15 14:11:39 +01:00
parent 1471091d98
commit 2531ee7cb1
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
3 changed files with 36 additions and 12 deletions

View file

@ -47,7 +47,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, devi
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST // was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
// be already filled in with join/leave information. // be already filled in with join/leave information.
func DeviceListCatchup( func DeviceListCatchup(
ctx context.Context, db storage.Database, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI, ctx context.Context, db storage.SharedUsers, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI,
userID string, res *types.Response, from, to types.StreamPosition, userID string, res *types.Response, from, to types.StreamPosition,
) (newPos types.StreamPosition, hasNew bool, err error) { ) (newPos types.StreamPosition, hasNew bool, err error) {
@ -219,9 +219,11 @@ func TrackChangedUsers(
// filterSharedUsers takes a list of remote users whose keys have changed and filters // filterSharedUsers takes a list of remote users whose keys have changed and filters
// it down to include only users who the requesting user shares a room with. // it down to include only users who the requesting user shares a room with.
func filterSharedUsers( func filterSharedUsers(
ctx context.Context, db storage.Database, userID string, usersWithChangedKeys []string, ctx context.Context, db storage.SharedUsers, userID string, usersWithChangedKeys []string,
) (map[string]int, []string) { ) (map[string]int, []string) {
var sharedUsersRes roomserverAPI.QuerySharedUsersResponse sharedUsersRes := &roomserverAPI.QuerySharedUsersResponse{
UserIDsToCount: map[string]int{},
}
sharedUsers, err := db.SharedUsers(ctx, userID, usersWithChangedKeys) sharedUsers, err := db.SharedUsers(ctx, userID, usersWithChangedKeys)
if err != nil { if err != nil {
// default to all users so we do needless queries rather than miss some important device update // default to all users so we do needless queries rather than miss some important device update

View file

@ -11,6 +11,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
) )
var ( var (
@ -105,6 +106,22 @@ func (s *mockRoomserverAPI) QuerySharedUsers(ctx context.Context, req *api.Query
return nil return nil
} }
// This is actually a database function, but seeing as we track the state inside the
// *mockRoomserverAPI, we'll just comply with the interface here instead.
func (s *mockRoomserverAPI) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
commonUsers := []string{}
for _, members := range s.roomIDToJoinedMembers {
for _, member := range members {
for _, userID := range otherUserIDs {
if member == userID {
commonUsers = append(commonUsers, userID)
}
}
}
}
return util.UniqueStrings(commonUsers), nil
}
type wantCatchup struct { type wantCatchup struct {
hasNew bool hasNew bool
changed []string changed []string
@ -178,7 +195,7 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) {
"!another:room": {syncingUser}, "!another:room": {syncingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
@ -201,7 +218,7 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) {
"!another:room": {syncingUser}, "!another:room": {syncingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
@ -224,7 +241,7 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) {
"!another:room": {syncingUser, existingUser}, "!another:room": {syncingUser, existingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("Catchup returned an error: %s", err) t.Fatalf("Catchup returned an error: %s", err)
} }
@ -246,7 +263,7 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) {
"!another:room": {syncingUser, existingUser}, "!another:room": {syncingUser, existingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
@ -305,7 +322,7 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) {
roomID: {syncingUser, existingUser}, roomID: {syncingUser, existingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
@ -333,7 +350,7 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) {
"!another:room": {syncingUser}, "!another:room": {syncingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("Catchup returned an error: %s", err) t.Fatalf("Catchup returned an error: %s", err)
} }
@ -419,7 +436,7 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) {
}, },
} }
_, hasNew, err := DeviceListCatchup( _, hasNew, err := DeviceListCatchup(
context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken, context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken,
) )
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)

View file

@ -27,6 +27,8 @@ import (
type Database interface { type Database interface {
Presence Presence
SharedUsers
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
@ -54,8 +56,6 @@ type Database interface {
AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error)
// AllJoinedUsersInRoom returns a map of room ID to a list of all joined user IDs for a given room. // AllJoinedUsersInRoom returns a map of room ID to a list of all joined user IDs for a given room.
AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
// SharedUsers returns a subset of otherUserIDs that share a room with userID.
SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error)
// AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices. // AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices.
AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error)
@ -167,3 +167,8 @@ type Presence interface {
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
} }
type SharedUsers interface {
// SharedUsers returns a subset of otherUserIDs that share a room with userID.
SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error)
}