Make excluding self behaviour optional

This commit is contained in:
Neil Alexander 2022-01-25 16:37:00 +00:00
parent ca48f956d4
commit 04818c4aa4
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
5 changed files with 12 additions and 9 deletions

View file

@ -186,7 +186,8 @@ type PerformServersAliveResponse struct {
// QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames // QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames
type QueryJoinedHostServerNamesInRoomRequest struct { type QueryJoinedHostServerNamesInRoomRequest struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
ExcludeSelf bool `json:"exclude_self"`
} }
// QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames // QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames

View file

@ -128,7 +128,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error {
return nil return nil
} }
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
if err != nil { if err != nil {
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
return nil return nil
@ -180,7 +180,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error {
return nil return nil
} }
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
if err != nil { if err != nil {
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")
return nil return nil

View file

@ -16,7 +16,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom(
request *api.QueryJoinedHostServerNamesInRoomRequest, request *api.QueryJoinedHostServerNamesInRoomRequest,
response *api.QueryJoinedHostServerNamesInRoomResponse, response *api.QueryJoinedHostServerNamesInRoomResponse,
) (err error) { ) (err error) {
joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}) joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf)
if err != nil { if err != nil {
return return
} }

View file

@ -32,7 +32,7 @@ type Database interface {
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given. // GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error)
PurgeRoomState(ctx context.Context, roomID string) error PurgeRoomState(ctx context.Context, roomID string) error
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)

View file

@ -103,14 +103,16 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
} }
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) { func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) {
servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs) servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for i, server := range servers { if excludeSelf {
if server == d.ServerName { for i, server := range servers {
servers = append(servers[:i], servers[i+1:]...) if server == d.ServerName {
servers = append(servers[:i], servers[i+1:]...)
}
} }
} }
return servers, nil return servers, nil