From c0af970cc6406b26452eb16856d188b0b5406361 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 25 Jan 2022 16:37:00 +0000 Subject: [PATCH] Make excluding self behaviour optional --- federationapi/api/api.go | 3 ++- federationapi/consumers/keychange.go | 4 ++-- federationapi/internal/query.go | 2 +- federationapi/storage/interface.go | 2 +- federationapi/storage/shared/storage.go | 10 ++++++---- 5 files changed, 12 insertions(+), 9 deletions(-) diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 1e535f406..f5ee75b4b 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -188,7 +188,8 @@ type PerformServersAliveResponse struct { // QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames type QueryJoinedHostServerNamesInRoomRequest struct { - RoomID string `json:"room_id"` + RoomID string `json:"room_id"` + ExcludeSelf bool `json:"exclude_self"` } // QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 8231fcf44..6a737d0ad 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -128,7 +128,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { return nil } // send this key change to all servers who share rooms with this user. - destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs) + destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) if err != nil { logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") return nil @@ -180,7 +180,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error { return nil } // send this key change to all servers who share rooms with this user. - destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs) + destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) if err != nil { logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") return nil diff --git a/federationapi/internal/query.go b/federationapi/internal/query.go index ce57778be..b0a76eeb7 100644 --- a/federationapi/internal/query.go +++ b/federationapi/internal/query.go @@ -16,7 +16,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom( request *api.QueryJoinedHostServerNamesInRoomRequest, response *api.QueryJoinedHostServerNamesInRoomResponse, ) (err error) { - joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}) + joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf) if err != nil { return } diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index a36f51528..21a919f6a 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -32,7 +32,7 @@ type Database interface { GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. - GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) + GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) PurgeRoomState(ctx context.Context, roomID string) error StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 944c0f67f..160c7f6fa 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -103,14 +103,16 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) } -func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) { +func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) { servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs) if err != nil { return nil, err } - for i, server := range servers { - if server == d.ServerName { - servers = append(servers[:i], servers[i+1:]...) + if excludeSelf { + for i, server := range servers { + if server == d.ServerName { + servers = append(servers[:i], servers[i+1:]...) + } } } return servers, nil