From d14efa29cbd44cae459409e6970d5b91c8c90896 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 14 Nov 2022 12:30:20 +0000 Subject: [PATCH] Federation fixes for virtual hosting (maybe) --- clientapi/routing/createroom.go | 2 +- clientapi/routing/directory.go | 2 +- clientapi/routing/directory_public.go | 2 +- clientapi/routing/membership.go | 1 + clientapi/routing/profile.go | 4 +- clientapi/routing/redaction.go | 2 +- clientapi/routing/sendevent.go | 1 + clientapi/routing/server_notices.go | 1 + clientapi/routing/userdirectory.go | 2 +- clientapi/threepid/invites.go | 1 + cmd/dendrite-demo-pinecone/conn/client.go | 4 +- cmd/dendrite-demo-pinecone/rooms/rooms.go | 8 ++- cmd/dendrite-demo-yggdrasil/yggconn/client.go | 3 +- .../yggrooms/yggrooms.go | 9 +++- cmd/furl/main.go | 12 +++-- federationapi/api/api.go | 54 +++++++++---------- federationapi/federationapi_keys_test.go | 2 +- federationapi/federationapi_test.go | 8 +-- federationapi/internal/federationclient.go | 44 +++++++-------- federationapi/internal/perform.go | 34 +++++++++--- federationapi/inthttp/client.go | 45 +++++++++++----- federationapi/inthttp/server.go | 22 ++++---- federationapi/routing/backfill.go | 3 +- federationapi/routing/query.go | 2 +- federationapi/routing/send.go | 13 ++--- federationapi/routing/send_test.go | 10 ++-- federationapi/routing/threepid.go | 27 ++++++++-- keyserver/internal/device_list_update.go | 5 +- keyserver/internal/device_list_update_test.go | 14 +++-- keyserver/internal/internal.go | 4 +- keyserver/keyserver.go | 2 +- roomserver/api/input.go | 5 +- roomserver/api/perform.go | 2 + roomserver/api/wrapper.go | 11 ++-- roomserver/internal/alias.go | 7 ++- roomserver/internal/api.go | 8 +-- roomserver/internal/input/input.go | 7 ++- roomserver/internal/input/input_events.go | 6 ++- roomserver/internal/input/input_missing.go | 17 +++--- roomserver/internal/perform/perform_admin.go | 2 +- .../internal/perform/perform_backfill.go | 46 ++++++++-------- .../internal/perform/perform_upgrade.go | 6 +-- roomserver/roomserver_test.go | 2 +- setup/base/base.go | 19 +++++-- setup/config/config_global.go | 29 ++++++++++ setup/mscs/msc2836/msc2836.go | 6 +-- setup/mscs/msc2946/msc2946.go | 2 +- syncapi/syncapi_test.go | 4 +- 48 files changed, 336 insertions(+), 186 deletions(-) diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index eefe8e24b..a0d80903d 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -477,7 +477,7 @@ func createRoom( SendAsServer: roomserverAPI.DoNotSendToOtherServers, }) } - if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, inputs, false); err != nil { + if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, device.UserDomain(), inputs, false); err != nil { util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index ce14745aa..b3c5aae45 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -77,7 +77,7 @@ func DirectoryRoom( // If we don't know it locally, do a federation query. // But don't send the query to ourselves. if !cfg.Matrix.IsLocalServerName(domain) { - fedRes, fedErr := federation.LookupRoomAlias(req.Context(), domain, roomAlias) + fedRes, fedErr := federation.LookupRoomAlias(req.Context(), cfg.Matrix.ServerName, domain, roomAlias) if fedErr != nil { // TODO: Return 502 if the remote server errored. // TODO: Return 504 if the remote server timed out. diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go index b1043e994..606744767 100644 --- a/clientapi/routing/directory_public.go +++ b/clientapi/routing/directory_public.go @@ -74,7 +74,7 @@ func GetPostPublicRooms( serverName := gomatrixserverlib.ServerName(request.Server) if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) { res, err := federation.GetPublicRoomsFiltered( - req.Context(), serverName, + req.Context(), cfg.Matrix.ServerName, serverName, int(request.Limit), request.Since, request.Filter.SearchTerms, false, "", diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 94ba17a02..f3a687017 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -110,6 +110,7 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic ctx, rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, + device.UserDomain(), serverName, serverName, nil, diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 4d9e1f8a5..5ccddcaff 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -298,7 +298,7 @@ func updateProfile( return jsonerror.InternalServerError(), e } - if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, domain, domain, nil, true); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, device.UserDomain(), domain, domain, nil, true); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError(), err } @@ -321,7 +321,7 @@ func getProfile( } if !cfg.Matrix.IsLocalServerName(domain) { - profile, fedErr := federation.LookupProfile(ctx, domain, userID, "") + profile, fedErr := federation.LookupProfile(ctx, cfg.Matrix.ServerName, domain, userID, "") if fedErr != nil { if x, ok := fedErr.(gomatrix.HTTPError); ok { if x.Code == http.StatusNotFound { diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 778a02fd4..c343b4a1e 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -132,7 +132,7 @@ func SendRedaction( } } domain := device.UserDomain() - if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, domain, domain, nil, false); err != nil { + if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, device.UserDomain(), domain, domain, nil, false); err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index bb66cf6fc..48625eadf 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -186,6 +186,7 @@ func SendEvent( []*gomatrixserverlib.HeaderedEvent{ e.Headered(verRes.RoomVersion), }, + device.UserDomain(), domain, domain, txnAndSessionID, diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index a7acee326..fb93d8783 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -231,6 +231,7 @@ func SendServerNotice( []*gomatrixserverlib.HeaderedEvent{ e.Headered(roomVersion), }, + device.UserDomain(), cfgClient.Matrix.ServerName, cfgClient.Matrix.ServerName, txnAndSessionID, diff --git a/clientapi/routing/userdirectory.go b/clientapi/routing/userdirectory.go index d3d1c22e4..62af9efa4 100644 --- a/clientapi/routing/userdirectory.go +++ b/clientapi/routing/userdirectory.go @@ -106,7 +106,7 @@ knownUsersLoop: continue } // TODO: We should probably cache/store this - fedProfile, fedErr := federation.LookupProfile(ctx, serverName, userID, "") + fedProfile, fedErr := federation.LookupProfile(ctx, localServerName, serverName, userID, "") if fedErr != nil { if x, ok := fedErr.(gomatrix.HTTPError); ok { if x.Code == http.StatusNotFound { diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 99fb8171d..a5a52ffb9 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -371,6 +371,7 @@ func emit3PIDInviteEvent( []*gomatrixserverlib.HeaderedEvent{ event.Headered(queryRes.RoomVersion), }, + device.UserDomain(), cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, diff --git a/cmd/dendrite-demo-pinecone/conn/client.go b/cmd/dendrite-demo-pinecone/conn/client.go index 27e18c2a3..a91434f62 100644 --- a/cmd/dendrite-demo-pinecone/conn/client.go +++ b/cmd/dendrite-demo-pinecone/conn/client.go @@ -101,9 +101,7 @@ func CreateFederationClient( base *base.BaseDendrite, s *pineconeSessions.Sessions, ) *gomatrixserverlib.FederationClient { return gomatrixserverlib.NewFederationClient( - base.Cfg.Global.ServerName, - base.Cfg.Global.KeyID, - base.Cfg.Global.PrivateKey, + base.Cfg.Global.SigningIdentities(), gomatrixserverlib.WithTransport(createTransport(s)), ) } diff --git a/cmd/dendrite-demo-pinecone/rooms/rooms.go b/cmd/dendrite-demo-pinecone/rooms/rooms.go index 0fafbedc3..0ac705cc1 100644 --- a/cmd/dendrite-demo-pinecone/rooms/rooms.go +++ b/cmd/dendrite-demo-pinecone/rooms/rooms.go @@ -58,13 +58,17 @@ func (p *PineconeRoomProvider) Rooms() []gomatrixserverlib.PublicRoom { for _, k := range p.r.Peers() { list[gomatrixserverlib.ServerName(k.PublicKey)] = struct{}{} } - return bulkFetchPublicRoomsFromServers(context.Background(), p.fedClient, list) + return bulkFetchPublicRoomsFromServers( + context.Background(), p.fedClient, + gomatrixserverlib.ServerName(p.r.PublicKey().String()), list, + ) } // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // Returns a list of public rooms. func bulkFetchPublicRoomsFromServers( ctx context.Context, fedClient *gomatrixserverlib.FederationClient, + origin gomatrixserverlib.ServerName, homeservers map[gomatrixserverlib.ServerName]struct{}, ) (publicRooms []gomatrixserverlib.PublicRoom) { limit := 200 @@ -82,7 +86,7 @@ func bulkFetchPublicRoomsFromServers( go func(homeserverDomain gomatrixserverlib.ServerName) { defer wg.Done() util.GetLogger(reqctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") - fres, err := fedClient.GetPublicRooms(reqctx, homeserverDomain, int(limit), "", false, "") + fres, err := fedClient.GetPublicRooms(reqctx, origin, homeserverDomain, int(limit), "", false, "") if err != nil { util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Warn( "bulkFetchPublicRoomsFromServers: failed to query hs", diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/client.go b/cmd/dendrite-demo-yggdrasil/yggconn/client.go index 358d3725e..41a9ec123 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/client.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/client.go @@ -55,8 +55,7 @@ func (n *Node) CreateFederationClient( }, ) return gomatrixserverlib.NewFederationClient( - base.Cfg.Global.ServerName, base.Cfg.Global.KeyID, - base.Cfg.Global.PrivateKey, + base.Cfg.Global.SigningIdentities(), gomatrixserverlib.WithTransport(tr), ) } diff --git a/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go b/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go index 402b86ed3..0de64755e 100644 --- a/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go +++ b/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go @@ -43,13 +43,18 @@ func NewYggdrasilRoomProvider( } func (p *YggdrasilRoomProvider) Rooms() []gomatrixserverlib.PublicRoom { - return bulkFetchPublicRoomsFromServers(context.Background(), p.fedClient, p.node.KnownNodes()) + return bulkFetchPublicRoomsFromServers( + context.Background(), p.fedClient, + gomatrixserverlib.ServerName(p.node.DerivedServerName()), + p.node.KnownNodes(), + ) } // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // Returns a list of public rooms. func bulkFetchPublicRoomsFromServers( ctx context.Context, fedClient *gomatrixserverlib.FederationClient, + origin gomatrixserverlib.ServerName, homeservers []gomatrixserverlib.ServerName, ) (publicRooms []gomatrixserverlib.PublicRoom) { limit := 200 @@ -66,7 +71,7 @@ func bulkFetchPublicRoomsFromServers( go func(homeserverDomain gomatrixserverlib.ServerName) { defer wg.Done() util.GetLogger(ctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") - fres, err := fedClient.GetPublicRooms(ctx, homeserverDomain, int(limit), "", false, "") + fres, err := fedClient.GetPublicRooms(ctx, origin, homeserverDomain, int(limit), "", false, "") if err != nil { util.GetLogger(ctx).WithError(err).WithField("hs", homeserverDomain).Warn( "bulkFetchPublicRoomsFromServers: failed to query hs", diff --git a/cmd/furl/main.go b/cmd/furl/main.go index f59f9c8ce..b208ba868 100644 --- a/cmd/furl/main.go +++ b/cmd/furl/main.go @@ -48,10 +48,15 @@ func main() { panic("unexpected key block") } + serverName := gomatrixserverlib.ServerName(*requestFrom) client := gomatrixserverlib.NewFederationClient( - gomatrixserverlib.ServerName(*requestFrom), - gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]), - privateKey, + []*gomatrixserverlib.SigningIdentity{ + { + ServerName: serverName, + KeyID: gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]), + PrivateKey: privateKey, + }, + }, ) u, err := url.Parse(flag.Arg(0)) @@ -79,6 +84,7 @@ func main() { req := gomatrixserverlib.NewFederationRequest( method, + serverName, gomatrixserverlib.ServerName(u.Host), u.RequestURI(), ) diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 362333fc9..e34c9e8b6 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -21,8 +21,8 @@ type FederationInternalAPI interface { QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) - MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) - MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) + MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) + MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) // Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos. PerformBroadcastEDU( @@ -60,18 +60,18 @@ type RoomserverFederationAPI interface { // containing only the server names (without information for membership events). // The response will include this server if they are joined to the room. QueryJoinedHostServerNamesInRoom(ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, response *QueryJoinedHostServerNamesInRoomResponse) error - GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) - GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) + GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) + GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // this interface are of type FederationClientError type KeyserverFederationAPI interface { - GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) - ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) - QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) + GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) + ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) + QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) } // an interface for gmsl.FederationClient - contains functions called by federationapi only. @@ -80,28 +80,28 @@ type FederationClient interface { SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) // Perform operations - LookupRoomAlias(ctx context.Context, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) - Peek(ctx context.Context, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error) - MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) - SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) - MakeLeave(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error) - SendLeave(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error) - SendInviteV2(ctx context.Context, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error) + LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) + Peek(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error) + MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) + SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) + MakeLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error) + SendLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error) + SendInviteV2(ctx context.Context, origin, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error) - GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) - GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error) - ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) - QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) - Backfill(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error) - MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) - MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) + GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) + GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error) + ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) + QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) + Backfill(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error) + MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) + MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) - ExchangeThirdPartyInvite(ctx context.Context, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error) - LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error) - LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) - LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) + ExchangeThirdPartyInvite(ctx context.Context, origin, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error) + LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error) + LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) + LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } // FederationClientError is returned from FederationClient methods in the event of a problem. diff --git a/federationapi/federationapi_keys_test.go b/federationapi/federationapi_keys_test.go index 3acaa70dc..cc03cdece 100644 --- a/federationapi/federationapi_keys_test.go +++ b/federationapi/federationapi_keys_test.go @@ -104,7 +104,7 @@ func TestMain(m *testing.M) { // Create the federation client. s.fedclient = gomatrixserverlib.NewFederationClient( - s.config.Matrix.ServerName, serverKeyID, testPriv, + s.config.Matrix.SigningIdentities(), gomatrixserverlib.WithTransport(transport), ) diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index c37bc87c2..68a06a033 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -103,7 +103,7 @@ func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserv return keys, nil } -func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) { +func (f *fedClient) MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) { for _, r := range f.allowJoins { if r.ID == roomID { res.RoomVersion = r.Version @@ -127,7 +127,7 @@ func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName } return } -func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) { +func (f *fedClient) SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) { f.fedClientMutex.Lock() defer f.fedClientMutex.Unlock() for _, r := range f.allowJoins { @@ -283,7 +283,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) fedCli := gomatrixserverlib.NewFederationClient( - serverName, cfg.Global.KeyID, cfg.Global.PrivateKey, + cfg.Global.SigningIdentities(), gomatrixserverlib.WithSkipVerify(true), ) @@ -326,7 +326,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { t.Errorf("failed to create invite v2 request: %s", err) continue } - _, err = fedCli.SendInviteV2(context.Background(), serverName, invReq) + _, err = fedCli.SendInviteV2(context.Background(), cfg.Global.ServerName, serverName, invReq) if err == nil { t.Errorf("expected an error, got none") continue diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index 2636b7fa0..db6348ec1 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -11,13 +11,13 @@ import ( // client. func (a *FederationInternalAPI) GetEventAuth( - ctx context.Context, s gomatrixserverlib.ServerName, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, ) (res gomatrixserverlib.RespEventAuth, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.GetEventAuth(ctx, s, roomVersion, roomID, eventID) + return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID) }) if err != nil { return gomatrixserverlib.RespEventAuth{}, err @@ -26,12 +26,12 @@ func (a *FederationInternalAPI) GetEventAuth( } func (a *FederationInternalAPI) GetUserDevices( - ctx context.Context, s gomatrixserverlib.ServerName, userID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string, ) (gomatrixserverlib.RespUserDevices, error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.GetUserDevices(ctx, s, userID) + return a.federation.GetUserDevices(ctx, origin, s, userID) }) if err != nil { return gomatrixserverlib.RespUserDevices{}, err @@ -40,12 +40,12 @@ func (a *FederationInternalAPI) GetUserDevices( } func (a *FederationInternalAPI) ClaimKeys( - ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ) (gomatrixserverlib.RespClaimKeys, error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.ClaimKeys(ctx, s, oneTimeKeys) + return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys) }) if err != nil { return gomatrixserverlib.RespClaimKeys{}, err @@ -54,10 +54,10 @@ func (a *FederationInternalAPI) ClaimKeys( } func (a *FederationInternalAPI) QueryKeys( - ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string, ) (gomatrixserverlib.RespQueryKeys, error) { ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) { - return a.federation.QueryKeys(ctx, s, keys) + return a.federation.QueryKeys(ctx, origin, s, keys) }) if err != nil { return gomatrixserverlib.RespQueryKeys{}, err @@ -66,12 +66,12 @@ func (a *FederationInternalAPI) QueryKeys( } func (a *FederationInternalAPI) Backfill( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ) (res gomatrixserverlib.Transaction, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.Backfill(ctx, s, roomID, limit, eventIDs) + return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs) }) if err != nil { return gomatrixserverlib.Transaction{}, err @@ -80,12 +80,12 @@ func (a *FederationInternalAPI) Backfill( } func (a *FederationInternalAPI) LookupState( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.RespState, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.LookupState(ctx, s, roomID, eventID, roomVersion) + return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion) }) if err != nil { return gomatrixserverlib.RespState{}, err @@ -94,12 +94,12 @@ func (a *FederationInternalAPI) LookupState( } func (a *FederationInternalAPI) LookupStateIDs( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, ) (res gomatrixserverlib.RespStateIDs, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.LookupStateIDs(ctx, s, roomID, eventID) + return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID) }) if err != nil { return gomatrixserverlib.RespStateIDs{}, err @@ -108,13 +108,13 @@ func (a *FederationInternalAPI) LookupStateIDs( } func (a *FederationInternalAPI) LookupMissingEvents( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.RespMissingEvents, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.LookupMissingEvents(ctx, s, roomID, missing, roomVersion) + return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion) }) if err != nil { return gomatrixserverlib.RespMissingEvents{}, err @@ -123,12 +123,12 @@ func (a *FederationInternalAPI) LookupMissingEvents( } func (a *FederationInternalAPI) GetEvent( - ctx context.Context, s gomatrixserverlib.ServerName, eventID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string, ) (res gomatrixserverlib.Transaction, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.GetEvent(ctx, s, eventID) + return a.federation.GetEvent(ctx, origin, s, eventID) }) if err != nil { return gomatrixserverlib.Transaction{}, err @@ -151,13 +151,13 @@ func (a *FederationInternalAPI) LookupServerKeys( } func (a *FederationInternalAPI) MSC2836EventRelationships( - ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, + ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.MSC2836EventRelationships(ctx, s, r, roomVersion) + return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion) }) if err != nil { return res, err @@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( } func (a *FederationInternalAPI) MSC2946Spaces( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly) + return a.federation.MSC2946Spaces(ctx, origin, s, roomID, suggestedOnly) }) if err != nil { return res, err diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 1b61ec711..d60da8f76 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -26,6 +26,7 @@ func (r *FederationInternalAPI) PerformDirectoryLookup( ) (err error) { dir, err := r.federation.LookupRoomAlias( ctx, + r.cfg.Matrix.ServerName, request.ServerName, request.RoomAlias, ) @@ -143,10 +144,16 @@ func (r *FederationInternalAPI) performJoinUsingServer( supportedVersions []gomatrixserverlib.RoomVersion, unsigned map[string]interface{}, ) error { + _, origin, err := r.cfg.Matrix.SplitLocalID('@', userID) + if err != nil { + return err + } + // Try to perform a make_join using the information supplied in the // request. respMakeJoin, err := r.federation.MakeJoin( ctx, + origin, serverName, roomID, userID, @@ -204,6 +211,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( // Try to perform a send_join using the newly built event. respSendJoin, err := r.federation.SendJoin( context.Background(), + origin, serverName, event, ) @@ -246,7 +254,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( respMakeJoin.RoomVersion, r.keyRing, event, - federatedAuthProvider(ctx, r.federation, r.keyRing, serverName), + federatedAuthProvider(ctx, r.federation, r.keyRing, origin, serverName), ) if err != nil { return fmt.Errorf("respSendJoin.Check: %w", err) @@ -281,6 +289,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( if err = roomserverAPI.SendEventWithState( context.Background(), r.rsAPI, + origin, roomserverAPI.KindNew, respState, event.Headered(respMakeJoin.RoomVersion), @@ -427,6 +436,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // request. respPeek, err := r.federation.Peek( ctx, + r.cfg.Matrix.ServerName, serverName, roomID, peekID, @@ -453,7 +463,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // authenticate the state returned (check its auth events etc) // the equivalent of CheckSendJoinResponse() - authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)) + authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName)) if err != nil { return fmt.Errorf("error checking state returned from peeking: %w", err) } @@ -475,7 +485,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // logrus.Warnf("got respPeek %#v", respPeek) // Send the newly returned state to the roomserver to update our local view. if err = roomserverAPI.SendEventWithState( - ctx, r.rsAPI, + ctx, r.rsAPI, r.cfg.Matrix.ServerName, roomserverAPI.KindNew, &respState, respPeek.LatestEvent.Headered(respPeek.RoomVersion), @@ -495,6 +505,11 @@ func (r *FederationInternalAPI) PerformLeave( request *api.PerformLeaveRequest, response *api.PerformLeaveResponse, ) (err error) { + _, origin, err := r.cfg.Matrix.SplitLocalID('@', request.UserID) + if err != nil { + return err + } + // Deduplicate the server names we were provided. util.SortAndUnique(request.ServerNames) @@ -505,6 +520,7 @@ func (r *FederationInternalAPI) PerformLeave( // request. respMakeLeave, err := r.federation.MakeLeave( ctx, + origin, serverName, request.RoomID, request.UserID, @@ -559,6 +575,7 @@ func (r *FederationInternalAPI) PerformLeave( // Try to perform a send_leave using the newly built event. err = r.federation.SendLeave( ctx, + origin, serverName, event, ) @@ -585,6 +602,11 @@ func (r *FederationInternalAPI) PerformInvite( request *api.PerformInviteRequest, response *api.PerformInviteResponse, ) (err error) { + _, origin, err := r.cfg.Matrix.SplitLocalID('@', request.Event.Sender()) + if err != nil { + return err + } + if request.Event.StateKey() == nil { return errors.New("invite must be a state event") } @@ -607,7 +629,7 @@ func (r *FederationInternalAPI) PerformInvite( return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) } - inviteRes, err := r.federation.SendInviteV2(ctx, destination, inviteReq) + inviteRes, err := r.federation.SendInviteV2(ctx, origin, destination, inviteReq) if err != nil { return fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) } @@ -708,7 +730,7 @@ func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder // FederatedAuthProvider is an auth chain provider which fetches events from the server provided func federatedAuthProvider( ctx context.Context, federation api.FederationClient, - keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName, + keyRing gomatrixserverlib.JSONVerifier, origin, server gomatrixserverlib.ServerName, ) gomatrixserverlib.AuthChainProvider { // A list of events that we have retried, if they were not included in // the auth events supplied in the send_join. @@ -738,7 +760,7 @@ func federatedAuthProvider( // Try to retrieve the event from the server that sent us the send // join response. - tx, txerr := federation.GetEvent(ctx, server, eventID) + tx, txerr := federation.GetEvent(ctx, origin, server, eventID) if txerr != nil { return nil, fmt.Errorf("missingAuth r.federation.GetEvent: %w", txerr) } diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index 812d3c6da..43f88b828 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -152,11 +152,12 @@ func (h *httpFederationInternalAPI) PerformBroadcastEDU( type getUserDevices struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName UserID string } func (h *httpFederationInternalAPI) GetUserDevices( - ctx context.Context, s gomatrixserverlib.ServerName, userID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string, ) (gomatrixserverlib.RespUserDevices, error) { return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError]( "GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient, @@ -169,11 +170,12 @@ func (h *httpFederationInternalAPI) GetUserDevices( type claimKeys struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName OneTimeKeys map[string]map[string]string } func (h *httpFederationInternalAPI) ClaimKeys( - ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ) (gomatrixserverlib.RespClaimKeys, error) { return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError]( "ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient, @@ -185,12 +187,13 @@ func (h *httpFederationInternalAPI) ClaimKeys( } type queryKeys struct { - S gomatrixserverlib.ServerName - Keys map[string][]string + S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName + Keys map[string][]string } func (h *httpFederationInternalAPI) QueryKeys( - ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string, ) (gomatrixserverlib.RespQueryKeys, error) { return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError]( "QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient, @@ -203,18 +206,20 @@ func (h *httpFederationInternalAPI) QueryKeys( type backfill struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomID string Limit int EventIDs []string } func (h *httpFederationInternalAPI) Backfill( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ) (gomatrixserverlib.Transaction, error) { return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError]( "Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient, ctx, &backfill{ S: s, + Origin: origin, RoomID: roomID, Limit: limit, EventIDs: eventIDs, @@ -224,18 +229,20 @@ func (h *httpFederationInternalAPI) Backfill( type lookupState struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomID string EventID string RoomVersion gomatrixserverlib.RoomVersion } func (h *httpFederationInternalAPI) LookupState( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (gomatrixserverlib.RespState, error) { return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError]( "LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient, ctx, &lookupState{ S: s, + Origin: origin, RoomID: roomID, EventID: eventID, RoomVersion: roomVersion, @@ -245,17 +252,19 @@ func (h *httpFederationInternalAPI) LookupState( type lookupStateIDs struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomID string EventID string } func (h *httpFederationInternalAPI) LookupStateIDs( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, ) (gomatrixserverlib.RespStateIDs, error) { return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError]( "LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient, ctx, &lookupStateIDs{ S: s, + Origin: origin, RoomID: roomID, EventID: eventID, }, @@ -264,19 +273,21 @@ func (h *httpFederationInternalAPI) LookupStateIDs( type lookupMissingEvents struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomID string Missing gomatrixserverlib.MissingEvents RoomVersion gomatrixserverlib.RoomVersion } func (h *httpFederationInternalAPI) LookupMissingEvents( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.RespMissingEvents, err error) { return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError]( "LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient, ctx, &lookupMissingEvents{ S: s, + Origin: origin, RoomID: roomID, Missing: missing, RoomVersion: roomVersion, @@ -286,16 +297,18 @@ func (h *httpFederationInternalAPI) LookupMissingEvents( type getEvent struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName EventID string } func (h *httpFederationInternalAPI) GetEvent( - ctx context.Context, s gomatrixserverlib.ServerName, eventID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string, ) (gomatrixserverlib.Transaction, error) { return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError]( "GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient, ctx, &getEvent{ S: s, + Origin: origin, EventID: eventID, }, ) @@ -303,19 +316,21 @@ func (h *httpFederationInternalAPI) GetEvent( type getEventAuth struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomVersion gomatrixserverlib.RoomVersion RoomID string EventID string } func (h *httpFederationInternalAPI) GetEventAuth( - ctx context.Context, s gomatrixserverlib.ServerName, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, ) (gomatrixserverlib.RespEventAuth, error) { return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError]( "GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient, ctx, &getEventAuth{ S: s, + Origin: origin, RoomVersion: roomVersion, RoomID: roomID, EventID: eventID, @@ -351,18 +366,20 @@ func (h *httpFederationInternalAPI) LookupServerKeys( type eventRelationships struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName Req gomatrixserverlib.MSC2836EventRelationshipsRequest RoomVer gomatrixserverlib.RoomVersion } func (h *httpFederationInternalAPI) MSC2836EventRelationships( - ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, + ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError]( "MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient, ctx, &eventRelationships{ S: s, + Origin: origin, Req: r, RoomVer: roomVersion, }, @@ -371,17 +388,19 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships( type spacesReq struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName SuggestedOnly bool RoomID string } func (h *httpFederationInternalAPI) MSC2946Spaces( - ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, + ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError]( "MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient, ctx, &spacesReq{ S: dst, + Origin: origin, SuggestedOnly: suggestedOnly, RoomID: roomID, }, diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 7aa0e4801..7b3edb2a7 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -59,7 +59,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIGetUserDevices", func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) { - res, err := intAPI.GetUserDevices(ctx, req.S, req.UserID) + res, err := intAPI.GetUserDevices(ctx, req.Origin, req.S, req.UserID) return &res, federationClientError(err) }, ), @@ -70,7 +70,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIClaimKeys", func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) { - res, err := intAPI.ClaimKeys(ctx, req.S, req.OneTimeKeys) + res, err := intAPI.ClaimKeys(ctx, req.Origin, req.S, req.OneTimeKeys) return &res, federationClientError(err) }, ), @@ -81,7 +81,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIQueryKeys", func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) { - res, err := intAPI.QueryKeys(ctx, req.S, req.Keys) + res, err := intAPI.QueryKeys(ctx, req.Origin, req.S, req.Keys) return &res, federationClientError(err) }, ), @@ -92,7 +92,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIBackfill", func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) { - res, err := intAPI.Backfill(ctx, req.S, req.RoomID, req.Limit, req.EventIDs) + res, err := intAPI.Backfill(ctx, req.Origin, req.S, req.RoomID, req.Limit, req.EventIDs) return &res, federationClientError(err) }, ), @@ -103,7 +103,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPILookupState", func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) { - res, err := intAPI.LookupState(ctx, req.S, req.RoomID, req.EventID, req.RoomVersion) + res, err := intAPI.LookupState(ctx, req.Origin, req.S, req.RoomID, req.EventID, req.RoomVersion) return &res, federationClientError(err) }, ), @@ -114,7 +114,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPILookupStateIDs", func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) { - res, err := intAPI.LookupStateIDs(ctx, req.S, req.RoomID, req.EventID) + res, err := intAPI.LookupStateIDs(ctx, req.Origin, req.S, req.RoomID, req.EventID) return &res, federationClientError(err) }, ), @@ -125,7 +125,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPILookupMissingEvents", func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) { - res, err := intAPI.LookupMissingEvents(ctx, req.S, req.RoomID, req.Missing, req.RoomVersion) + res, err := intAPI.LookupMissingEvents(ctx, req.Origin, req.S, req.RoomID, req.Missing, req.RoomVersion) return &res, federationClientError(err) }, ), @@ -136,7 +136,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIGetEvent", func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) { - res, err := intAPI.GetEvent(ctx, req.S, req.EventID) + res, err := intAPI.GetEvent(ctx, req.Origin, req.S, req.EventID) return &res, federationClientError(err) }, ), @@ -147,7 +147,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIGetEventAuth", func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) { - res, err := intAPI.GetEventAuth(ctx, req.S, req.RoomVersion, req.RoomID, req.EventID) + res, err := intAPI.GetEventAuth(ctx, req.Origin, req.S, req.RoomVersion, req.RoomID, req.EventID) return &res, federationClientError(err) }, ), @@ -174,7 +174,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIMSC2836EventRelationships", func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { - res, err := intAPI.MSC2836EventRelationships(ctx, req.S, req.Req, req.RoomVer) + res, err := intAPI.MSC2836EventRelationships(ctx, req.Origin, req.S, req.Req, req.RoomVer) return &res, federationClientError(err) }, ), @@ -185,7 +185,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIMSC2946SpacesSummary", func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) { - res, err := intAPI.MSC2946Spaces(ctx, req.S, req.RoomID, req.SuggestedOnly) + res, err := intAPI.MSC2946Spaces(ctx, req.Origin, req.S, req.RoomID, req.SuggestedOnly) return &res, federationClientError(err) }, ), diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index 7b9ca66f6..006c0efa6 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -82,7 +82,8 @@ func Backfill( BackwardsExtremities: map[string][]string{ "": eIDs, }, - ServerName: request.Origin(), + ServerName: request.Origin(), + VirtualHost: request.Destination(), } if req.Limit, err = strconv.Atoi(limit); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") diff --git a/federationapi/routing/query.go b/federationapi/routing/query.go index 316c61a14..e6dc52601 100644 --- a/federationapi/routing/query.go +++ b/federationapi/routing/query.go @@ -83,7 +83,7 @@ func RoomAliasToID( } } } else { - resp, err = federation.LookupRoomAlias(httpReq.Context(), domain, roomAlias) + resp, err = federation.LookupRoomAlias(httpReq.Context(), domain, cfg.Matrix.ServerName, roomAlias) if err != nil { switch x := err.(type) { case gomatrix.HTTPError: diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index b3bbaa394..19134c9ee 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -164,7 +164,7 @@ func Send( util.GetLogger(httpReq.Context()).Debugf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, request.Origin(), len(t.PDUs), len(t.EDUs)) - resp, jsonErr := t.processTransaction(httpReq.Context()) + resp, jsonErr := t.processTransaction(httpReq.Context(), request.Destination()) if jsonErr != nil { util.GetLogger(httpReq.Context()).WithField("jsonErr", jsonErr).Error("t.processTransaction failed") return *jsonErr @@ -197,16 +197,16 @@ type txnReq struct { // A subset of FederationClient functionality that txn requires. Useful for testing. type txnFederationClient interface { - LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( + LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( res gomatrixserverlib.RespState, err error, ) - LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) - GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, + LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) + GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } -func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { +func (t *txnReq) processTransaction(ctx context.Context, virtualHost gomatrixserverlib.ServerName) (*gomatrixserverlib.RespSend, *util.JSONResponse) { var wg sync.WaitGroup wg.Add(1) go func() { @@ -287,6 +287,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res []*gomatrixserverlib.HeaderedEvent{ event.Headered(roomVersion), }, + t.Destination, t.Origin, api.DoNotSendToOtherServers, nil, diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 1c796f542..ebcbba3e7 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -147,7 +147,7 @@ type txnFedClient struct { getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) } -func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( +func (c *txnFedClient) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( res gomatrixserverlib.RespState, err error, ) { fmt.Println("testFederationClient.LookupState", eventID) @@ -159,7 +159,7 @@ func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.Serv res = r return } -func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { +func (c *txnFedClient) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { fmt.Println("testFederationClient.LookupStateIDs", eventID) r, ok := c.stateIDs[eventID] if !ok { @@ -169,7 +169,7 @@ func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.S res = r return } -func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { +func (c *txnFedClient) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { fmt.Println("testFederationClient.GetEvent", eventID) r, ok := c.getEvent[eventID] if !ok { @@ -179,7 +179,7 @@ func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerN res = r return } -func (c *txnFedClient) LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, +func (c *txnFedClient) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) { return c.getMissingEvents(missing) } @@ -199,7 +199,7 @@ func mustCreateTransaction(rsAPI api.FederationRoomserverAPI, fedClient txnFeder } func mustProcessTransaction(t *testing.T, txn *txnReq, pdusWithErrors []string) { - res, err := txn.processTransaction(context.Background()) + res, err := txn.processTransaction(context.Background(), "localhost") if err != nil { t.Errorf("txn.processTransaction returned an error: %v", err) return diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index ccde9168e..d07faef39 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -90,7 +90,17 @@ func CreateInvitesFrom3PIDInvites( } // Send all the events - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, evs, "TODO", cfg.Matrix.ServerName, nil, false); err != nil { + if err := api.SendEvents( + req.Context(), + rsAPI, + api.KindNew, + evs, + cfg.Matrix.ServerName, // TODO: which virtual host? + "TODO", + cfg.Matrix.ServerName, + nil, + false, + ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -126,6 +136,14 @@ func ExchangeThirdPartyInvite( } } + _, senderDomain, err := cfg.Matrix.SplitLocalID('@', builder.Sender) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("Invalid sender ID: " + err.Error()), + } + } + // Check that the state key is correct. _, targetDomain, err := gomatrixserverlib.SplitID('@', *builder.StateKey) if err != nil { @@ -171,7 +189,7 @@ func ExchangeThirdPartyInvite( util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request") return jsonerror.InternalServerError() } - signedEvent, err := federation.SendInviteV2(httpReq.Context(), request.Origin(), inviteReq) + signedEvent, err := federation.SendInviteV2(httpReq.Context(), senderDomain, request.Origin(), inviteReq) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") return jsonerror.InternalServerError() @@ -189,6 +207,7 @@ func ExchangeThirdPartyInvite( []*gomatrixserverlib.HeaderedEvent{ inviteEvent.Headered(verRes.RoomVersion), }, + request.Destination(), request.Origin(), cfg.Matrix.ServerName, nil, @@ -341,7 +360,7 @@ func buildMembershipEvent( // them responded with an error. func sendToRemoteServer( ctx context.Context, inv invite, - federation federationAPI.FederationClient, _ *config.FederationAPI, + federation federationAPI.FederationClient, cfg *config.FederationAPI, builder gomatrixserverlib.EventBuilder, ) (err error) { remoteServers := make([]gomatrixserverlib.ServerName, 2) @@ -357,7 +376,7 @@ func sendToRemoteServer( } for _, server := range remoteServers { - err = federation.ExchangeThirdPartyInvite(ctx, server, builder) + err = federation.ExchangeThirdPartyInvite(ctx, cfg.Matrix.ServerName, server, builder) if err == nil { return } diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 3f7c0d8b4..45f3c54e5 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -96,6 +96,7 @@ type DeviceListUpdater struct { producer KeyChangeProducer fedClient fedsenderapi.KeyserverFederationAPI workerChans []chan gomatrixserverlib.ServerName + thisServer gomatrixserverlib.ServerName // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will // block on or timeout via a select. @@ -139,6 +140,7 @@ func NewDeviceListUpdater( process *process.ProcessContext, db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, + thisServer gomatrixserverlib.ServerName, ) *DeviceListUpdater { return &DeviceListUpdater{ process: process, @@ -148,6 +150,7 @@ func NewDeviceListUpdater( api: api, producer: producer, fedClient: fedClient, + thisServer: thisServer, workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), userIDToChan: make(map[string]chan bool), userIDToChanMu: &sync.Mutex{}, @@ -436,7 +439,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go "user_id": userID, }) - res, err := u.fedClient.GetUserDevices(ctx, serverName, userID) + res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID) if err != nil { if errors.Is(err, context.DeadlineExceeded) { return time.Minute * 10, err diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index 28a13a0a0..86fb283ec 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -129,7 +129,13 @@ func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient { _, pkey, _ := ed25519.GenerateKey(nil) fedClient := gomatrixserverlib.NewFederationClient( - gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey, + []*gomatrixserverlib.SigningIdentity{ + { + ServerName: gomatrixserverlib.ServerName("example.test"), + KeyID: gomatrixserverlib.KeyID("ed25519:test"), + PrivateKey: pkey, + }, + }, ) fedClient.Client = *gomatrixserverlib.NewClient( gomatrixserverlib.WithTransport(&roundTripper{tripper}), @@ -147,7 +153,7 @@ func TestUpdateHavePrevID(t *testing.T) { } ap := &mockDeviceListUpdaterAPI{} producer := &mockKeyChangeProducer{} - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, "localhost") event := gomatrixserverlib.DeviceListUpdateEvent{ DeviceDisplayName: "Foo Bar", Deleted: false, @@ -219,7 +225,7 @@ func TestUpdateNoPrevID(t *testing.T) { `)), }, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, "localhost") if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } @@ -288,7 +294,7 @@ func TestDebounce(t *testing.T) { close(incomingFedReq) return <-fedCh, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, "localhost") if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 37c55c8f8..9a08a0bb7 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -146,7 +146,7 @@ func (a *KeyInternalAPI) claimRemoteKeys( defer cancel() defer wg.Done() - claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) + claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim) mu.Lock() defer mu.Unlock() @@ -559,7 +559,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( if len(devKeys) == 0 { return } - queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys) + queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys) if err == nil { resultCh <- &queryKeysResp return diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 0a4b8fde6..a86c2da4e 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -58,7 +58,7 @@ func NewInternalAPI( FedClient: fedClient, Producer: keyChangeProducer, } - updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable + updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, cfg.Matrix.ServerName) // 8 workers TODO: configurable ap.Updater = updater go func() { if err := updater.Start(); err != nil { diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 45a9ef497..88d523270 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -94,8 +94,9 @@ type TransactionID struct { // InputRoomEventsRequest is a request to InputRoomEvents type InputRoomEventsRequest struct { - InputRoomEvents []InputRoomEvent `json:"input_room_events"` - Asynchronous bool `json:"async"` + InputRoomEvents []InputRoomEvent `json:"input_room_events"` + Asynchronous bool `json:"async"` + VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"` } // InputRoomEventsResponse is a response to InputRoomEvents diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 2536a4bb5..e70e5ea9c 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -148,6 +148,8 @@ type PerformBackfillRequest struct { Limit int `json:"limit"` // The server interested in the events. ServerName gomatrixserverlib.ServerName `json:"server_name"` + // Which virtual host are we doing this for? + VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"` } // PrevEventIDs returns the prev_event IDs of all backwards extremities, de-duplicated in a lexicographically sorted order. diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 8b031982c..252be557f 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -26,7 +26,7 @@ import ( func SendEvents( ctx context.Context, rsAPI InputRoomEventsAPI, kind Kind, events []*gomatrixserverlib.HeaderedEvent, - origin gomatrixserverlib.ServerName, + virtualHost, origin gomatrixserverlib.ServerName, sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, async bool, ) error { @@ -40,14 +40,15 @@ func SendEvents( TransactionID: txnID, } } - return SendInputRoomEvents(ctx, rsAPI, ires, async) + return SendInputRoomEvents(ctx, rsAPI, virtualHost, ires, async) } // SendEventWithState writes an event with the specified kind to the roomserver // with the state at the event as KindOutlier before it. Will not send any event that is // marked as `true` in haveEventIDs. func SendEventWithState( - ctx context.Context, rsAPI InputRoomEventsAPI, kind Kind, + ctx context.Context, rsAPI InputRoomEventsAPI, + virtualHost gomatrixserverlib.ServerName, kind Kind, state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, ) error { @@ -85,17 +86,19 @@ func SendEventWithState( StateEventIDs: stateEventIDs, }) - return SendInputRoomEvents(ctx, rsAPI, ires, async) + return SendInputRoomEvents(ctx, rsAPI, virtualHost, ires, async) } // SendInputRoomEvents to the roomserver. func SendInputRoomEvents( ctx context.Context, rsAPI InputRoomEventsAPI, + virtualHost gomatrixserverlib.ServerName, ires []InputRoomEvent, async bool, ) error { request := InputRoomEventsRequest{ InputRoomEvents: ires, Asynchronous: async, + VirtualHost: virtualHost, } var response InputRoomEventsResponse if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil { diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 175bb9310..8e06e3d5a 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -137,6 +137,11 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( request *api.RemoveRoomAliasRequest, response *api.RemoveRoomAliasResponse, ) error { + _, virtualHost, err := r.Cfg.Matrix.SplitLocalID('@', request.UserID) + if err != nil { + return err + } + roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) if err != nil { return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err) @@ -216,7 +221,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return err } - err = api.SendEvents(ctx, r, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, r.ServerName, r.ServerName, nil, false) + err = api.SendEvents(ctx, r, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, virtualHost, r.ServerName, r.ServerName, nil, false) if err != nil { return err } diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 1a11586a5..d0ce57537 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -163,10 +163,10 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio DB: r.DB, } r.Backfiller = &perform.Backfiller{ - ServerName: r.ServerName, - DB: r.DB, - FSAPI: r.fsAPI, - KeyRing: r.KeyRing, + IsLocalServerName: r.Cfg.Matrix.IsLocalServerName, + DB: r.DB, + FSAPI: r.fsAPI, + KeyRing: r.KeyRing, // Perspective servers are trusted to not lie about server keys, so we will also // prefer these servers when backfilling (assuming they are in the room) rather // than trying random servers diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index f5099ca11..e965691c9 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -278,7 +278,11 @@ func (w *worker) _next() { // a string, because we might want to return that to the caller if // it was a synchronous request. var errString string - if err = w.r.processRoomEvent(w.r.ProcessContext.Context(), &inputRoomEvent); err != nil { + if err = w.r.processRoomEvent( + w.r.ProcessContext.Context(), + gomatrixserverlib.ServerName(msg.Header.Get("virtual_host")), + &inputRoomEvent, + ); err != nil { switch err.(type) { case types.RejectedError: // Don't send events that were rejected to Sentry @@ -358,6 +362,7 @@ func (r *Inputer) queueInputRoomEvents( if replyTo != "" { msg.Header.Set("sync", replyTo) } + msg.Header.Set("virtual_host", string(request.VirtualHost)) msg.Data, err = json.Marshal(e) if err != nil { return nil, fmt.Errorf("json.Marshal: %w", err) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 682aa2b12..1fc28ff77 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -69,6 +69,7 @@ var processRoomEventDuration = prometheus.NewHistogramVec( // nolint:gocyclo func (r *Inputer) processRoomEvent( ctx context.Context, + virtualHost gomatrixserverlib.ServerName, input *api.InputRoomEvent, ) error { select { @@ -200,7 +201,7 @@ func (r *Inputer) processRoomEvent( isRejected := false authEvents := gomatrixserverlib.NewAuthEvents(nil) knownEvents := map[string]*types.Event{} - if err = r.fetchAuthEvents(ctx, logger, roomInfo, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { return fmt.Errorf("r.fetchAuthEvents: %w", err) } @@ -555,6 +556,7 @@ func (r *Inputer) fetchAuthEvents( ctx context.Context, logger *logrus.Entry, roomInfo *types.RoomInfo, + virtualHost gomatrixserverlib.ServerName, event *gomatrixserverlib.HeaderedEvent, auth *gomatrixserverlib.AuthEvents, known map[string]*types.Event, @@ -605,7 +607,7 @@ func (r *Inputer) fetchAuthEvents( // Request the entire auth chain for the event in question. This should // contain all of the auth events — including ones that we already know — // so we'll need to filter through those in the next section. - res, err = r.FSAPI.GetEventAuth(ctx, serverName, event.RoomVersion, event.RoomID(), event.EventID()) + res, err = r.FSAPI.GetEventAuth(ctx, virtualHost, serverName, event.RoomVersion, event.RoomID(), event.EventID()) if err != nil { logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err) continue diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index f788c5b3a..03ac2b38d 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -41,6 +41,7 @@ func (p *parsedRespState) Events() []*gomatrixserverlib.Event { type missingStateReq struct { log *logrus.Entry + virtualHost gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName db storage.Database roomInfo *types.RoomInfo @@ -101,7 +102,7 @@ func (t *missingStateReq) processEventWithMissingState( // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // in the gap in the DAG for _, newEvent := range newEvents { - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -157,7 +158,7 @@ func (t *missingStateReq) processEventWithMissingState( }) } for _, ire := range outlierRoomEvents { - if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { + if err = t.inputer.processRoomEvent(ctx, t.virtualHost, &ire); err != nil { if _, ok := err.(types.RejectedError); !ok { return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err) } @@ -180,7 +181,7 @@ func (t *missingStateReq) processEventWithMissingState( stateIDs = append(stateIDs, event.EventID()) } - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ Kind: api.KindOld, Event: backwardsExtremity.Headered(roomVersion), Origin: t.origin, @@ -199,7 +200,7 @@ func (t *missingStateReq) processEventWithMissingState( // they will automatically fast-forward based on the room state at the // extremity in the last step. for _, newEvent := range newEvents { - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -519,7 +520,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve var missingResp *gomatrixserverlib.RespMissingEvents for _, server := range t.servers { var m gomatrixserverlib.RespMissingEvents - if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ + if m, err = t.federation.LookupMissingEvents(ctx, t.virtualHost, server, e.RoomID(), gomatrixserverlib.MissingEvents{ Limit: 20, // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. EarliestEvents: latestEvents, @@ -635,7 +636,7 @@ func (t *missingStateReq) lookupMissingStateViaState( span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaState") defer span.Finish() - state, err := t.federation.LookupState(ctx, t.origin, roomID, eventID, roomVersion) + state, err := t.federation.LookupState(ctx, t.virtualHost, t.origin, roomID, eventID, roomVersion) if err != nil { return nil, err } @@ -675,7 +676,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo totalctx, totalcancel := context.WithTimeout(ctx, time.Minute*5) for _, serverName := range t.servers { reqctx, reqcancel := context.WithTimeout(totalctx, time.Second*20) - stateIDs, err = t.federation.LookupStateIDs(reqctx, serverName, roomID, eventID) + stateIDs, err = t.federation.LookupStateIDs(reqctx, t.virtualHost, serverName, roomID, eventID) reqcancel() if err == nil { break @@ -855,7 +856,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs for _, serverName := range t.servers { reqctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() - txn, err := t.federation.GetEvent(reqctx, serverName, missingEventID) + txn, err := t.federation.GetEvent(reqctx, t.virtualHost, serverName, missingEventID) if err != nil || len(txn.PDUs) == 0 { t.log.WithError(err).WithField("missing_event_id", missingEventID).Warn("Failed to get missing /event for event ID") if errors.Is(err, context.DeadlineExceeded) { diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 0406fc8f4..131ff3c6e 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -273,7 +273,7 @@ func (r *Admin) PerformAdminDownloadState( for _, fwdExtremity := range fwdExtremities { var state gomatrixserverlib.RespState - state, err = r.Inputer.FSAPI.LookupState(ctx, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion) + state, err = r.Inputer.FSAPI.LookupState(ctx, r.Inputer.ServerName, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 57e121ea2..97d8e2f32 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -37,10 +37,10 @@ import ( const maxBackfillServers = 5 type Backfiller struct { - ServerName gomatrixserverlib.ServerName - DB storage.Database - FSAPI federationAPI.RoomserverFederationAPI - KeyRing gomatrixserverlib.JSONVerifier + IsLocalServerName func(gomatrixserverlib.ServerName) bool + DB storage.Database + FSAPI federationAPI.RoomserverFederationAPI + KeyRing gomatrixserverlib.JSONVerifier // The servers which should be preferred above other servers when backfilling PreferServers []gomatrixserverlib.ServerName @@ -55,7 +55,7 @@ func (r *Backfiller) PerformBackfill( // if we are requesting the backfill then we need to do a federation hit // TODO: we could be more sensible and fetch as many events we already have then request the rest // which is what the syncapi does already. - if request.ServerName == r.ServerName { + if !r.IsLocalServerName(request.ServerName) { return r.backfillViaFederation(ctx, request, response) } // someone else is requesting the backfill, try to service their request. @@ -112,14 +112,14 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform if info == nil || info.IsStub() { return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID) } - requester := newBackfillRequester(r.DB, r.FSAPI, r.ServerName, req.BackwardsExtremities, r.PreferServers) + requester := newBackfillRequester(r.DB, r.FSAPI, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers) // Request 100 items regardless of what the query asks for. // We don't want to go much higher than this. // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass // (so we don't need to hit /state_ids which the test has no listener for) // Specifically the test "Outbound federation can backfill events" events, err := gomatrixserverlib.RequestBackfill( - ctx, requester, + ctx, req.VirtualHost, requester, r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100) if err != nil { return err @@ -144,7 +144,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform var entries []types.StateEntry if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true); err != nil { // attempt to fetch the missing events - r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs) + r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs, req.VirtualHost) // try again entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true) if err != nil { @@ -173,7 +173,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just // best effort. func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, - backfillRequester *backfillRequester, stateIDs []string) { + backfillRequester *backfillRequester, stateIDs []string, virtualHost gomatrixserverlib.ServerName) { servers := backfillRequester.servers @@ -198,7 +198,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom continue // already found } logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id) - res, err := r.FSAPI.GetEvent(ctx, srv, id) + res, err := r.FSAPI.GetEvent(ctx, virtualHost, srv, id) if err != nil { logger.WithError(err).Warn("failed to get event from server") continue @@ -241,11 +241,12 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom // backfillRequester implements gomatrixserverlib.BackfillRequester type backfillRequester struct { - db storage.Database - fsAPI federationAPI.RoomserverFederationAPI - thisServer gomatrixserverlib.ServerName - preferServer map[gomatrixserverlib.ServerName]bool - bwExtrems map[string][]string + db storage.Database + fsAPI federationAPI.RoomserverFederationAPI + virtualHost gomatrixserverlib.ServerName + isLocalServerName func(gomatrixserverlib.ServerName) bool + preferServer map[gomatrixserverlib.ServerName]bool + bwExtrems map[string][]string // per-request state servers []gomatrixserverlib.ServerName @@ -255,7 +256,9 @@ type backfillRequester struct { } func newBackfillRequester( - db storage.Database, fsAPI federationAPI.RoomserverFederationAPI, thisServer gomatrixserverlib.ServerName, + db storage.Database, fsAPI federationAPI.RoomserverFederationAPI, + virtualHost gomatrixserverlib.ServerName, + isLocalServerName func(gomatrixserverlib.ServerName) bool, bwExtrems map[string][]string, preferServers []gomatrixserverlib.ServerName, ) *backfillRequester { preferServer := make(map[gomatrixserverlib.ServerName]bool) @@ -265,7 +268,8 @@ func newBackfillRequester( return &backfillRequester{ db: db, fsAPI: fsAPI, - thisServer: thisServer, + virtualHost: virtualHost, + isLocalServerName: isLocalServerName, eventIDToBeforeStateIDs: make(map[string][]string), eventIDMap: make(map[string]*gomatrixserverlib.Event), bwExtrems: bwExtrems, @@ -450,7 +454,7 @@ FindSuccessor: } // possibly return all joined servers depending on history visiblity - memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.thisServer) + memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.virtualHost) b.historyVisiblity = visibility if err != nil { logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") @@ -477,7 +481,7 @@ FindSuccessor: } var servers []gomatrixserverlib.ServerName for server := range serverSet { - if server == b.thisServer { + if b.isLocalServerName(server) { continue } if b.preferServer[server] { // insert at the front @@ -496,10 +500,10 @@ FindSuccessor: // Backfill performs a backfill request to the given server. // https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid -func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string, +func (b *backfillRequester) Backfill(ctx context.Context, origin, server gomatrixserverlib.ServerName, roomID string, limit int, fromEventIDs []string) (gomatrixserverlib.Transaction, error) { - tx, err := b.fsAPI.Backfill(ctx, server, roomID, limit, fromEventIDs) + tx, err := b.fsAPI.Backfill(ctx, origin, server, roomID, limit, fromEventIDs) return tx, err } diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 38abe323c..f627e86fc 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -60,7 +60,7 @@ func (r *Upgrader) performRoomUpgrade( ) (string, *api.PerformError) { roomID := req.RoomID userID := req.UserID - _, userDomain, err := gomatrixserverlib.SplitID('@', userID) + _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) if err != nil { return "", &api.PerformError{ Code: api.PerformErrorNotAllowed, @@ -558,7 +558,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user SendAsServer: api.DoNotSendToOtherServers, }) } - if err = api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { + if err = api.SendInputRoomEvents(ctx, r.URSAPI, userDomain, inputs, false); err != nil { return &api.PerformError{ Msg: fmt.Sprintf("Failed to send new room %q to roomserver: %s", newRoomID, err), } @@ -686,7 +686,7 @@ func (r *Upgrader) sendHeaderedEvent( Origin: serverName, SendAsServer: sendAsServer, }) - if err := api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { + if err := api.SendInputRoomEvents(ctx, r.URSAPI, serverName, inputs, false); err != nil { return &api.PerformError{ Msg: fmt.Sprintf("Failed to send new %q event to roomserver: %s", headeredEvent.Type(), err), } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 4e98af853..24b5515e5 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -44,7 +44,7 @@ func Test_SharedUsers(t *testing.T) { // SetFederationAPI starts the room event input consumer rsAPI.SetFederationAPI(nil, nil) // Create the room - if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", nil, false); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } diff --git a/setup/base/base.go b/setup/base/base.go index 2e3a3a195..2d6548483 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -364,10 +364,22 @@ func (b *BaseDendrite) CreateClient() *gomatrixserverlib.Client { // CreateFederationClient creates a new federation client. Should only be called // once per component. func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationClient { + identities := make([]*gomatrixserverlib.SigningIdentity, 0, 1+len(b.Cfg.Global.SecondaryServerNames)) + identities = append(identities, &gomatrixserverlib.SigningIdentity{ + ServerName: b.Cfg.Global.ServerName, + KeyID: b.Cfg.Global.KeyID, + PrivateKey: b.Cfg.Global.PrivateKey, + }) + for _, serverName := range b.Cfg.Global.SecondaryServerNames { + identities = append(identities, &gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + KeyID: b.Cfg.Global.KeyID, // TODO: Per-virtual host key + PrivateKey: b.Cfg.Global.PrivateKey, // TODO: Per-virtual host key + }) + } if b.Cfg.Global.DisableFederation { return gomatrixserverlib.NewFederationClient( - b.Cfg.Global.ServerName, b.Cfg.Global.KeyID, b.Cfg.Global.PrivateKey, - gomatrixserverlib.WithTransport(noOpHTTPTransport), + identities, gomatrixserverlib.WithTransport(noOpHTTPTransport), ) } opts := []gomatrixserverlib.ClientOption{ @@ -379,8 +391,7 @@ func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationCli opts = append(opts, gomatrixserverlib.WithDNSCache(b.DNSCache)) } client := gomatrixserverlib.NewFederationClient( - b.Cfg.Global.ServerName, b.Cfg.Global.KeyID, - b.Cfg.Global.PrivateKey, opts..., + identities, opts..., ) client.SetUserAgent(fmt.Sprintf("Dendrite/%s", internal.VersionString())) return client diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 825772827..fe34b5cac 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -1,6 +1,7 @@ package config import ( + "fmt" "math/rand" "strconv" "strings" @@ -135,6 +136,34 @@ func (c *Global) IsLocalServerName(serverName gomatrixserverlib.ServerName) bool return false } +func (c *Global) SplitLocalID(sigil byte, id string) (string, gomatrixserverlib.ServerName, error) { + u, s, err := gomatrixserverlib.SplitID(sigil, id) + if err != nil { + return u, s, err + } + if !c.IsLocalServerName(s) { + return u, s, fmt.Errorf("server name not locally configured") + } + return u, s, nil +} + +func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity { + identities := make([]*gomatrixserverlib.SigningIdentity, 0, len(c.SecondaryServerNames)+1) + identities = append(identities, &gomatrixserverlib.SigningIdentity{ + ServerName: c.ServerName, + KeyID: c.KeyID, + PrivateKey: c.PrivateKey, + }) + for _, serverName := range c.SecondaryServerNames { + identities = append(identities, &gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + KeyID: c.KeyID, + PrivateKey: c.PrivateKey, + }) + } + return identities +} + type OldVerifyKeys struct { // Path to the private key. PrivateKeyPath Path `yaml:"private_key"` diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 98502f5cb..bc369c166 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -397,7 +397,7 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen serversToQuery := rc.getServersForEventID(parentID) var result *MSC2836EventRelationshipsResponse for _, srv := range serversToQuery { - res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ + res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, rc.serverName, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ EventID: parentID, Direction: "down", Limit: 100, @@ -484,7 +484,7 @@ func walkThread( // MSC2836EventRelationships performs an /event_relationships request to a remote server func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*MSC2836EventRelationshipsResponse, error) { - res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ + res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, rc.serverName, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ EventID: eventID, DepthFirst: rc.req.DepthFirst, Direction: rc.req.Direction, @@ -665,7 +665,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo }) } // we've got the data by this point so use a background context - err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false) + err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, rc.serverName, ires, false) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver") } diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index d7c022544..56c063598 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -433,7 +433,7 @@ func (w *walker) federatedRoomInfo(roomID string, vias []string) *gomatrixserver if serverName == string(w.thisServer) { continue } - res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly) + res, err := w.fsAPI.MSC2946Spaces(ctx, w.thisServer, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly) if err != nil { util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName) continue diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index a4985dbf4..483274481 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -433,7 +433,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { beforeJoinBody := fmt.Sprintf("Before invite in a %s room", tc.historyVisibility) beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": beforeJoinBody}) eventsToSend := append(room.Events(), beforeJoinEv) - if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } syncUntil(t, base, aliceDev.AccessToken, false, @@ -472,7 +472,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv) - if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } syncUntil(t, base, aliceDev.AccessToken, false,