Federation fixes for virtual hosting

This commit is contained in:
Neil Alexander 2022-11-15 15:05:23 +00:00
parent f4ee397734
commit 6650712a1c
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
73 changed files with 736 additions and 420 deletions

View file

@ -110,7 +110,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
JSON: jsonerror.MissingArgument("Expecting user localpart."), JSON: jsonerror.MissingArgument("Expecting user localpart."),
} }
} }
if l, s, err := gomatrixserverlib.SplitID('@', localpart); err == nil { if l, s, err := cfg.Matrix.SplitLocalID('@', localpart); err == nil {
localpart, serverName = l, s localpart, serverName = l, s
} }
request := struct { request := struct {

View file

@ -477,7 +477,7 @@ func createRoom(
SendAsServer: roomserverAPI.DoNotSendToOtherServers, SendAsServer: roomserverAPI.DoNotSendToOtherServers,
}) })
} }
if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, inputs, false); err != nil { if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, device.UserDomain(), inputs, false); err != nil {
util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

View file

@ -77,7 +77,7 @@ func DirectoryRoom(
// If we don't know it locally, do a federation query. // If we don't know it locally, do a federation query.
// But don't send the query to ourselves. // But don't send the query to ourselves.
if !cfg.Matrix.IsLocalServerName(domain) { if !cfg.Matrix.IsLocalServerName(domain) {
fedRes, fedErr := federation.LookupRoomAlias(req.Context(), domain, roomAlias) fedRes, fedErr := federation.LookupRoomAlias(req.Context(), cfg.Matrix.ServerName, domain, roomAlias)
if fedErr != nil { if fedErr != nil {
// TODO: Return 502 if the remote server errored. // TODO: Return 502 if the remote server errored.
// TODO: Return 504 if the remote server timed out. // TODO: Return 504 if the remote server timed out.

View file

@ -74,7 +74,7 @@ func GetPostPublicRooms(
serverName := gomatrixserverlib.ServerName(request.Server) serverName := gomatrixserverlib.ServerName(request.Server)
if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) { if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) {
res, err := federation.GetPublicRoomsFiltered( res, err := federation.GetPublicRoomsFiltered(
req.Context(), serverName, req.Context(), cfg.Matrix.ServerName, serverName,
int(request.Limit), request.Since, int(request.Limit), request.Since,
request.Filter.SearchTerms, false, request.Filter.SearchTerms, false,
"", "",

View file

@ -110,6 +110,7 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic
ctx, rsAPI, ctx, rsAPI,
roomserverAPI.KindNew, roomserverAPI.KindNew,
[]*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, []*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)},
device.UserDomain(),
serverName, serverName,
serverName, serverName,
nil, nil,
@ -322,7 +323,12 @@ func buildMembershipEvent(
return nil, err return nil, err
} }
return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
return nil, err
}
return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, nil)
} }
// loadProfile lookups the profile of a given user from the database and returns // loadProfile lookups the profile of a given user from the database and returns

View file

@ -284,7 +284,7 @@ func updateProfile(
} }
events, err := buildMembershipEvents( events, err := buildMembershipEvents(
ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, ctx, device, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI,
) )
switch e := err.(type) { switch e := err.(type) {
case nil: case nil:
@ -298,7 +298,7 @@ func updateProfile(
return jsonerror.InternalServerError(), e return jsonerror.InternalServerError(), e
} }
if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, domain, domain, nil, true); err != nil { if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, device.UserDomain(), domain, domain, nil, true); err != nil {
util.GetLogger(ctx).WithError(err).Error("SendEvents failed") util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError(), err return jsonerror.InternalServerError(), err
} }
@ -321,7 +321,7 @@ func getProfile(
} }
if !cfg.Matrix.IsLocalServerName(domain) { if !cfg.Matrix.IsLocalServerName(domain) {
profile, fedErr := federation.LookupProfile(ctx, domain, userID, "") profile, fedErr := federation.LookupProfile(ctx, cfg.Matrix.ServerName, domain, userID, "")
if fedErr != nil { if fedErr != nil {
if x, ok := fedErr.(gomatrix.HTTPError); ok { if x, ok := fedErr.(gomatrix.HTTPError); ok {
if x.Code == http.StatusNotFound { if x.Code == http.StatusNotFound {
@ -349,6 +349,7 @@ func getProfile(
func buildMembershipEvents( func buildMembershipEvents(
ctx context.Context, ctx context.Context,
device *userapi.Device,
roomIDs []string, roomIDs []string,
newProfile authtypes.Profile, userID string, cfg *config.ClientAPI, newProfile authtypes.Profile, userID string, cfg *config.ClientAPI,
evTime time.Time, rsAPI api.ClientRoomserverAPI, evTime time.Time, rsAPI api.ClientRoomserverAPI,
@ -380,7 +381,12 @@ func buildMembershipEvents(
return nil, err return nil, err
} }
event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
return nil, err
}
event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -123,8 +123,13 @@ func SendRedaction(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
return jsonerror.InternalServerError()
}
var queryRes roomserverAPI.QueryLatestEventsAndStateResponse var queryRes roomserverAPI.QueryLatestEventsAndStateResponse
e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
@ -132,7 +137,7 @@ func SendRedaction(
} }
} }
domain := device.UserDomain() domain := device.UserDomain()
if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, domain, domain, nil, false); err != nil { if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, device.UserDomain(), domain, domain, nil, false); err != nil {
util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

View file

@ -211,9 +211,10 @@ var (
// previous parameters with the ones supplied. This mean you cannot "build up" request params. // previous parameters with the ones supplied. This mean you cannot "build up" request params.
type registerRequest struct { type registerRequest struct {
// registration parameters // registration parameters
Password string `json:"password"` Password string `json:"password"`
Username string `json:"username"` Username string `json:"username"`
Admin bool `json:"admin"` ServerName gomatrixserverlib.ServerName `json:"-"`
Admin bool `json:"admin"`
// user-interactive auth params // user-interactive auth params
Auth authDict `json:"auth"` Auth authDict `json:"auth"`
@ -570,11 +571,14 @@ func Register(
JSON: response, JSON: response,
} }
} }
} }
if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil {
return *resErr return *resErr
} }
r.ServerName = cfg.Matrix.ServerName
if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil {
r.Username, r.ServerName = l, d
}
if req.URL.Query().Get("kind") == "guest" { if req.URL.Query().Get("kind") == "guest" {
return handleGuestRegistration(req, r, cfg, userAPI) return handleGuestRegistration(req, r, cfg, userAPI)
} }
@ -589,7 +593,7 @@ func Register(
// Auto generate a numeric username if r.Username is empty // Auto generate a numeric username if r.Username is empty
if r.Username == "" { if r.Username == "" {
nreq := &userapi.QueryNumericLocalpartRequest{ nreq := &userapi.QueryNumericLocalpartRequest{
ServerName: cfg.Matrix.ServerName, // TODO: might not be right ServerName: r.ServerName,
} }
nres := &userapi.QueryNumericLocalpartResponse{} nres := &userapi.QueryNumericLocalpartResponse{}
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
@ -609,7 +613,7 @@ func Register(
case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil:
// Spec-compliant case (the access_token is specified and the login type // Spec-compliant case (the access_token is specified and the login type
// is correctly set, so it's an appservice registration) // is correctly set, so it's an appservice registration)
if resErr := validateApplicationServiceUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil {
return *resErr return *resErr
} }
case accessTokenErr == nil: case accessTokenErr == nil:
@ -622,7 +626,7 @@ func Register(
default: default:
// Spec-compliant case (neither the access_token nor the login type are // Spec-compliant case (neither the access_token nor the login type are
// specified, so it's a normal user registration) // specified, so it's a normal user registration)
if resErr := validateUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { if resErr := validateUsername(r.Username, r.ServerName); resErr != nil {
return *resErr return *resErr
} }
} }
@ -1023,13 +1027,27 @@ func RegisterAvailable(
// Squash username to all lowercase letters // Squash username to all lowercase letters
username = strings.ToLower(username) username = strings.ToLower(username)
domain := cfg.Matrix.ServerName
if u, l, err := cfg.Matrix.SplitLocalID('@', username); err == nil {
username, domain = u, l
}
for _, v := range cfg.Matrix.VirtualHosts {
if v.ServerName == domain && !v.AllowRegistration {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(
fmt.Sprintf("Registration is not allowed on %q", string(v.ServerName)),
),
}
}
}
if err := validateUsername(username, cfg.Matrix.ServerName); err != nil { if err := validateUsername(username, domain); err != nil {
return *err return *err
} }
// Check if this username is reserved by an application service // Check if this username is reserved by an application service
userID := userutil.MakeUserID(username, cfg.Matrix.ServerName) userID := userutil.MakeUserID(username, domain)
for _, appservice := range cfg.Derived.ApplicationServices { for _, appservice := range cfg.Derived.ApplicationServices {
if appservice.OwnsNamespaceCoveringUserId(userID) { if appservice.OwnsNamespaceCoveringUserId(userID) {
return util.JSONResponse{ return util.JSONResponse{
@ -1041,7 +1059,8 @@ func RegisterAvailable(
res := &userapi.QueryAccountAvailabilityResponse{} res := &userapi.QueryAccountAvailabilityResponse{}
err := registerAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{ err := registerAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{
Localpart: username, Localpart: username,
ServerName: domain,
}, res) }, res)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{

View file

@ -186,6 +186,7 @@ func SendEvent(
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
e.Headered(verRes.RoomVersion), e.Headered(verRes.RoomVersion),
}, },
device.UserDomain(),
domain, domain,
domain, domain,
txnAndSessionID, txnAndSessionID,
@ -275,8 +276,14 @@ func generateSendEvent(
return nil, &resErr return nil, &resErr
} }
identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
resErr := jsonerror.InternalServerError()
return nil, &resErr
}
var queryRes api.QueryLatestEventsAndStateResponse var queryRes api.QueryLatestEventsAndStateResponse
e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, &queryRes) e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,

View file

@ -231,6 +231,7 @@ func SendServerNotice(
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
e.Headered(roomVersion), e.Headered(roomVersion),
}, },
device.UserDomain(),
cfgClient.Matrix.ServerName, cfgClient.Matrix.ServerName,
cfgClient.Matrix.ServerName, cfgClient.Matrix.ServerName,
txnAndSessionID, txnAndSessionID,

View file

@ -106,7 +106,7 @@ knownUsersLoop:
continue continue
} }
// TODO: We should probably cache/store this // TODO: We should probably cache/store this
fedProfile, fedErr := federation.LookupProfile(ctx, serverName, userID, "") fedProfile, fedErr := federation.LookupProfile(ctx, localServerName, serverName, userID, "")
if fedErr != nil { if fedErr != nil {
if x, ok := fedErr.(gomatrix.HTTPError); ok { if x, ok := fedErr.(gomatrix.HTTPError); ok {
if x.Code == http.StatusNotFound { if x.Code == http.StatusNotFound {

View file

@ -359,8 +359,13 @@ func emit3PIDInviteEvent(
return err return err
} }
identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain())
if err != nil {
return err
}
queryRes := api.QueryLatestEventsAndStateResponse{} queryRes := api.QueryLatestEventsAndStateResponse{}
event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, evTime, rsAPI, &queryRes) event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, identity, evTime, rsAPI, &queryRes)
if err != nil { if err != nil {
return err return err
} }
@ -371,6 +376,7 @@ func emit3PIDInviteEvent(
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
event.Headered(queryRes.RoomVersion), event.Headered(queryRes.RoomVersion),
}, },
device.UserDomain(),
cfg.Matrix.ServerName, cfg.Matrix.ServerName,
cfg.Matrix.ServerName, cfg.Matrix.ServerName,
nil, nil,

View file

@ -101,9 +101,7 @@ func CreateFederationClient(
base *base.BaseDendrite, s *pineconeSessions.Sessions, base *base.BaseDendrite, s *pineconeSessions.Sessions,
) *gomatrixserverlib.FederationClient { ) *gomatrixserverlib.FederationClient {
return gomatrixserverlib.NewFederationClient( return gomatrixserverlib.NewFederationClient(
base.Cfg.Global.ServerName, base.Cfg.Global.SigningIdentities(),
base.Cfg.Global.KeyID,
base.Cfg.Global.PrivateKey,
gomatrixserverlib.WithTransport(createTransport(s)), gomatrixserverlib.WithTransport(createTransport(s)),
) )
} }

View file

@ -58,13 +58,17 @@ func (p *PineconeRoomProvider) Rooms() []gomatrixserverlib.PublicRoom {
for _, k := range p.r.Peers() { for _, k := range p.r.Peers() {
list[gomatrixserverlib.ServerName(k.PublicKey)] = struct{}{} list[gomatrixserverlib.ServerName(k.PublicKey)] = struct{}{}
} }
return bulkFetchPublicRoomsFromServers(context.Background(), p.fedClient, list) return bulkFetchPublicRoomsFromServers(
context.Background(), p.fedClient,
gomatrixserverlib.ServerName(p.r.PublicKey().String()), list,
)
} }
// bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers.
// Returns a list of public rooms. // Returns a list of public rooms.
func bulkFetchPublicRoomsFromServers( func bulkFetchPublicRoomsFromServers(
ctx context.Context, fedClient *gomatrixserverlib.FederationClient, ctx context.Context, fedClient *gomatrixserverlib.FederationClient,
origin gomatrixserverlib.ServerName,
homeservers map[gomatrixserverlib.ServerName]struct{}, homeservers map[gomatrixserverlib.ServerName]struct{},
) (publicRooms []gomatrixserverlib.PublicRoom) { ) (publicRooms []gomatrixserverlib.PublicRoom) {
limit := 200 limit := 200
@ -82,7 +86,7 @@ func bulkFetchPublicRoomsFromServers(
go func(homeserverDomain gomatrixserverlib.ServerName) { go func(homeserverDomain gomatrixserverlib.ServerName) {
defer wg.Done() defer wg.Done()
util.GetLogger(reqctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") util.GetLogger(reqctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms")
fres, err := fedClient.GetPublicRooms(reqctx, homeserverDomain, int(limit), "", false, "") fres, err := fedClient.GetPublicRooms(reqctx, origin, homeserverDomain, int(limit), "", false, "")
if err != nil { if err != nil {
util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Warn( util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Warn(
"bulkFetchPublicRoomsFromServers: failed to query hs", "bulkFetchPublicRoomsFromServers: failed to query hs",

View file

@ -55,8 +55,7 @@ func (n *Node) CreateFederationClient(
}, },
) )
return gomatrixserverlib.NewFederationClient( return gomatrixserverlib.NewFederationClient(
base.Cfg.Global.ServerName, base.Cfg.Global.KeyID, base.Cfg.Global.SigningIdentities(),
base.Cfg.Global.PrivateKey,
gomatrixserverlib.WithTransport(tr), gomatrixserverlib.WithTransport(tr),
) )
} }

View file

@ -43,13 +43,18 @@ func NewYggdrasilRoomProvider(
} }
func (p *YggdrasilRoomProvider) Rooms() []gomatrixserverlib.PublicRoom { func (p *YggdrasilRoomProvider) Rooms() []gomatrixserverlib.PublicRoom {
return bulkFetchPublicRoomsFromServers(context.Background(), p.fedClient, p.node.KnownNodes()) return bulkFetchPublicRoomsFromServers(
context.Background(), p.fedClient,
gomatrixserverlib.ServerName(p.node.DerivedServerName()),
p.node.KnownNodes(),
)
} }
// bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers.
// Returns a list of public rooms. // Returns a list of public rooms.
func bulkFetchPublicRoomsFromServers( func bulkFetchPublicRoomsFromServers(
ctx context.Context, fedClient *gomatrixserverlib.FederationClient, ctx context.Context, fedClient *gomatrixserverlib.FederationClient,
origin gomatrixserverlib.ServerName,
homeservers []gomatrixserverlib.ServerName, homeservers []gomatrixserverlib.ServerName,
) (publicRooms []gomatrixserverlib.PublicRoom) { ) (publicRooms []gomatrixserverlib.PublicRoom) {
limit := 200 limit := 200
@ -66,7 +71,7 @@ func bulkFetchPublicRoomsFromServers(
go func(homeserverDomain gomatrixserverlib.ServerName) { go func(homeserverDomain gomatrixserverlib.ServerName) {
defer wg.Done() defer wg.Done()
util.GetLogger(ctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") util.GetLogger(ctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms")
fres, err := fedClient.GetPublicRooms(ctx, homeserverDomain, int(limit), "", false, "") fres, err := fedClient.GetPublicRooms(ctx, origin, homeserverDomain, int(limit), "", false, "")
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).WithField("hs", homeserverDomain).Warn( util.GetLogger(ctx).WithError(err).WithField("hs", homeserverDomain).Warn(
"bulkFetchPublicRoomsFromServers: failed to query hs", "bulkFetchPublicRoomsFromServers: failed to query hs",

View file

@ -48,10 +48,15 @@ func main() {
panic("unexpected key block") panic("unexpected key block")
} }
serverName := gomatrixserverlib.ServerName(*requestFrom)
client := gomatrixserverlib.NewFederationClient( client := gomatrixserverlib.NewFederationClient(
gomatrixserverlib.ServerName(*requestFrom), []*gomatrixserverlib.SigningIdentity{
gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]), {
privateKey, ServerName: serverName,
KeyID: gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]),
PrivateKey: privateKey,
},
},
) )
u, err := url.Parse(flag.Arg(0)) u, err := url.Parse(flag.Arg(0))
@ -79,6 +84,7 @@ func main() {
req := gomatrixserverlib.NewFederationRequest( req := gomatrixserverlib.NewFederationRequest(
method, method,
serverName,
gomatrixserverlib.ServerName(u.Host), gomatrixserverlib.ServerName(u.Host),
u.RequestURI(), u.RequestURI(),
) )

View file

@ -21,8 +21,8 @@ type FederationInternalAPI interface {
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
// Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos. // Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos.
PerformBroadcastEDU( PerformBroadcastEDU(
@ -60,18 +60,18 @@ type RoomserverFederationAPI interface {
// containing only the server names (without information for membership events). // containing only the server names (without information for membership events).
// The response will include this server if they are joined to the room. // The response will include this server if they are joined to the room.
QueryJoinedHostServerNamesInRoom(ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, response *QueryJoinedHostServerNamesInRoomResponse) error QueryJoinedHostServerNamesInRoom(ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, response *QueryJoinedHostServerNamesInRoomResponse) error
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
} }
// KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver
// implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in
// this interface are of type FederationClientError // this interface are of type FederationClientError
type KeyserverFederationAPI interface { type KeyserverFederationAPI interface {
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error)
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
} }
// an interface for gmsl.FederationClient - contains functions called by federationapi only. // an interface for gmsl.FederationClient - contains functions called by federationapi only.
@ -80,28 +80,28 @@ type FederationClient interface {
SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error)
// Perform operations // Perform operations
LookupRoomAlias(ctx context.Context, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error)
Peek(ctx context.Context, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error) Peek(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error)
MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error)
SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error)
MakeLeave(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error) MakeLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error)
SendLeave(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error) SendLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error)
SendInviteV2(ctx context.Context, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error) SendInviteV2(ctx context.Context, origin, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error) GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error)
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error)
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error)
Backfill(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error) Backfill(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
ExchangeThirdPartyInvite(ctx context.Context, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error) ExchangeThirdPartyInvite(ctx context.Context, origin, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error)
LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error)
LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
} }
// FederationClientError is returned from FederationClient methods in the event of a problem. // FederationClientError is returned from FederationClient methods in the event of a problem.

View file

@ -120,17 +120,7 @@ func NewInternalAPI(
js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
signingInfo := map[gomatrixserverlib.ServerName]*queue.SigningInfo{} signingInfo := base.Cfg.Global.SigningIdentities()
for _, serverName := range append(
[]gomatrixserverlib.ServerName{base.Cfg.Global.ServerName},
base.Cfg.Global.SecondaryServerNames...,
) {
signingInfo[serverName] = &queue.SigningInfo{
KeyID: cfg.Matrix.KeyID,
PrivateKey: cfg.Matrix.PrivateKey,
ServerName: serverName,
}
}
queues := queue.NewOutgoingQueues( queues := queue.NewOutgoingQueues(
federationDB, base.ProcessContext, federationDB, base.ProcessContext,

View file

@ -104,7 +104,7 @@ func TestMain(m *testing.M) {
// Create the federation client. // Create the federation client.
s.fedclient = gomatrixserverlib.NewFederationClient( s.fedclient = gomatrixserverlib.NewFederationClient(
s.config.Matrix.ServerName, serverKeyID, testPriv, s.config.Matrix.SigningIdentities(),
gomatrixserverlib.WithTransport(transport), gomatrixserverlib.WithTransport(transport),
) )

View file

@ -103,7 +103,7 @@ func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserv
return keys, nil return keys, nil
} }
func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) { func (f *fedClient) MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) {
for _, r := range f.allowJoins { for _, r := range f.allowJoins {
if r.ID == roomID { if r.ID == roomID {
res.RoomVersion = r.Version res.RoomVersion = r.Version
@ -127,7 +127,7 @@ func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName
} }
return return
} }
func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) { func (f *fedClient) SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) {
f.fedClientMutex.Lock() f.fedClientMutex.Lock()
defer f.fedClientMutex.Unlock() defer f.fedClientMutex.Unlock()
for _, r := range f.allowJoins { for _, r := range f.allowJoins {
@ -283,7 +283,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://"))
fedCli := gomatrixserverlib.NewFederationClient( fedCli := gomatrixserverlib.NewFederationClient(
serverName, cfg.Global.KeyID, cfg.Global.PrivateKey, cfg.Global.SigningIdentities(),
gomatrixserverlib.WithSkipVerify(true), gomatrixserverlib.WithSkipVerify(true),
) )
@ -326,7 +326,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
t.Errorf("failed to create invite v2 request: %s", err) t.Errorf("failed to create invite v2 request: %s", err)
continue continue
} }
_, err = fedCli.SendInviteV2(context.Background(), serverName, invReq) _, err = fedCli.SendInviteV2(context.Background(), cfg.Global.ServerName, serverName, invReq)
if err == nil { if err == nil {
t.Errorf("expected an error, got none") t.Errorf("expected an error, got none")
continue continue

View file

@ -11,13 +11,13 @@ import (
// client. // client.
func (a *FederationInternalAPI) GetEventAuth( func (a *FederationInternalAPI) GetEventAuth(
ctx context.Context, s gomatrixserverlib.ServerName, ctx context.Context, origin, s gomatrixserverlib.ServerName,
roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string,
) (res gomatrixserverlib.RespEventAuth, err error) { ) (res gomatrixserverlib.RespEventAuth, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetEventAuth(ctx, s, roomVersion, roomID, eventID) return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespEventAuth{}, err return gomatrixserverlib.RespEventAuth{}, err
@ -26,12 +26,12 @@ func (a *FederationInternalAPI) GetEventAuth(
} }
func (a *FederationInternalAPI) GetUserDevices( func (a *FederationInternalAPI) GetUserDevices(
ctx context.Context, s gomatrixserverlib.ServerName, userID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string,
) (gomatrixserverlib.RespUserDevices, error) { ) (gomatrixserverlib.RespUserDevices, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetUserDevices(ctx, s, userID) return a.federation.GetUserDevices(ctx, origin, s, userID)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespUserDevices{}, err return gomatrixserverlib.RespUserDevices{}, err
@ -40,12 +40,12 @@ func (a *FederationInternalAPI) GetUserDevices(
} }
func (a *FederationInternalAPI) ClaimKeys( func (a *FederationInternalAPI) ClaimKeys(
ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string,
) (gomatrixserverlib.RespClaimKeys, error) { ) (gomatrixserverlib.RespClaimKeys, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.ClaimKeys(ctx, s, oneTimeKeys) return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespClaimKeys{}, err return gomatrixserverlib.RespClaimKeys{}, err
@ -54,10 +54,10 @@ func (a *FederationInternalAPI) ClaimKeys(
} }
func (a *FederationInternalAPI) QueryKeys( func (a *FederationInternalAPI) QueryKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string,
) (gomatrixserverlib.RespQueryKeys, error) { ) (gomatrixserverlib.RespQueryKeys, error) {
ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) {
return a.federation.QueryKeys(ctx, s, keys) return a.federation.QueryKeys(ctx, origin, s, keys)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespQueryKeys{}, err return gomatrixserverlib.RespQueryKeys{}, err
@ -66,12 +66,12 @@ func (a *FederationInternalAPI) QueryKeys(
} }
func (a *FederationInternalAPI) Backfill( func (a *FederationInternalAPI) Backfill(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
) (res gomatrixserverlib.Transaction, err error) { ) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.Backfill(ctx, s, roomID, limit, eventIDs) return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.Transaction{}, err return gomatrixserverlib.Transaction{}, err
@ -80,12 +80,12 @@ func (a *FederationInternalAPI) Backfill(
} }
func (a *FederationInternalAPI) LookupState( func (a *FederationInternalAPI) LookupState(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespState, err error) { ) (res gomatrixserverlib.RespState, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupState(ctx, s, roomID, eventID, roomVersion) return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespState{}, err return gomatrixserverlib.RespState{}, err
@ -94,12 +94,12 @@ func (a *FederationInternalAPI) LookupState(
} }
func (a *FederationInternalAPI) LookupStateIDs( func (a *FederationInternalAPI) LookupStateIDs(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string,
) (res gomatrixserverlib.RespStateIDs, err error) { ) (res gomatrixserverlib.RespStateIDs, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupStateIDs(ctx, s, roomID, eventID) return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespStateIDs{}, err return gomatrixserverlib.RespStateIDs{}, err
@ -108,13 +108,13 @@ func (a *FederationInternalAPI) LookupStateIDs(
} }
func (a *FederationInternalAPI) LookupMissingEvents( func (a *FederationInternalAPI) LookupMissingEvents(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string,
missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespMissingEvents, err error) { ) (res gomatrixserverlib.RespMissingEvents, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.LookupMissingEvents(ctx, s, roomID, missing, roomVersion) return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.RespMissingEvents{}, err return gomatrixserverlib.RespMissingEvents{}, err
@ -123,12 +123,12 @@ func (a *FederationInternalAPI) LookupMissingEvents(
} }
func (a *FederationInternalAPI) GetEvent( func (a *FederationInternalAPI) GetEvent(
ctx context.Context, s gomatrixserverlib.ServerName, eventID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string,
) (res gomatrixserverlib.Transaction, err error) { ) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30) ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.GetEvent(ctx, s, eventID) return a.federation.GetEvent(ctx, origin, s, eventID)
}) })
if err != nil { if err != nil {
return gomatrixserverlib.Transaction{}, err return gomatrixserverlib.Transaction{}, err
@ -151,13 +151,13 @@ func (a *FederationInternalAPI) LookupServerKeys(
} }
func (a *FederationInternalAPI) MSC2836EventRelationships( func (a *FederationInternalAPI) MSC2836EventRelationships(
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.MSC2836EventRelationships(ctx, s, r, roomVersion) return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion)
}) })
if err != nil { if err != nil {
return res, err return res, err
@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships(
} }
func (a *FederationInternalAPI) MSC2946Spaces( func (a *FederationInternalAPI) MSC2946Spaces(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() defer cancel()
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly) return a.federation.MSC2946Spaces(ctx, origin, s, roomID, suggestedOnly)
}) })
if err != nil { if err != nil {
return res, err return res, err

View file

@ -26,6 +26,7 @@ func (r *FederationInternalAPI) PerformDirectoryLookup(
) (err error) { ) (err error) {
dir, err := r.federation.LookupRoomAlias( dir, err := r.federation.LookupRoomAlias(
ctx, ctx,
r.cfg.Matrix.ServerName,
request.ServerName, request.ServerName,
request.RoomAlias, request.RoomAlias,
) )
@ -143,10 +144,16 @@ func (r *FederationInternalAPI) performJoinUsingServer(
supportedVersions []gomatrixserverlib.RoomVersion, supportedVersions []gomatrixserverlib.RoomVersion,
unsigned map[string]interface{}, unsigned map[string]interface{},
) error { ) error {
_, origin, err := r.cfg.Matrix.SplitLocalID('@', userID)
if err != nil {
return err
}
// Try to perform a make_join using the information supplied in the // Try to perform a make_join using the information supplied in the
// request. // request.
respMakeJoin, err := r.federation.MakeJoin( respMakeJoin, err := r.federation.MakeJoin(
ctx, ctx,
origin,
serverName, serverName,
roomID, roomID,
userID, userID,
@ -192,7 +199,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
// Build the join event. // Build the join event.
event, err := respMakeJoin.JoinEvent.Build( event, err := respMakeJoin.JoinEvent.Build(
time.Now(), time.Now(),
r.cfg.Matrix.ServerName, origin,
r.cfg.Matrix.KeyID, r.cfg.Matrix.KeyID,
r.cfg.Matrix.PrivateKey, r.cfg.Matrix.PrivateKey,
respMakeJoin.RoomVersion, respMakeJoin.RoomVersion,
@ -204,6 +211,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
// Try to perform a send_join using the newly built event. // Try to perform a send_join using the newly built event.
respSendJoin, err := r.federation.SendJoin( respSendJoin, err := r.federation.SendJoin(
context.Background(), context.Background(),
origin,
serverName, serverName,
event, event,
) )
@ -246,7 +254,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
respMakeJoin.RoomVersion, respMakeJoin.RoomVersion,
r.keyRing, r.keyRing,
event, event,
federatedAuthProvider(ctx, r.federation, r.keyRing, serverName), federatedAuthProvider(ctx, r.federation, r.keyRing, origin, serverName),
) )
if err != nil { if err != nil {
return fmt.Errorf("respSendJoin.Check: %w", err) return fmt.Errorf("respSendJoin.Check: %w", err)
@ -281,6 +289,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
if err = roomserverAPI.SendEventWithState( if err = roomserverAPI.SendEventWithState(
context.Background(), context.Background(),
r.rsAPI, r.rsAPI,
origin,
roomserverAPI.KindNew, roomserverAPI.KindNew,
respState, respState,
event.Headered(respMakeJoin.RoomVersion), event.Headered(respMakeJoin.RoomVersion),
@ -427,6 +436,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// request. // request.
respPeek, err := r.federation.Peek( respPeek, err := r.federation.Peek(
ctx, ctx,
r.cfg.Matrix.ServerName,
serverName, serverName,
roomID, roomID,
peekID, peekID,
@ -453,7 +463,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// authenticate the state returned (check its auth events etc) // authenticate the state returned (check its auth events etc)
// the equivalent of CheckSendJoinResponse() // the equivalent of CheckSendJoinResponse()
authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)) authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName))
if err != nil { if err != nil {
return fmt.Errorf("error checking state returned from peeking: %w", err) return fmt.Errorf("error checking state returned from peeking: %w", err)
} }
@ -475,7 +485,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// logrus.Warnf("got respPeek %#v", respPeek) // logrus.Warnf("got respPeek %#v", respPeek)
// Send the newly returned state to the roomserver to update our local view. // Send the newly returned state to the roomserver to update our local view.
if err = roomserverAPI.SendEventWithState( if err = roomserverAPI.SendEventWithState(
ctx, r.rsAPI, ctx, r.rsAPI, r.cfg.Matrix.ServerName,
roomserverAPI.KindNew, roomserverAPI.KindNew,
&respState, &respState,
respPeek.LatestEvent.Headered(respPeek.RoomVersion), respPeek.LatestEvent.Headered(respPeek.RoomVersion),
@ -495,6 +505,11 @@ func (r *FederationInternalAPI) PerformLeave(
request *api.PerformLeaveRequest, request *api.PerformLeaveRequest,
response *api.PerformLeaveResponse, response *api.PerformLeaveResponse,
) (err error) { ) (err error) {
_, origin, err := r.cfg.Matrix.SplitLocalID('@', request.UserID)
if err != nil {
return err
}
// Deduplicate the server names we were provided. // Deduplicate the server names we were provided.
util.SortAndUnique(request.ServerNames) util.SortAndUnique(request.ServerNames)
@ -505,6 +520,7 @@ func (r *FederationInternalAPI) PerformLeave(
// request. // request.
respMakeLeave, err := r.federation.MakeLeave( respMakeLeave, err := r.federation.MakeLeave(
ctx, ctx,
origin,
serverName, serverName,
request.RoomID, request.RoomID,
request.UserID, request.UserID,
@ -546,7 +562,7 @@ func (r *FederationInternalAPI) PerformLeave(
// Build the leave event. // Build the leave event.
event, err := respMakeLeave.LeaveEvent.Build( event, err := respMakeLeave.LeaveEvent.Build(
time.Now(), time.Now(),
r.cfg.Matrix.ServerName, origin,
r.cfg.Matrix.KeyID, r.cfg.Matrix.KeyID,
r.cfg.Matrix.PrivateKey, r.cfg.Matrix.PrivateKey,
respMakeLeave.RoomVersion, respMakeLeave.RoomVersion,
@ -559,6 +575,7 @@ func (r *FederationInternalAPI) PerformLeave(
// Try to perform a send_leave using the newly built event. // Try to perform a send_leave using the newly built event.
err = r.federation.SendLeave( err = r.federation.SendLeave(
ctx, ctx,
origin,
serverName, serverName,
event, event,
) )
@ -585,6 +602,11 @@ func (r *FederationInternalAPI) PerformInvite(
request *api.PerformInviteRequest, request *api.PerformInviteRequest,
response *api.PerformInviteResponse, response *api.PerformInviteResponse,
) (err error) { ) (err error) {
_, origin, err := r.cfg.Matrix.SplitLocalID('@', request.Event.Sender())
if err != nil {
return err
}
if request.Event.StateKey() == nil { if request.Event.StateKey() == nil {
return errors.New("invite must be a state event") return errors.New("invite must be a state event")
} }
@ -607,7 +629,7 @@ func (r *FederationInternalAPI) PerformInvite(
return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err)
} }
inviteRes, err := r.federation.SendInviteV2(ctx, destination, inviteReq) inviteRes, err := r.federation.SendInviteV2(ctx, origin, destination, inviteReq)
if err != nil { if err != nil {
return fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) return fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err)
} }
@ -708,7 +730,7 @@ func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder
// FederatedAuthProvider is an auth chain provider which fetches events from the server provided // FederatedAuthProvider is an auth chain provider which fetches events from the server provided
func federatedAuthProvider( func federatedAuthProvider(
ctx context.Context, federation api.FederationClient, ctx context.Context, federation api.FederationClient,
keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName, keyRing gomatrixserverlib.JSONVerifier, origin, server gomatrixserverlib.ServerName,
) gomatrixserverlib.AuthChainProvider { ) gomatrixserverlib.AuthChainProvider {
// A list of events that we have retried, if they were not included in // A list of events that we have retried, if they were not included in
// the auth events supplied in the send_join. // the auth events supplied in the send_join.
@ -738,7 +760,7 @@ func federatedAuthProvider(
// Try to retrieve the event from the server that sent us the send // Try to retrieve the event from the server that sent us the send
// join response. // join response.
tx, txerr := federation.GetEvent(ctx, server, eventID) tx, txerr := federation.GetEvent(ctx, origin, server, eventID)
if txerr != nil { if txerr != nil {
return nil, fmt.Errorf("missingAuth r.federation.GetEvent: %w", txerr) return nil, fmt.Errorf("missingAuth r.federation.GetEvent: %w", txerr)
} }

View file

@ -152,16 +152,18 @@ func (h *httpFederationInternalAPI) PerformBroadcastEDU(
type getUserDevices struct { type getUserDevices struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
UserID string UserID string
} }
func (h *httpFederationInternalAPI) GetUserDevices( func (h *httpFederationInternalAPI) GetUserDevices(
ctx context.Context, s gomatrixserverlib.ServerName, userID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string,
) (gomatrixserverlib.RespUserDevices, error) { ) (gomatrixserverlib.RespUserDevices, error) {
return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError]( return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError](
"GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient, "GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient,
ctx, &getUserDevices{ ctx, &getUserDevices{
S: s, S: s,
Origin: origin,
UserID: userID, UserID: userID,
}, },
) )
@ -169,52 +171,58 @@ func (h *httpFederationInternalAPI) GetUserDevices(
type claimKeys struct { type claimKeys struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
OneTimeKeys map[string]map[string]string OneTimeKeys map[string]map[string]string
} }
func (h *httpFederationInternalAPI) ClaimKeys( func (h *httpFederationInternalAPI) ClaimKeys(
ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string,
) (gomatrixserverlib.RespClaimKeys, error) { ) (gomatrixserverlib.RespClaimKeys, error) {
return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError]( return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError](
"ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient, "ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient,
ctx, &claimKeys{ ctx, &claimKeys{
S: s, S: s,
Origin: origin,
OneTimeKeys: oneTimeKeys, OneTimeKeys: oneTimeKeys,
}, },
) )
} }
type queryKeys struct { type queryKeys struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Keys map[string][]string Origin gomatrixserverlib.ServerName
Keys map[string][]string
} }
func (h *httpFederationInternalAPI) QueryKeys( func (h *httpFederationInternalAPI) QueryKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string,
) (gomatrixserverlib.RespQueryKeys, error) { ) (gomatrixserverlib.RespQueryKeys, error) {
return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError]( return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError](
"QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient, "QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient,
ctx, &queryKeys{ ctx, &queryKeys{
S: s, S: s,
Keys: keys, Origin: origin,
Keys: keys,
}, },
) )
} }
type backfill struct { type backfill struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomID string RoomID string
Limit int Limit int
EventIDs []string EventIDs []string
} }
func (h *httpFederationInternalAPI) Backfill( func (h *httpFederationInternalAPI) Backfill(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
) (gomatrixserverlib.Transaction, error) { ) (gomatrixserverlib.Transaction, error) {
return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError]( return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError](
"Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient, "Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient,
ctx, &backfill{ ctx, &backfill{
S: s, S: s,
Origin: origin,
RoomID: roomID, RoomID: roomID,
Limit: limit, Limit: limit,
EventIDs: eventIDs, EventIDs: eventIDs,
@ -224,18 +232,20 @@ func (h *httpFederationInternalAPI) Backfill(
type lookupState struct { type lookupState struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomID string RoomID string
EventID string EventID string
RoomVersion gomatrixserverlib.RoomVersion RoomVersion gomatrixserverlib.RoomVersion
} }
func (h *httpFederationInternalAPI) LookupState( func (h *httpFederationInternalAPI) LookupState(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (gomatrixserverlib.RespState, error) { ) (gomatrixserverlib.RespState, error) {
return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError]( return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError](
"LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient, "LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient,
ctx, &lookupState{ ctx, &lookupState{
S: s, S: s,
Origin: origin,
RoomID: roomID, RoomID: roomID,
EventID: eventID, EventID: eventID,
RoomVersion: roomVersion, RoomVersion: roomVersion,
@ -245,17 +255,19 @@ func (h *httpFederationInternalAPI) LookupState(
type lookupStateIDs struct { type lookupStateIDs struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomID string RoomID string
EventID string EventID string
} }
func (h *httpFederationInternalAPI) LookupStateIDs( func (h *httpFederationInternalAPI) LookupStateIDs(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string,
) (gomatrixserverlib.RespStateIDs, error) { ) (gomatrixserverlib.RespStateIDs, error) {
return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError]( return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError](
"LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient, "LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient,
ctx, &lookupStateIDs{ ctx, &lookupStateIDs{
S: s, S: s,
Origin: origin,
RoomID: roomID, RoomID: roomID,
EventID: eventID, EventID: eventID,
}, },
@ -264,19 +276,21 @@ func (h *httpFederationInternalAPI) LookupStateIDs(
type lookupMissingEvents struct { type lookupMissingEvents struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomID string RoomID string
Missing gomatrixserverlib.MissingEvents Missing gomatrixserverlib.MissingEvents
RoomVersion gomatrixserverlib.RoomVersion RoomVersion gomatrixserverlib.RoomVersion
} }
func (h *httpFederationInternalAPI) LookupMissingEvents( func (h *httpFederationInternalAPI) LookupMissingEvents(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string,
missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespMissingEvents, err error) { ) (res gomatrixserverlib.RespMissingEvents, err error) {
return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError]( return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError](
"LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient, "LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient,
ctx, &lookupMissingEvents{ ctx, &lookupMissingEvents{
S: s, S: s,
Origin: origin,
RoomID: roomID, RoomID: roomID,
Missing: missing, Missing: missing,
RoomVersion: roomVersion, RoomVersion: roomVersion,
@ -286,16 +300,18 @@ func (h *httpFederationInternalAPI) LookupMissingEvents(
type getEvent struct { type getEvent struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
EventID string EventID string
} }
func (h *httpFederationInternalAPI) GetEvent( func (h *httpFederationInternalAPI) GetEvent(
ctx context.Context, s gomatrixserverlib.ServerName, eventID string, ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string,
) (gomatrixserverlib.Transaction, error) { ) (gomatrixserverlib.Transaction, error) {
return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError]( return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError](
"GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient, "GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient,
ctx, &getEvent{ ctx, &getEvent{
S: s, S: s,
Origin: origin,
EventID: eventID, EventID: eventID,
}, },
) )
@ -303,19 +319,21 @@ func (h *httpFederationInternalAPI) GetEvent(
type getEventAuth struct { type getEventAuth struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
RoomVersion gomatrixserverlib.RoomVersion RoomVersion gomatrixserverlib.RoomVersion
RoomID string RoomID string
EventID string EventID string
} }
func (h *httpFederationInternalAPI) GetEventAuth( func (h *httpFederationInternalAPI) GetEventAuth(
ctx context.Context, s gomatrixserverlib.ServerName, ctx context.Context, origin, s gomatrixserverlib.ServerName,
roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string,
) (gomatrixserverlib.RespEventAuth, error) { ) (gomatrixserverlib.RespEventAuth, error) {
return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError]( return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError](
"GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient, "GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient,
ctx, &getEventAuth{ ctx, &getEventAuth{
S: s, S: s,
Origin: origin,
RoomVersion: roomVersion, RoomVersion: roomVersion,
RoomID: roomID, RoomID: roomID,
EventID: eventID, EventID: eventID,
@ -351,18 +369,20 @@ func (h *httpFederationInternalAPI) LookupServerKeys(
type eventRelationships struct { type eventRelationships struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
Req gomatrixserverlib.MSC2836EventRelationshipsRequest Req gomatrixserverlib.MSC2836EventRelationshipsRequest
RoomVer gomatrixserverlib.RoomVersion RoomVer gomatrixserverlib.RoomVersion
} }
func (h *httpFederationInternalAPI) MSC2836EventRelationships( func (h *httpFederationInternalAPI) MSC2836EventRelationships(
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError]( return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError](
"MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient, "MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient,
ctx, &eventRelationships{ ctx, &eventRelationships{
S: s, S: s,
Origin: origin,
Req: r, Req: r,
RoomVer: roomVersion, RoomVer: roomVersion,
}, },
@ -371,17 +391,19 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships(
type spacesReq struct { type spacesReq struct {
S gomatrixserverlib.ServerName S gomatrixserverlib.ServerName
Origin gomatrixserverlib.ServerName
SuggestedOnly bool SuggestedOnly bool
RoomID string RoomID string
} }
func (h *httpFederationInternalAPI) MSC2946Spaces( func (h *httpFederationInternalAPI) MSC2946Spaces(
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError]( return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError](
"MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient, "MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient,
ctx, &spacesReq{ ctx, &spacesReq{
S: dst, S: dst,
Origin: origin,
SuggestedOnly: suggestedOnly, SuggestedOnly: suggestedOnly,
RoomID: roomID, RoomID: roomID,
}, },

View file

@ -59,7 +59,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIGetUserDevices", "FederationAPIGetUserDevices",
func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) { func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) {
res, err := intAPI.GetUserDevices(ctx, req.S, req.UserID) res, err := intAPI.GetUserDevices(ctx, req.Origin, req.S, req.UserID)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -70,7 +70,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIClaimKeys", "FederationAPIClaimKeys",
func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) { func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) {
res, err := intAPI.ClaimKeys(ctx, req.S, req.OneTimeKeys) res, err := intAPI.ClaimKeys(ctx, req.Origin, req.S, req.OneTimeKeys)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -81,7 +81,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIQueryKeys", "FederationAPIQueryKeys",
func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) { func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) {
res, err := intAPI.QueryKeys(ctx, req.S, req.Keys) res, err := intAPI.QueryKeys(ctx, req.Origin, req.S, req.Keys)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -92,7 +92,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIBackfill", "FederationAPIBackfill",
func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) { func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) {
res, err := intAPI.Backfill(ctx, req.S, req.RoomID, req.Limit, req.EventIDs) res, err := intAPI.Backfill(ctx, req.Origin, req.S, req.RoomID, req.Limit, req.EventIDs)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -103,7 +103,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPILookupState", "FederationAPILookupState",
func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) { func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) {
res, err := intAPI.LookupState(ctx, req.S, req.RoomID, req.EventID, req.RoomVersion) res, err := intAPI.LookupState(ctx, req.Origin, req.S, req.RoomID, req.EventID, req.RoomVersion)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -114,7 +114,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPILookupStateIDs", "FederationAPILookupStateIDs",
func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) { func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) {
res, err := intAPI.LookupStateIDs(ctx, req.S, req.RoomID, req.EventID) res, err := intAPI.LookupStateIDs(ctx, req.Origin, req.S, req.RoomID, req.EventID)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -125,7 +125,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPILookupMissingEvents", "FederationAPILookupMissingEvents",
func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) { func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) {
res, err := intAPI.LookupMissingEvents(ctx, req.S, req.RoomID, req.Missing, req.RoomVersion) res, err := intAPI.LookupMissingEvents(ctx, req.Origin, req.S, req.RoomID, req.Missing, req.RoomVersion)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -136,7 +136,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIGetEvent", "FederationAPIGetEvent",
func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) { func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) {
res, err := intAPI.GetEvent(ctx, req.S, req.EventID) res, err := intAPI.GetEvent(ctx, req.Origin, req.S, req.EventID)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -147,7 +147,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIGetEventAuth", "FederationAPIGetEventAuth",
func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) { func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) {
res, err := intAPI.GetEventAuth(ctx, req.S, req.RoomVersion, req.RoomID, req.EventID) res, err := intAPI.GetEventAuth(ctx, req.Origin, req.S, req.RoomVersion, req.RoomID, req.EventID)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -174,7 +174,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIMSC2836EventRelationships", "FederationAPIMSC2836EventRelationships",
func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) {
res, err := intAPI.MSC2836EventRelationships(ctx, req.S, req.Req, req.RoomVer) res, err := intAPI.MSC2836EventRelationships(ctx, req.Origin, req.S, req.Req, req.RoomVer)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),
@ -185,7 +185,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalProxyAPI( httputil.MakeInternalProxyAPI(
"FederationAPIMSC2946SpacesSummary", "FederationAPIMSC2946SpacesSummary",
func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) { func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
res, err := intAPI.MSC2946Spaces(ctx, req.S, req.RoomID, req.SuggestedOnly) res, err := intAPI.MSC2946Spaces(ctx, req.Origin, req.S, req.RoomID, req.SuggestedOnly)
return &res, federationClientError(err) return &res, federationClientError(err)
}, },
), ),

View file

@ -50,7 +50,7 @@ type destinationQueue struct {
queues *OutgoingQueues queues *OutgoingQueues
db storage.Database db storage.Database
process *process.ProcessContext process *process.ProcessContext
signing map[gomatrixserverlib.ServerName]*SigningInfo signing map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity
rsAPI api.FederationRoomserverAPI rsAPI api.FederationRoomserverAPI
client fedapi.FederationClient // federation client client fedapi.FederationClient // federation client
origin gomatrixserverlib.ServerName // origin of requests origin gomatrixserverlib.ServerName // origin of requests

View file

@ -15,7 +15,6 @@
package queue package queue
import ( import (
"crypto/ed25519"
"encoding/json" "encoding/json"
"fmt" "fmt"
"sync" "sync"
@ -46,7 +45,7 @@ type OutgoingQueues struct {
origin gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName
client fedapi.FederationClient client fedapi.FederationClient
statistics *statistics.Statistics statistics *statistics.Statistics
signing map[gomatrixserverlib.ServerName]*SigningInfo signing map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity
queuesMutex sync.Mutex // protects the below queuesMutex sync.Mutex // protects the below
queues map[gomatrixserverlib.ServerName]*destinationQueue queues map[gomatrixserverlib.ServerName]*destinationQueue
} }
@ -91,7 +90,7 @@ func NewOutgoingQueues(
client fedapi.FederationClient, client fedapi.FederationClient,
rsAPI api.FederationRoomserverAPI, rsAPI api.FederationRoomserverAPI,
statistics *statistics.Statistics, statistics *statistics.Statistics,
signing map[gomatrixserverlib.ServerName]*SigningInfo, signing []*gomatrixserverlib.SigningIdentity,
) *OutgoingQueues { ) *OutgoingQueues {
queues := &OutgoingQueues{ queues := &OutgoingQueues{
disabled: disabled, disabled: disabled,
@ -101,9 +100,12 @@ func NewOutgoingQueues(
origin: origin, origin: origin,
client: client, client: client,
statistics: statistics, statistics: statistics,
signing: signing, signing: map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity{},
queues: map[gomatrixserverlib.ServerName]*destinationQueue{}, queues: map[gomatrixserverlib.ServerName]*destinationQueue{},
} }
for _, identity := range signing {
queues.signing[identity.ServerName] = identity
}
// Look up which servers we have pending items for and then rehydrate those queues. // Look up which servers we have pending items for and then rehydrate those queues.
if !disabled { if !disabled {
serverNames := map[gomatrixserverlib.ServerName]struct{}{} serverNames := map[gomatrixserverlib.ServerName]struct{}{}
@ -135,14 +137,6 @@ func NewOutgoingQueues(
return queues return queues
} }
// TODO: Move this somewhere useful for other components as we often need to ferry these 3 variables
// around together
type SigningInfo struct {
ServerName gomatrixserverlib.ServerName
KeyID gomatrixserverlib.KeyID
PrivateKey ed25519.PrivateKey
}
type queuedPDU struct { type queuedPDU struct {
receipt *shared.Receipt receipt *shared.Receipt
pdu *gomatrixserverlib.HeaderedEvent pdu *gomatrixserverlib.HeaderedEvent

View file

@ -350,8 +350,8 @@ func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T
} }
rs := &stubFederationRoomServerAPI{} rs := &stubFederationRoomServerAPI{}
stats := statistics.NewStatistics(db, failuresUntilBlacklist) stats := statistics.NewStatistics(db, failuresUntilBlacklist)
signingInfo := map[gomatrixserverlib.ServerName]*SigningInfo{ signingInfo := []*gomatrixserverlib.SigningIdentity{
"localhost": { {
KeyID: "ed21019:auto", KeyID: "ed21019:auto",
PrivateKey: test.PrivateKeyA, PrivateKey: test.PrivateKeyA,
ServerName: "localhost", ServerName: "localhost",

View file

@ -82,7 +82,8 @@ func Backfill(
BackwardsExtremities: map[string][]string{ BackwardsExtremities: map[string][]string{
"": eIDs, "": eIDs,
}, },
ServerName: request.Origin(), ServerName: request.Origin(),
VirtualHost: request.Destination(),
} }
if req.Limit, err = strconv.Atoi(limit); err != nil { if req.Limit, err = strconv.Atoi(limit); err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed")
@ -123,7 +124,7 @@ func Backfill(
} }
txn := gomatrixserverlib.Transaction{ txn := gomatrixserverlib.Transaction{
Origin: cfg.Matrix.ServerName, Origin: request.Destination(),
PDUs: eventJSONs, PDUs: eventJSONs,
OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()),
} }

View file

@ -140,6 +140,21 @@ func processInvite(
} }
} }
if event.StateKey() == nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The invite event has no state key"),
}
}
_, domain, err := cfg.Matrix.SplitLocalID('@', *event.StateKey())
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("The user ID is invalid or domain %q does not belong to this server", domain)),
}
}
// Check that the event is signed by the server sending the request. // Check that the event is signed by the server sending the request.
redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version()) redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version())
if err != nil { if err != nil {
@ -175,7 +190,7 @@ func processInvite(
// Sign the event so that other servers will know that we have received the invite. // Sign the event so that other servers will know that we have received the invite.
signedEvent := event.Sign( signedEvent := event.Sign(
string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, string(domain), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey,
) )
// Add the invite event to the roomserver. // Add the invite event to the roomserver.

View file

@ -131,10 +131,20 @@ func MakeJoin(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
identity, err := cfg.Matrix.SigningIdentityFor(request.Destination())
if err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(
fmt.Sprintf("Server name %q does not exist", request.Destination()),
),
}
}
queryRes := api.QueryLatestEventsAndStateResponse{ queryRes := api.QueryLatestEventsAndStateResponse{
RoomVersion: verRes.RoomVersion, RoomVersion: verRes.RoomVersion,
} }
event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,

View file

@ -16,7 +16,6 @@ package routing
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"time" "time"
@ -136,38 +135,52 @@ func ClaimOneTimeKeys(
// LocalKeys returns the local keys for the server. // LocalKeys returns the local keys for the server.
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys // See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) util.JSONResponse { func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) util.JSONResponse {
keys, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod)) keys, err := localKeys(cfg, serverName)
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.MessageResponse(http.StatusNotFound, err.Error())
} }
return util.JSONResponse{Code: http.StatusOK, JSON: keys} return util.JSONResponse{Code: http.StatusOK, JSON: keys}
} }
func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) { func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) {
var keys gomatrixserverlib.ServerKeys var keys gomatrixserverlib.ServerKeys
if !cfg.Matrix.IsLocalServerName(serverName) { var virtualHost *config.VirtualHost
return nil, fmt.Errorf("server name not known") for _, v := range cfg.Matrix.VirtualHosts {
} if v.ServerName != serverName {
continue
keys.ServerName = serverName
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil)
publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)
keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{
cfg.Matrix.KeyID: {
Key: gomatrixserverlib.Base64Bytes(publicKey),
},
}
keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{}
for _, oldVerifyKey := range cfg.Matrix.OldVerifyKeys {
keys.OldVerifyKeys[oldVerifyKey.KeyID] = gomatrixserverlib.OldVerifyKey{
VerifyKey: gomatrixserverlib.VerifyKey{
Key: oldVerifyKey.PublicKey,
},
ExpiredTS: oldVerifyKey.ExpiredAt,
} }
virtualHost = v
}
if virtualHost == nil {
publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)
keys.ServerName = cfg.Matrix.ServerName
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(cfg.Matrix.KeyValidityPeriod))
keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{
cfg.Matrix.KeyID: {
Key: gomatrixserverlib.Base64Bytes(publicKey),
},
}
keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{}
for _, oldVerifyKey := range cfg.Matrix.OldVerifyKeys {
keys.OldVerifyKeys[oldVerifyKey.KeyID] = gomatrixserverlib.OldVerifyKey{
VerifyKey: gomatrixserverlib.VerifyKey{
Key: oldVerifyKey.PublicKey,
},
ExpiredTS: oldVerifyKey.ExpiredAt,
}
}
} else {
publicKey := virtualHost.PrivateKey.Public().(ed25519.PublicKey)
keys.ServerName = virtualHost.ServerName
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(virtualHost.KeyValidityPeriod))
keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{
virtualHost.KeyID: {
Key: gomatrixserverlib.Base64Bytes(publicKey),
},
}
// TODO: Virtual hosts probably want to be able to specify old signing
// keys too, just in case
} }
toSign, err := json.Marshal(keys.ServerKeyFields) toSign, err := json.Marshal(keys.ServerKeyFields)
@ -213,7 +226,7 @@ func NotaryKeys(
for serverName, kidToCriteria := range req.ServerKeys { for serverName, kidToCriteria := range req.ServerKeys {
var keyList []gomatrixserverlib.ServerKeys var keyList []gomatrixserverlib.ServerKeys
if serverName == cfg.Matrix.ServerName { if serverName == cfg.Matrix.ServerName {
if k, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil { if k, err := localKeys(cfg, serverName); err == nil {
keyList = append(keyList, *k) keyList = append(keyList, *k)
} else { } else {
return util.ErrorResponse(err) return util.ErrorResponse(err)

View file

@ -13,6 +13,7 @@
package routing package routing
import ( import (
"fmt"
"net/http" "net/http"
"time" "time"
@ -60,8 +61,18 @@ func MakeLeave(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
identity, err := cfg.Matrix.SigningIdentityFor(request.Destination())
if err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(
fmt.Sprintf("Server name %q does not exist", request.Destination()),
),
}
}
var queryRes api.QueryLatestEventsAndStateResponse var queryRes api.QueryLatestEventsAndStateResponse
event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,

View file

@ -22,7 +22,6 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -42,16 +41,9 @@ func GetProfile(
} }
} }
_, domain, err := gomatrixserverlib.SplitID('@', userID) _, domain, err := cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(httpReq.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument(fmt.Sprintf("Format of user ID %q is invalid", userID)),
}
}
if domain != cfg.Matrix.ServerName {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("Domain %q does not match this server", domain)), JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("Domain %q does not match this server", domain)),

View file

@ -83,7 +83,7 @@ func RoomAliasToID(
} }
} }
} else { } else {
resp, err = federation.LookupRoomAlias(httpReq.Context(), domain, roomAlias) resp, err = federation.LookupRoomAlias(httpReq.Context(), domain, cfg.Matrix.ServerName, roomAlias)
if err != nil { if err != nil {
switch x := err.(type) { switch x := err.(type) {
case gomatrix.HTTPError: case gomatrix.HTTPError:

View file

@ -197,12 +197,12 @@ type txnReq struct {
// A subset of FederationClient functionality that txn requires. Useful for testing. // A subset of FederationClient functionality that txn requires. Useful for testing.
type txnFederationClient interface { type txnFederationClient interface {
LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
res gomatrixserverlib.RespState, err error, res gomatrixserverlib.RespState, err error,
) )
LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents,
roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
} }
@ -287,6 +287,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
event.Headered(roomVersion), event.Headered(roomVersion),
}, },
t.Destination,
t.Origin, t.Origin,
api.DoNotSendToOtherServers, api.DoNotSendToOtherServers,
nil, nil,

View file

@ -147,7 +147,7 @@ type txnFedClient struct {
getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error)
} }
func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( func (c *txnFedClient) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
res gomatrixserverlib.RespState, err error, res gomatrixserverlib.RespState, err error,
) { ) {
fmt.Println("testFederationClient.LookupState", eventID) fmt.Println("testFederationClient.LookupState", eventID)
@ -159,7 +159,7 @@ func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.Serv
res = r res = r
return return
} }
func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { func (c *txnFedClient) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) {
fmt.Println("testFederationClient.LookupStateIDs", eventID) fmt.Println("testFederationClient.LookupStateIDs", eventID)
r, ok := c.stateIDs[eventID] r, ok := c.stateIDs[eventID]
if !ok { if !ok {
@ -169,7 +169,7 @@ func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.S
res = r res = r
return return
} }
func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { func (c *txnFedClient) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) {
fmt.Println("testFederationClient.GetEvent", eventID) fmt.Println("testFederationClient.GetEvent", eventID)
r, ok := c.getEvent[eventID] r, ok := c.getEvent[eventID]
if !ok { if !ok {
@ -179,7 +179,7 @@ func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerN
res = r res = r
return return
} }
func (c *txnFedClient) LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, func (c *txnFedClient) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents,
roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) { roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) {
return c.getMissingEvents(missing) return c.getMissingEvents(missing)
} }

View file

@ -90,7 +90,17 @@ func CreateInvitesFrom3PIDInvites(
} }
// Send all the events // Send all the events
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, evs, "TODO", cfg.Matrix.ServerName, nil, false); err != nil { if err := api.SendEvents(
req.Context(),
rsAPI,
api.KindNew,
evs,
cfg.Matrix.ServerName, // TODO: which virtual host?
"TODO",
cfg.Matrix.ServerName,
nil,
false,
); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -126,6 +136,14 @@ func ExchangeThirdPartyInvite(
} }
} }
_, senderDomain, err := cfg.Matrix.SplitLocalID('@', builder.Sender)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("Invalid sender ID: " + err.Error()),
}
}
// Check that the state key is correct. // Check that the state key is correct.
_, targetDomain, err := gomatrixserverlib.SplitID('@', *builder.StateKey) _, targetDomain, err := gomatrixserverlib.SplitID('@', *builder.StateKey)
if err != nil { if err != nil {
@ -171,7 +189,7 @@ func ExchangeThirdPartyInvite(
util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request") util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
signedEvent, err := federation.SendInviteV2(httpReq.Context(), request.Origin(), inviteReq) signedEvent, err := federation.SendInviteV2(httpReq.Context(), senderDomain, request.Origin(), inviteReq)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -189,6 +207,7 @@ func ExchangeThirdPartyInvite(
[]*gomatrixserverlib.HeaderedEvent{ []*gomatrixserverlib.HeaderedEvent{
inviteEvent.Headered(verRes.RoomVersion), inviteEvent.Headered(verRes.RoomVersion),
}, },
request.Destination(),
request.Origin(), request.Origin(),
cfg.Matrix.ServerName, cfg.Matrix.ServerName,
nil, nil,
@ -341,7 +360,7 @@ func buildMembershipEvent(
// them responded with an error. // them responded with an error.
func sendToRemoteServer( func sendToRemoteServer(
ctx context.Context, inv invite, ctx context.Context, inv invite,
federation federationAPI.FederationClient, _ *config.FederationAPI, federation federationAPI.FederationClient, cfg *config.FederationAPI,
builder gomatrixserverlib.EventBuilder, builder gomatrixserverlib.EventBuilder,
) (err error) { ) (err error) {
remoteServers := make([]gomatrixserverlib.ServerName, 2) remoteServers := make([]gomatrixserverlib.ServerName, 2)
@ -357,7 +376,7 @@ func sendToRemoteServer(
} }
for _, server := range remoteServers { for _, server := range remoteServers {
err = federation.ExchangeThirdPartyInvite(ctx, server, builder) err = federation.ExchangeThirdPartyInvite(ctx, cfg.Matrix.ServerName, server, builder)
if err == nil { if err == nil {
return return
} }

2
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2 github.com/matrix-org/gomatrixserverlib v0.0.0-20221111172904-e8fa860b1cdf
github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.15 github.com/mattn/go-sqlite3 v1.14.15

4
go.sum
View file

@ -348,8 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2 h1:Bet5n+//Yh+A2SuPHD67N8jrOhC/EIKvEgisfsKhTss= github.com/matrix-org/gomatrixserverlib v0.0.0-20221111172904-e8fa860b1cdf h1:OPM8urHlGC8m1MfusZg6Oi9zxJiDNuSVh3GQuvZeOHw=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/gomatrixserverlib v0.0.0-20221111172904-e8fa860b1cdf/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 h1:CQWFrgH9TJOU2f2qCDhGwaSdAnmgSu3/f+2xcf/Fse4= github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 h1:CQWFrgH9TJOU2f2qCDhGwaSdAnmgSu3/f+2xcf/Fse4=
github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=

View file

@ -38,7 +38,8 @@ var ErrRoomNoExists = errors.New("room does not exist")
// Returns an error if something else went wrong // Returns an error if something else went wrong
func QueryAndBuildEvent( func QueryAndBuildEvent(
ctx context.Context, ctx context.Context,
builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, builder *gomatrixserverlib.EventBuilder, cfg *config.Global,
identity *gomatrixserverlib.SigningIdentity, evTime time.Time,
rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse, rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*gomatrixserverlib.HeaderedEvent, error) {
if queryRes == nil { if queryRes == nil {
@ -50,24 +51,24 @@ func QueryAndBuildEvent(
// This can pass through a ErrRoomNoExists to the caller // This can pass through a ErrRoomNoExists to the caller
return nil, err return nil, err
} }
return BuildEvent(ctx, builder, cfg, evTime, eventsNeeded, queryRes) return BuildEvent(ctx, builder, cfg, identity, evTime, eventsNeeded, queryRes)
} }
// BuildEvent builds a Matrix event from the builder and QueryLatestEventsAndStateResponse // BuildEvent builds a Matrix event from the builder and QueryLatestEventsAndStateResponse
// provided. // provided.
func BuildEvent( func BuildEvent(
ctx context.Context, ctx context.Context,
builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, builder *gomatrixserverlib.EventBuilder, cfg *config.Global,
identity *gomatrixserverlib.SigningIdentity, evTime time.Time,
eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse, eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*gomatrixserverlib.HeaderedEvent, error) {
err := addPrevEventsToEvent(builder, eventsNeeded, queryRes) if err := addPrevEventsToEvent(builder, eventsNeeded, queryRes); err != nil {
if err != nil {
return nil, err return nil, err
} }
event, err := builder.Build( event, err := builder.Build(
evTime, cfg.ServerName, cfg.KeyID, evTime, identity.ServerName, identity.KeyID,
cfg.PrivateKey, queryRes.RoomVersion, identity.PrivateKey, queryRes.RoomVersion,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -30,12 +30,12 @@ import (
// DeviceListUpdateConsumer consumes device list updates that came in over federation. // DeviceListUpdateConsumer consumes device list updates that came in over federation.
type DeviceListUpdateConsumer struct { type DeviceListUpdateConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable string durable string
topic string topic string
updater *internal.DeviceListUpdater updater *internal.DeviceListUpdater
serverName gomatrixserverlib.ServerName isLocalServerName func(gomatrixserverlib.ServerName) bool
} }
// NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. // NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers.
@ -46,12 +46,12 @@ func NewDeviceListUpdateConsumer(
updater *internal.DeviceListUpdater, updater *internal.DeviceListUpdater,
) *DeviceListUpdateConsumer { ) *DeviceListUpdateConsumer {
return &DeviceListUpdateConsumer{ return &DeviceListUpdateConsumer{
ctx: process.Context(), ctx: process.Context(),
jetstream: js, jetstream: js,
durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
updater: updater, updater: updater,
serverName: cfg.Matrix.ServerName, isLocalServerName: cfg.Matrix.IsLocalServerName,
} }
} }
@ -75,7 +75,7 @@ func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M
origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) origin := gomatrixserverlib.ServerName(msg.Header.Get("origin"))
if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil { if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil {
return true return true
} else if serverName == t.serverName { } else if t.isLocalServerName(serverName) {
return true return true
} else if serverName != origin { } else if serverName != origin {
return true return true

View file

@ -31,12 +31,13 @@ import (
// SigningKeyUpdateConsumer consumes signing key updates that came in over federation. // SigningKeyUpdateConsumer consumes signing key updates that came in over federation.
type SigningKeyUpdateConsumer struct { type SigningKeyUpdateConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable string durable string
topic string topic string
keyAPI *internal.KeyInternalAPI keyAPI *internal.KeyInternalAPI
cfg *config.KeyServer cfg *config.KeyServer
isLocalServerName func(gomatrixserverlib.ServerName) bool
} }
// NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers. // NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers.
@ -47,12 +48,13 @@ func NewSigningKeyUpdateConsumer(
keyAPI *internal.KeyInternalAPI, keyAPI *internal.KeyInternalAPI,
) *SigningKeyUpdateConsumer { ) *SigningKeyUpdateConsumer {
return &SigningKeyUpdateConsumer{ return &SigningKeyUpdateConsumer{
ctx: process.Context(), ctx: process.Context(),
jetstream: js, jetstream: js,
durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"), durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
keyAPI: keyAPI, keyAPI: keyAPI,
cfg: cfg, cfg: cfg,
isLocalServerName: cfg.Matrix.IsLocalServerName,
} }
} }
@ -77,7 +79,7 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M
if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil { if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil {
logrus.WithError(err).Error("failed to split user id") logrus.WithError(err).Error("failed to split user id")
return true return true
} else if serverName == t.cfg.Matrix.ServerName { } else if t.isLocalServerName(serverName) {
logrus.Warn("dropping device key update from ourself") logrus.Warn("dropping device key update from ourself")
return true return true
} else if serverName != origin { } else if serverName != origin {

View file

@ -96,6 +96,7 @@ type DeviceListUpdater struct {
producer KeyChangeProducer producer KeyChangeProducer
fedClient fedsenderapi.KeyserverFederationAPI fedClient fedsenderapi.KeyserverFederationAPI
workerChans []chan gomatrixserverlib.ServerName workerChans []chan gomatrixserverlib.ServerName
thisServer gomatrixserverlib.ServerName
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
// block on or timeout via a select. // block on or timeout via a select.
@ -139,6 +140,7 @@ func NewDeviceListUpdater(
process *process.ProcessContext, db DeviceListUpdaterDatabase, process *process.ProcessContext, db DeviceListUpdaterDatabase,
api DeviceListUpdaterAPI, producer KeyChangeProducer, api DeviceListUpdaterAPI, producer KeyChangeProducer,
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
thisServer gomatrixserverlib.ServerName,
) *DeviceListUpdater { ) *DeviceListUpdater {
return &DeviceListUpdater{ return &DeviceListUpdater{
process: process, process: process,
@ -148,6 +150,7 @@ func NewDeviceListUpdater(
api: api, api: api,
producer: producer, producer: producer,
fedClient: fedClient, fedClient: fedClient,
thisServer: thisServer,
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
userIDToChan: make(map[string]chan bool), userIDToChan: make(map[string]chan bool),
userIDToChanMu: &sync.Mutex{}, userIDToChanMu: &sync.Mutex{},
@ -435,8 +438,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go
"server_name": serverName, "server_name": serverName,
"user_id": userID, "user_id": userID,
}) })
res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID)
res, err := u.fedClient.GetUserDevices(ctx, serverName, userID)
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.DeadlineExceeded) {
return time.Minute * 10, err return time.Minute * 10, err

View file

@ -129,7 +129,13 @@ func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient { func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient {
_, pkey, _ := ed25519.GenerateKey(nil) _, pkey, _ := ed25519.GenerateKey(nil)
fedClient := gomatrixserverlib.NewFederationClient( fedClient := gomatrixserverlib.NewFederationClient(
gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey, []*gomatrixserverlib.SigningIdentity{
{
ServerName: gomatrixserverlib.ServerName("example.test"),
KeyID: gomatrixserverlib.KeyID("ed25519:test"),
PrivateKey: pkey,
},
},
) )
fedClient.Client = *gomatrixserverlib.NewClient( fedClient.Client = *gomatrixserverlib.NewClient(
gomatrixserverlib.WithTransport(&roundTripper{tripper}), gomatrixserverlib.WithTransport(&roundTripper{tripper}),
@ -147,7 +153,7 @@ func TestUpdateHavePrevID(t *testing.T) {
} }
ap := &mockDeviceListUpdaterAPI{} ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{} producer := &mockKeyChangeProducer{}
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, "localhost")
event := gomatrixserverlib.DeviceListUpdateEvent{ event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Foo Bar", DeviceDisplayName: "Foo Bar",
Deleted: false, Deleted: false,
@ -219,7 +225,7 @@ func TestUpdateNoPrevID(t *testing.T) {
`)), `)),
}, nil }, nil
}) })
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, "example.test")
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err) t.Fatalf("failed to start updater: %s", err)
} }
@ -288,7 +294,7 @@ func TestDebounce(t *testing.T) {
close(incomingFedReq) close(incomingFedReq)
return <-fedCh, nil return <-fedCh, nil
}) })
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, "localhost")
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err) t.Fatalf("failed to start updater: %s", err)
} }

View file

@ -146,7 +146,7 @@ func (a *KeyInternalAPI) claimRemoteKeys(
defer cancel() defer cancel()
defer wg.Done() defer wg.Done()
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim)
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
@ -559,7 +559,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
if len(devKeys) == 0 { if len(devKeys) == 0 {
return return
} }
queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys) queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys)
if err == nil { if err == nil {
resultCh <- &queryKeysResp resultCh <- &queryKeysResp
return return

View file

@ -58,7 +58,7 @@ func NewInternalAPI(
FedClient: fedClient, FedClient: fedClient,
Producer: keyChangeProducer, Producer: keyChangeProducer,
} }
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, cfg.Matrix.ServerName) // 8 workers TODO: configurable
ap.Updater = updater ap.Updater = updater
go func() { go func() {
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {

View file

@ -94,8 +94,9 @@ type TransactionID struct {
// InputRoomEventsRequest is a request to InputRoomEvents // InputRoomEventsRequest is a request to InputRoomEvents
type InputRoomEventsRequest struct { type InputRoomEventsRequest struct {
InputRoomEvents []InputRoomEvent `json:"input_room_events"` InputRoomEvents []InputRoomEvent `json:"input_room_events"`
Asynchronous bool `json:"async"` Asynchronous bool `json:"async"`
VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"`
} }
// InputRoomEventsResponse is a response to InputRoomEvents // InputRoomEventsResponse is a response to InputRoomEvents

View file

@ -148,6 +148,8 @@ type PerformBackfillRequest struct {
Limit int `json:"limit"` Limit int `json:"limit"`
// The server interested in the events. // The server interested in the events.
ServerName gomatrixserverlib.ServerName `json:"server_name"` ServerName gomatrixserverlib.ServerName `json:"server_name"`
// Which virtual host are we doing this for?
VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"`
} }
// PrevEventIDs returns the prev_event IDs of all backwards extremities, de-duplicated in a lexicographically sorted order. // PrevEventIDs returns the prev_event IDs of all backwards extremities, de-duplicated in a lexicographically sorted order.

View file

@ -26,7 +26,7 @@ import (
func SendEvents( func SendEvents(
ctx context.Context, rsAPI InputRoomEventsAPI, ctx context.Context, rsAPI InputRoomEventsAPI,
kind Kind, events []*gomatrixserverlib.HeaderedEvent, kind Kind, events []*gomatrixserverlib.HeaderedEvent,
origin gomatrixserverlib.ServerName, virtualHost, origin gomatrixserverlib.ServerName,
sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID,
async bool, async bool,
) error { ) error {
@ -40,14 +40,15 @@ func SendEvents(
TransactionID: txnID, TransactionID: txnID,
} }
} }
return SendInputRoomEvents(ctx, rsAPI, ires, async) return SendInputRoomEvents(ctx, rsAPI, virtualHost, ires, async)
} }
// SendEventWithState writes an event with the specified kind to the roomserver // SendEventWithState writes an event with the specified kind to the roomserver
// with the state at the event as KindOutlier before it. Will not send any event that is // with the state at the event as KindOutlier before it. Will not send any event that is
// marked as `true` in haveEventIDs. // marked as `true` in haveEventIDs.
func SendEventWithState( func SendEventWithState(
ctx context.Context, rsAPI InputRoomEventsAPI, kind Kind, ctx context.Context, rsAPI InputRoomEventsAPI,
virtualHost gomatrixserverlib.ServerName, kind Kind,
state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent,
origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool,
) error { ) error {
@ -85,17 +86,19 @@ func SendEventWithState(
StateEventIDs: stateEventIDs, StateEventIDs: stateEventIDs,
}) })
return SendInputRoomEvents(ctx, rsAPI, ires, async) return SendInputRoomEvents(ctx, rsAPI, virtualHost, ires, async)
} }
// SendInputRoomEvents to the roomserver. // SendInputRoomEvents to the roomserver.
func SendInputRoomEvents( func SendInputRoomEvents(
ctx context.Context, rsAPI InputRoomEventsAPI, ctx context.Context, rsAPI InputRoomEventsAPI,
virtualHost gomatrixserverlib.ServerName,
ires []InputRoomEvent, async bool, ires []InputRoomEvent, async bool,
) error { ) error {
request := InputRoomEventsRequest{ request := InputRoomEventsRequest{
InputRoomEvents: ires, InputRoomEvents: ires,
Asynchronous: async, Asynchronous: async,
VirtualHost: virtualHost,
} }
var response InputRoomEventsResponse var response InputRoomEventsResponse
if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil { if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil {

View file

@ -137,6 +137,11 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
request *api.RemoveRoomAliasRequest, request *api.RemoveRoomAliasRequest,
response *api.RemoveRoomAliasResponse, response *api.RemoveRoomAliasResponse,
) error { ) error {
_, virtualHost, err := r.Cfg.Matrix.SplitLocalID('@', request.UserID)
if err != nil {
return err
}
roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err) return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err)
@ -190,6 +195,16 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
sender = ev.Sender() sender = ev.Sender()
} }
_, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', sender)
if err != nil {
return err
}
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil {
return err
}
builder := &gomatrixserverlib.EventBuilder{ builder := &gomatrixserverlib.EventBuilder{
Sender: sender, Sender: sender,
RoomID: ev.RoomID(), RoomID: ev.RoomID(),
@ -211,12 +226,12 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return err return err
} }
newEvent, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, time.Now(), &eventsNeeded, stateRes) newEvent, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, stateRes)
if err != nil { if err != nil {
return err return err
} }
err = api.SendEvents(ctx, r, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, r.ServerName, r.ServerName, nil, false) err = api.SendEvents(ctx, r, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, virtualHost, r.ServerName, r.ServerName, nil, false)
if err != nil { if err != nil {
return err return err
} }

View file

@ -87,10 +87,10 @@ func NewRoomserverAPI(
Durable: base.Cfg.Global.JetStream.Durable("RoomserverInputConsumer"), Durable: base.Cfg.Global.JetStream.Durable("RoomserverInputConsumer"),
ServerACLs: serverACLs, ServerACLs: serverACLs,
Queryer: &query.Queryer{ Queryer: &query.Queryer{
DB: roomserverDB, DB: roomserverDB,
Cache: base.Caches, Cache: base.Caches,
ServerName: base.Cfg.Global.ServerName, IsLocalServerName: base.Cfg.Global.IsLocalServerName,
ServerACLs: serverACLs, ServerACLs: serverACLs,
}, },
// perform-er structs get initialised when we have a federation sender to use // perform-er structs get initialised when we have a federation sender to use
} }
@ -127,13 +127,12 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
Inputer: r.Inputer, Inputer: r.Inputer,
} }
r.Joiner = &perform.Joiner{ r.Joiner = &perform.Joiner{
ServerName: r.Cfg.Matrix.ServerName, Cfg: r.Cfg,
Cfg: r.Cfg, DB: r.DB,
DB: r.DB, FSAPI: r.fsAPI,
FSAPI: r.fsAPI, RSAPI: r,
RSAPI: r, Inputer: r.Inputer,
Inputer: r.Inputer, Queryer: r.Queryer,
Queryer: r.Queryer,
} }
r.Peeker = &perform.Peeker{ r.Peeker = &perform.Peeker{
ServerName: r.Cfg.Matrix.ServerName, ServerName: r.Cfg.Matrix.ServerName,
@ -163,10 +162,10 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
DB: r.DB, DB: r.DB,
} }
r.Backfiller = &perform.Backfiller{ r.Backfiller = &perform.Backfiller{
ServerName: r.ServerName, IsLocalServerName: r.Cfg.Matrix.IsLocalServerName,
DB: r.DB, DB: r.DB,
FSAPI: r.fsAPI, FSAPI: r.fsAPI,
KeyRing: r.KeyRing, KeyRing: r.KeyRing,
// Perspective servers are trusted to not lie about server keys, so we will also // Perspective servers are trusted to not lie about server keys, so we will also
// prefer these servers when backfilling (assuming they are in the room) rather // prefer these servers when backfilling (assuming they are in the room) rather
// than trying random servers // than trying random servers

View file

@ -278,7 +278,11 @@ func (w *worker) _next() {
// a string, because we might want to return that to the caller if // a string, because we might want to return that to the caller if
// it was a synchronous request. // it was a synchronous request.
var errString string var errString string
if err = w.r.processRoomEvent(w.r.ProcessContext.Context(), &inputRoomEvent); err != nil { if err = w.r.processRoomEvent(
w.r.ProcessContext.Context(),
gomatrixserverlib.ServerName(msg.Header.Get("virtual_host")),
&inputRoomEvent,
); err != nil {
switch err.(type) { switch err.(type) {
case types.RejectedError: case types.RejectedError:
// Don't send events that were rejected to Sentry // Don't send events that were rejected to Sentry
@ -358,6 +362,7 @@ func (r *Inputer) queueInputRoomEvents(
if replyTo != "" { if replyTo != "" {
msg.Header.Set("sync", replyTo) msg.Header.Set("sync", replyTo)
} }
msg.Header.Set("virtual_host", string(request.VirtualHost))
msg.Data, err = json.Marshal(e) msg.Data, err = json.Marshal(e)
if err != nil { if err != nil {
return nil, fmt.Errorf("json.Marshal: %w", err) return nil, fmt.Errorf("json.Marshal: %w", err)

View file

@ -69,6 +69,7 @@ var processRoomEventDuration = prometheus.NewHistogramVec(
// nolint:gocyclo // nolint:gocyclo
func (r *Inputer) processRoomEvent( func (r *Inputer) processRoomEvent(
ctx context.Context, ctx context.Context,
virtualHost gomatrixserverlib.ServerName,
input *api.InputRoomEvent, input *api.InputRoomEvent,
) error { ) error {
select { select {
@ -200,7 +201,7 @@ func (r *Inputer) processRoomEvent(
isRejected := false isRejected := false
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
knownEvents := map[string]*types.Event{} knownEvents := map[string]*types.Event{}
if err = r.fetchAuthEvents(ctx, logger, roomInfo, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
return fmt.Errorf("r.fetchAuthEvents: %w", err) return fmt.Errorf("r.fetchAuthEvents: %w", err)
} }
@ -265,16 +266,17 @@ func (r *Inputer) processRoomEvent(
// processRoomEvent. // processRoomEvent.
if len(serverRes.ServerNames) > 0 { if len(serverRes.ServerNames) > 0 {
missingState := missingStateReq{ missingState := missingStateReq{
origin: input.Origin, origin: input.Origin,
inputer: r, virtualHost: virtualHost,
db: r.DB, inputer: r,
roomInfo: roomInfo, db: r.DB,
federation: r.FSAPI, roomInfo: roomInfo,
keys: r.KeyRing, federation: r.FSAPI,
roomsMu: internal.NewMutexByRoom(), keys: r.KeyRing,
servers: serverRes.ServerNames, roomsMu: internal.NewMutexByRoom(),
hadEvents: map[string]bool{}, servers: serverRes.ServerNames,
haveEvents: map[string]*gomatrixserverlib.Event{}, hadEvents: map[string]bool{},
haveEvents: map[string]*gomatrixserverlib.Event{},
} }
var stateSnapshot *parsedRespState var stateSnapshot *parsedRespState
if stateSnapshot, err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { if stateSnapshot, err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil {
@ -555,6 +557,7 @@ func (r *Inputer) fetchAuthEvents(
ctx context.Context, ctx context.Context,
logger *logrus.Entry, logger *logrus.Entry,
roomInfo *types.RoomInfo, roomInfo *types.RoomInfo,
virtualHost gomatrixserverlib.ServerName,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
auth *gomatrixserverlib.AuthEvents, auth *gomatrixserverlib.AuthEvents,
known map[string]*types.Event, known map[string]*types.Event,
@ -605,7 +608,7 @@ func (r *Inputer) fetchAuthEvents(
// Request the entire auth chain for the event in question. This should // Request the entire auth chain for the event in question. This should
// contain all of the auth events — including ones that we already know — // contain all of the auth events — including ones that we already know —
// so we'll need to filter through those in the next section. // so we'll need to filter through those in the next section.
res, err = r.FSAPI.GetEventAuth(ctx, serverName, event.RoomVersion, event.RoomID(), event.EventID()) res, err = r.FSAPI.GetEventAuth(ctx, virtualHost, serverName, event.RoomVersion, event.RoomID(), event.EventID())
if err != nil { if err != nil {
logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err) logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err)
continue continue

View file

@ -41,6 +41,7 @@ func (p *parsedRespState) Events() []*gomatrixserverlib.Event {
type missingStateReq struct { type missingStateReq struct {
log *logrus.Entry log *logrus.Entry
virtualHost gomatrixserverlib.ServerName
origin gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName
db storage.Database db storage.Database
roomInfo *types.RoomInfo roomInfo *types.RoomInfo
@ -101,7 +102,7 @@ func (t *missingStateReq) processEventWithMissingState(
// we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled
// in the gap in the DAG // in the gap in the DAG
for _, newEvent := range newEvents { for _, newEvent := range newEvents {
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: newEvent.Headered(roomVersion), Event: newEvent.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,
@ -157,7 +158,7 @@ func (t *missingStateReq) processEventWithMissingState(
}) })
} }
for _, ire := range outlierRoomEvents { for _, ire := range outlierRoomEvents {
if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { if err = t.inputer.processRoomEvent(ctx, t.virtualHost, &ire); err != nil {
if _, ok := err.(types.RejectedError); !ok { if _, ok := err.(types.RejectedError); !ok {
return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err) return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err)
} }
@ -180,7 +181,7 @@ func (t *missingStateReq) processEventWithMissingState(
stateIDs = append(stateIDs, event.EventID()) stateIDs = append(stateIDs, event.EventID())
} }
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: backwardsExtremity.Headered(roomVersion), Event: backwardsExtremity.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,
@ -199,7 +200,7 @@ func (t *missingStateReq) processEventWithMissingState(
// they will automatically fast-forward based on the room state at the // they will automatically fast-forward based on the room state at the
// extremity in the last step. // extremity in the last step.
for _, newEvent := range newEvents { for _, newEvent := range newEvents {
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: newEvent.Headered(roomVersion), Event: newEvent.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,
@ -519,7 +520,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve
var missingResp *gomatrixserverlib.RespMissingEvents var missingResp *gomatrixserverlib.RespMissingEvents
for _, server := range t.servers { for _, server := range t.servers {
var m gomatrixserverlib.RespMissingEvents var m gomatrixserverlib.RespMissingEvents
if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ if m, err = t.federation.LookupMissingEvents(ctx, t.virtualHost, server, e.RoomID(), gomatrixserverlib.MissingEvents{
Limit: 20, Limit: 20,
// The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events.
EarliestEvents: latestEvents, EarliestEvents: latestEvents,
@ -635,7 +636,7 @@ func (t *missingStateReq) lookupMissingStateViaState(
span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaState") span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaState")
defer span.Finish() defer span.Finish()
state, err := t.federation.LookupState(ctx, t.origin, roomID, eventID, roomVersion) state, err := t.federation.LookupState(ctx, t.virtualHost, t.origin, roomID, eventID, roomVersion)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -675,7 +676,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
totalctx, totalcancel := context.WithTimeout(ctx, time.Minute*5) totalctx, totalcancel := context.WithTimeout(ctx, time.Minute*5)
for _, serverName := range t.servers { for _, serverName := range t.servers {
reqctx, reqcancel := context.WithTimeout(totalctx, time.Second*20) reqctx, reqcancel := context.WithTimeout(totalctx, time.Second*20)
stateIDs, err = t.federation.LookupStateIDs(reqctx, serverName, roomID, eventID) stateIDs, err = t.federation.LookupStateIDs(reqctx, t.virtualHost, serverName, roomID, eventID)
reqcancel() reqcancel()
if err == nil { if err == nil {
break break
@ -855,7 +856,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
for _, serverName := range t.servers { for _, serverName := range t.servers {
reqctx, cancel := context.WithTimeout(ctx, time.Second*30) reqctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
txn, err := t.federation.GetEvent(reqctx, serverName, missingEventID) txn, err := t.federation.GetEvent(reqctx, t.virtualHost, serverName, missingEventID)
if err != nil || len(txn.PDUs) == 0 { if err != nil || len(txn.PDUs) == 0 {
t.log.WithError(err).WithField("missing_event_id", missingEventID).Warn("Failed to get missing /event for event ID") t.log.WithError(err).WithField("missing_event_id", missingEventID).Warn("Failed to get missing /event for event ID")
if errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.DeadlineExceeded) {

View file

@ -139,7 +139,12 @@ func (r *Admin) PerformAdminEvacuateRoom(
return nil return nil
} }
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes) identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil {
continue
}
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, latestRes)
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
@ -242,6 +247,15 @@ func (r *Admin) PerformAdminDownloadState(
req *api.PerformAdminDownloadStateRequest, req *api.PerformAdminDownloadStateRequest,
res *api.PerformAdminDownloadStateResponse, res *api.PerformAdminDownloadStateResponse,
) error { ) error {
_, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', req.UserID)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Cfg.Matrix.SplitLocalID: %s", err),
}
return nil
}
roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID)
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
@ -273,7 +287,7 @@ func (r *Admin) PerformAdminDownloadState(
for _, fwdExtremity := range fwdExtremities { for _, fwdExtremity := range fwdExtremities {
var state gomatrixserverlib.RespState var state gomatrixserverlib.RespState
state, err = r.Inputer.FSAPI.LookupState(ctx, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion) state, err = r.Inputer.FSAPI.LookupState(ctx, r.Inputer.ServerName, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion)
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
@ -331,7 +345,12 @@ func (r *Admin) PerformAdminDownloadState(
Depth: depth, Depth: depth,
} }
ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, time.Now(), &eventsNeeded, queryRes) identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil {
return err
}
ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, queryRes)
if err != nil { if err != nil {
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,

View file

@ -37,10 +37,10 @@ import (
const maxBackfillServers = 5 const maxBackfillServers = 5
type Backfiller struct { type Backfiller struct {
ServerName gomatrixserverlib.ServerName IsLocalServerName func(gomatrixserverlib.ServerName) bool
DB storage.Database DB storage.Database
FSAPI federationAPI.RoomserverFederationAPI FSAPI federationAPI.RoomserverFederationAPI
KeyRing gomatrixserverlib.JSONVerifier KeyRing gomatrixserverlib.JSONVerifier
// The servers which should be preferred above other servers when backfilling // The servers which should be preferred above other servers when backfilling
PreferServers []gomatrixserverlib.ServerName PreferServers []gomatrixserverlib.ServerName
@ -55,7 +55,7 @@ func (r *Backfiller) PerformBackfill(
// if we are requesting the backfill then we need to do a federation hit // if we are requesting the backfill then we need to do a federation hit
// TODO: we could be more sensible and fetch as many events we already have then request the rest // TODO: we could be more sensible and fetch as many events we already have then request the rest
// which is what the syncapi does already. // which is what the syncapi does already.
if request.ServerName == r.ServerName { if r.IsLocalServerName(request.ServerName) {
return r.backfillViaFederation(ctx, request, response) return r.backfillViaFederation(ctx, request, response)
} }
// someone else is requesting the backfill, try to service their request. // someone else is requesting the backfill, try to service their request.
@ -112,16 +112,18 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
if info == nil || info.IsStub() { if info == nil || info.IsStub() {
return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID) return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID)
} }
requester := newBackfillRequester(r.DB, r.FSAPI, r.ServerName, req.BackwardsExtremities, r.PreferServers) requester := newBackfillRequester(r.DB, r.FSAPI, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers)
// Request 100 items regardless of what the query asks for. // Request 100 items regardless of what the query asks for.
// We don't want to go much higher than this. // We don't want to go much higher than this.
// We can't honour exactly the limit as some sytests rely on requesting more for tests to pass // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass
// (so we don't need to hit /state_ids which the test has no listener for) // (so we don't need to hit /state_ids which the test has no listener for)
// Specifically the test "Outbound federation can backfill events" // Specifically the test "Outbound federation can backfill events"
events, err := gomatrixserverlib.RequestBackfill( events, err := gomatrixserverlib.RequestBackfill(
ctx, requester, ctx, req.VirtualHost, requester,
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100) r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100,
)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("gomatrixserverlib.RequestBackfill failed")
return err return err
} }
logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
@ -144,7 +146,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
var entries []types.StateEntry var entries []types.StateEntry
if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true); err != nil { if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true); err != nil {
// attempt to fetch the missing events // attempt to fetch the missing events
r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs) r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs, req.VirtualHost)
// try again // try again
entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true) entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true)
if err != nil { if err != nil {
@ -173,7 +175,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just // fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just
// best effort. // best effort.
func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
backfillRequester *backfillRequester, stateIDs []string) { backfillRequester *backfillRequester, stateIDs []string, virtualHost gomatrixserverlib.ServerName) {
servers := backfillRequester.servers servers := backfillRequester.servers
@ -198,7 +200,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
continue // already found continue // already found
} }
logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id) logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id)
res, err := r.FSAPI.GetEvent(ctx, srv, id) res, err := r.FSAPI.GetEvent(ctx, virtualHost, srv, id)
if err != nil { if err != nil {
logger.WithError(err).Warn("failed to get event from server") logger.WithError(err).Warn("failed to get event from server")
continue continue
@ -241,11 +243,12 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
// backfillRequester implements gomatrixserverlib.BackfillRequester // backfillRequester implements gomatrixserverlib.BackfillRequester
type backfillRequester struct { type backfillRequester struct {
db storage.Database db storage.Database
fsAPI federationAPI.RoomserverFederationAPI fsAPI federationAPI.RoomserverFederationAPI
thisServer gomatrixserverlib.ServerName virtualHost gomatrixserverlib.ServerName
preferServer map[gomatrixserverlib.ServerName]bool isLocalServerName func(gomatrixserverlib.ServerName) bool
bwExtrems map[string][]string preferServer map[gomatrixserverlib.ServerName]bool
bwExtrems map[string][]string
// per-request state // per-request state
servers []gomatrixserverlib.ServerName servers []gomatrixserverlib.ServerName
@ -255,7 +258,9 @@ type backfillRequester struct {
} }
func newBackfillRequester( func newBackfillRequester(
db storage.Database, fsAPI federationAPI.RoomserverFederationAPI, thisServer gomatrixserverlib.ServerName, db storage.Database, fsAPI federationAPI.RoomserverFederationAPI,
virtualHost gomatrixserverlib.ServerName,
isLocalServerName func(gomatrixserverlib.ServerName) bool,
bwExtrems map[string][]string, preferServers []gomatrixserverlib.ServerName, bwExtrems map[string][]string, preferServers []gomatrixserverlib.ServerName,
) *backfillRequester { ) *backfillRequester {
preferServer := make(map[gomatrixserverlib.ServerName]bool) preferServer := make(map[gomatrixserverlib.ServerName]bool)
@ -265,7 +270,8 @@ func newBackfillRequester(
return &backfillRequester{ return &backfillRequester{
db: db, db: db,
fsAPI: fsAPI, fsAPI: fsAPI,
thisServer: thisServer, virtualHost: virtualHost,
isLocalServerName: isLocalServerName,
eventIDToBeforeStateIDs: make(map[string][]string), eventIDToBeforeStateIDs: make(map[string][]string),
eventIDMap: make(map[string]*gomatrixserverlib.Event), eventIDMap: make(map[string]*gomatrixserverlib.Event),
bwExtrems: bwExtrems, bwExtrems: bwExtrems,
@ -450,7 +456,7 @@ FindSuccessor:
} }
// possibly return all joined servers depending on history visiblity // possibly return all joined servers depending on history visiblity
memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.thisServer) memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.virtualHost)
b.historyVisiblity = visibility b.historyVisiblity = visibility
if err != nil { if err != nil {
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
@ -477,7 +483,7 @@ FindSuccessor:
} }
var servers []gomatrixserverlib.ServerName var servers []gomatrixserverlib.ServerName
for server := range serverSet { for server := range serverSet {
if server == b.thisServer { if b.isLocalServerName(server) {
continue continue
} }
if b.preferServer[server] { // insert at the front if b.preferServer[server] { // insert at the front
@ -496,10 +502,10 @@ FindSuccessor:
// Backfill performs a backfill request to the given server. // Backfill performs a backfill request to the given server.
// https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid // https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid
func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string, func (b *backfillRequester) Backfill(ctx context.Context, origin, server gomatrixserverlib.ServerName, roomID string,
limit int, fromEventIDs []string) (gomatrixserverlib.Transaction, error) { limit int, fromEventIDs []string) (gomatrixserverlib.Transaction, error) {
tx, err := b.fsAPI.Backfill(ctx, server, roomID, limit, fromEventIDs) tx, err := b.fsAPI.Backfill(ctx, origin, server, roomID, limit, fromEventIDs)
return tx, err return tx, err
} }

View file

@ -39,11 +39,10 @@ import (
) )
type Joiner struct { type Joiner struct {
ServerName gomatrixserverlib.ServerName Cfg *config.RoomServer
Cfg *config.RoomServer FSAPI fsAPI.RoomserverFederationAPI
FSAPI fsAPI.RoomserverFederationAPI RSAPI rsAPI.RoomserverInternalAPI
RSAPI rsAPI.RoomserverInternalAPI DB storage.Database
DB storage.Database
Inputer *input.Inputer Inputer *input.Inputer
Queryer *query.Queryer Queryer *query.Queryer
@ -197,7 +196,7 @@ func (r *Joiner) performJoinRoomByID(
// Prepare the template for the join event. // Prepare the template for the join event.
userID := req.UserID userID := req.UserID
_, userDomain, err := gomatrixserverlib.SplitID('@', userID) _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
return "", "", &rsAPI.PerformError{ return "", "", &rsAPI.PerformError{
Code: rsAPI.PerformErrorBadRequest, Code: rsAPI.PerformErrorBadRequest,
@ -283,7 +282,7 @@ func (r *Joiner) performJoinRoomByID(
// locally on the homeserver. // locally on the homeserver.
// TODO: Check what happens if the room exists on the server // TODO: Check what happens if the room exists on the server
// but everyone has since left. I suspect it does the wrong thing. // but everyone has since left. I suspect it does the wrong thing.
event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, &eb) event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, userDomain, &eb)
switch err { switch err {
case nil: case nil:
@ -410,7 +409,9 @@ func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin(
} }
func buildEvent( func buildEvent(
ctx context.Context, db storage.Database, cfg *config.Global, builder *gomatrixserverlib.EventBuilder, ctx context.Context, db storage.Database, cfg *config.Global,
senderDomain gomatrixserverlib.ServerName,
builder *gomatrixserverlib.EventBuilder,
) (*gomatrixserverlib.HeaderedEvent, *rsAPI.QueryLatestEventsAndStateResponse, error) { ) (*gomatrixserverlib.HeaderedEvent, *rsAPI.QueryLatestEventsAndStateResponse, error) {
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
if err != nil { if err != nil {
@ -438,7 +439,12 @@ func buildEvent(
} }
} }
ev, err := eventutil.BuildEvent(ctx, builder, cfg, time.Now(), &eventsNeeded, &queryRes) identity, err := cfg.SigningIdentityFor(senderDomain)
if err != nil {
return nil, nil, err
}
ev, err := eventutil.BuildEvent(ctx, builder, cfg, identity, time.Now(), &eventsNeeded, &queryRes)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View file

@ -162,21 +162,21 @@ func (r *Leaver) performLeaveRoomByID(
return nil, fmt.Errorf("eb.SetUnsigned: %w", err) return nil, fmt.Errorf("eb.SetUnsigned: %w", err)
} }
// Get the sender domain.
_, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', eb.Sender)
if serr != nil {
return nil, fmt.Errorf("sender %q is invalid", eb.Sender)
}
// We know that the user is in the room at this point so let's build // We know that the user is in the room at this point so let's build
// a leave event. // a leave event.
// TODO: Check what happens if the room exists on the server // TODO: Check what happens if the room exists on the server
// but everyone has since left. I suspect it does the wrong thing. // but everyone has since left. I suspect it does the wrong thing.
event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, &eb) event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, senderDomain, &eb)
if err != nil { if err != nil {
return nil, fmt.Errorf("eventutil.BuildEvent: %w", err) return nil, fmt.Errorf("eventutil.BuildEvent: %w", err)
} }
// Get the sender domain.
_, senderDomain, serr := gomatrixserverlib.SplitID('@', event.Sender())
if serr != nil {
return nil, fmt.Errorf("sender %q is invalid", event.Sender())
}
// Give our leave event to the roomserver input stream. The // Give our leave event to the roomserver input stream. The
// roomserver will process the membership change and notify // roomserver will process the membership change and notify
// downstream automatically. // downstream automatically.

View file

@ -60,7 +60,7 @@ func (r *Upgrader) performRoomUpgrade(
) (string, *api.PerformError) { ) (string, *api.PerformError) {
roomID := req.RoomID roomID := req.RoomID
userID := req.UserID userID := req.UserID
_, userDomain, err := gomatrixserverlib.SplitID('@', userID) _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
return "", &api.PerformError{ return "", &api.PerformError{
Code: api.PerformErrorNotAllowed, Code: api.PerformErrorNotAllowed,
@ -558,7 +558,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
SendAsServer: api.DoNotSendToOtherServers, SendAsServer: api.DoNotSendToOtherServers,
}) })
} }
if err = api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { if err = api.SendInputRoomEvents(ctx, r.URSAPI, userDomain, inputs, false); err != nil {
return &api.PerformError{ return &api.PerformError{
Msg: fmt.Sprintf("Failed to send new room %q to roomserver: %s", newRoomID, err), Msg: fmt.Sprintf("Failed to send new room %q to roomserver: %s", newRoomID, err),
} }
@ -595,8 +595,21 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user
Msg: fmt.Sprintf("Failed to set new %q event content: %s", builder.Type, err), Msg: fmt.Sprintf("Failed to set new %q event content: %s", builder.Type, err),
} }
} }
// Get the sender domain.
_, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', builder.Sender)
if serr != nil {
return nil, &api.PerformError{
Msg: fmt.Sprintf("Failed to split user ID %q: %s", builder.Sender, err),
}
}
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil {
return nil, &api.PerformError{
Msg: fmt.Sprintf("Failed to get signing identity for %q: %s", senderDomain, err),
}
}
var queryRes api.QueryLatestEventsAndStateResponse var queryRes api.QueryLatestEventsAndStateResponse
headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &builder, r.Cfg.Matrix, evTime, r.URSAPI, &queryRes) headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &builder, r.Cfg.Matrix, identity, evTime, r.URSAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return nil, &api.PerformError{ return nil, &api.PerformError{
Code: api.PerformErrorNoRoom, Code: api.PerformErrorNoRoom,
@ -686,7 +699,7 @@ func (r *Upgrader) sendHeaderedEvent(
Origin: serverName, Origin: serverName,
SendAsServer: sendAsServer, SendAsServer: sendAsServer,
}) })
if err := api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { if err := api.SendInputRoomEvents(ctx, r.URSAPI, serverName, inputs, false); err != nil {
return &api.PerformError{ return &api.PerformError{
Msg: fmt.Sprintf("Failed to send new %q event to roomserver: %s", headeredEvent.Type(), err), Msg: fmt.Sprintf("Failed to send new %q event to roomserver: %s", headeredEvent.Type(), err),
} }

View file

@ -37,10 +37,10 @@ import (
) )
type Queryer struct { type Queryer struct {
DB storage.Database DB storage.Database
Cache caching.RoomServerCaches Cache caching.RoomServerCaches
ServerName gomatrixserverlib.ServerName IsLocalServerName func(gomatrixserverlib.ServerName) bool
ServerACLs *acls.ServerACLs ServerACLs *acls.ServerACLs
} }
// QueryLatestEventsAndState implements api.RoomserverInternalAPI // QueryLatestEventsAndState implements api.RoomserverInternalAPI
@ -392,7 +392,7 @@ func (r *Queryer) QueryServerJoinedToRoom(
} }
response.RoomExists = true response.RoomExists = true
if request.ServerName == r.ServerName || request.ServerName == "" { if r.IsLocalServerName(request.ServerName) || request.ServerName == "" {
response.IsInRoom, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID) response.IsInRoom, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err) return fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err)

View file

@ -44,7 +44,7 @@ func Test_SharedUsers(t *testing.T) {
// SetFederationAPI starts the room event input consumer // SetFederationAPI starts the room event input consumer
rsAPI.SetFederationAPI(nil, nil) rsAPI.SetFederationAPI(nil, nil)
// Create the room // Create the room
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", nil, false); err != nil { if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err) t.Fatalf("failed to send events: %v", err)
} }

View file

@ -364,10 +364,10 @@ func (b *BaseDendrite) CreateClient() *gomatrixserverlib.Client {
// CreateFederationClient creates a new federation client. Should only be called // CreateFederationClient creates a new federation client. Should only be called
// once per component. // once per component.
func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationClient { func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationClient {
identities := b.Cfg.Global.SigningIdentities()
if b.Cfg.Global.DisableFederation { if b.Cfg.Global.DisableFederation {
return gomatrixserverlib.NewFederationClient( return gomatrixserverlib.NewFederationClient(
b.Cfg.Global.ServerName, b.Cfg.Global.KeyID, b.Cfg.Global.PrivateKey, identities, gomatrixserverlib.WithTransport(noOpHTTPTransport),
gomatrixserverlib.WithTransport(noOpHTTPTransport),
) )
} }
opts := []gomatrixserverlib.ClientOption{ opts := []gomatrixserverlib.ClientOption{
@ -379,8 +379,7 @@ func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationCli
opts = append(opts, gomatrixserverlib.WithDNSCache(b.DNSCache)) opts = append(opts, gomatrixserverlib.WithDNSCache(b.DNSCache))
} }
client := gomatrixserverlib.NewFederationClient( client := gomatrixserverlib.NewFederationClient(
b.Cfg.Global.ServerName, b.Cfg.Global.KeyID, identities, opts...,
b.Cfg.Global.PrivateKey, opts...,
) )
client.SetUserAgent(fmt.Sprintf("Dendrite/%s", internal.VersionString())) client.SetUserAgent(fmt.Sprintf("Dendrite/%s", internal.VersionString()))
return client return client

View file

@ -231,6 +231,21 @@ func loadConfig(
return nil, err return nil, err
} }
for _, v := range c.Global.VirtualHosts {
if v.KeyValidityPeriod == 0 {
v.KeyValidityPeriod = c.Global.KeyValidityPeriod
}
if v.PrivateKeyPath == "" {
v.KeyID = c.Global.KeyID
v.PrivateKey = c.Global.PrivateKey
continue
}
privateKeyPath := absPath(basePath, v.PrivateKeyPath)
if v.KeyID, v.PrivateKey, err = LoadMatrixKey(privateKeyPath, readFile); err != nil {
return nil, err
}
}
for _, key := range c.Global.OldVerifyKeys { for _, key := range c.Global.OldVerifyKeys {
switch { switch {
case key.PrivateKeyPath != "": case key.PrivateKeyPath != "":

View file

@ -1,6 +1,7 @@
package config package config
import ( import (
"fmt"
"math/rand" "math/rand"
"strconv" "strconv"
"strings" "strings"
@ -15,7 +16,7 @@ type Global struct {
ServerName gomatrixserverlib.ServerName `yaml:"server_name"` ServerName gomatrixserverlib.ServerName `yaml:"server_name"`
// The secondary server names, used for virtual hosting. // The secondary server names, used for virtual hosting.
SecondaryServerNames []gomatrixserverlib.ServerName `yaml:"-"` VirtualHosts []*VirtualHost `yaml:"virtual_hosts"`
// Path to the private key which will be used to sign requests and events. // Path to the private key which will be used to sign requests and events.
PrivateKeyPath Path `yaml:"private_key"` PrivateKeyPath Path `yaml:"private_key"`
@ -114,6 +115,10 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) {
checkNotEmpty(configErrs, "global.server_name", string(c.ServerName)) checkNotEmpty(configErrs, "global.server_name", string(c.ServerName))
checkNotEmpty(configErrs, "global.private_key", string(c.PrivateKeyPath)) checkNotEmpty(configErrs, "global.private_key", string(c.PrivateKeyPath))
for _, v := range c.VirtualHosts {
v.Verify(configErrs)
}
c.JetStream.Verify(configErrs, isMonolith) c.JetStream.Verify(configErrs, isMonolith)
c.Metrics.Verify(configErrs, isMonolith) c.Metrics.Verify(configErrs, isMonolith)
c.Sentry.Verify(configErrs, isMonolith) c.Sentry.Verify(configErrs, isMonolith)
@ -127,14 +132,85 @@ func (c *Global) IsLocalServerName(serverName gomatrixserverlib.ServerName) bool
if c.ServerName == serverName { if c.ServerName == serverName {
return true return true
} }
for _, secondaryName := range c.SecondaryServerNames { for _, v := range c.VirtualHosts {
if secondaryName == serverName { if v.ServerName == serverName {
return true return true
} }
} }
return false return false
} }
func (c *Global) SplitLocalID(sigil byte, id string) (string, gomatrixserverlib.ServerName, error) {
u, s, err := gomatrixserverlib.SplitID(sigil, id)
if err != nil {
return u, s, err
}
if !c.IsLocalServerName(s) {
return u, s, fmt.Errorf("server name %q not known", s)
}
return u, s, nil
}
func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.SigningIdentity, error) {
for _, id := range c.SigningIdentities() {
if id.ServerName == serverName {
return id, nil
}
}
return nil, fmt.Errorf("no signing identity %q", serverName)
}
func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity {
identities := make([]*gomatrixserverlib.SigningIdentity, 0, len(c.VirtualHosts)+1)
identities = append(identities, &gomatrixserverlib.SigningIdentity{
ServerName: c.ServerName,
KeyID: c.KeyID,
PrivateKey: c.PrivateKey,
})
for _, v := range c.VirtualHosts {
identities = append(identities, v.SigningIdentity())
}
return identities
}
type VirtualHost struct {
// The server name of the virtual host.
ServerName gomatrixserverlib.ServerName `yaml:"server_name"`
// The key ID of the private key. If not specified, the default global key ID
// will be used instead.
KeyID gomatrixserverlib.KeyID `yaml:"key_id"`
// Path to the private key. If not specified, the default global private key
// will be used instead.
PrivateKeyPath Path `yaml:"private_key"`
// The private key itself.
PrivateKey ed25519.PrivateKey `yaml:"-"`
// How long a remote server can cache our server key for before requesting it again.
// Increasing this number will reduce the number of requests made by remote servers
// for our key, but increases the period a compromised key will be considered valid
// by remote servers.
// Defaults to 24 hours.
KeyValidityPeriod time.Duration `yaml:"key_validity_period"`
// Is registration enabled on this virtual host?
AllowRegistration bool `json:"allow_registration"`
}
func (v *VirtualHost) Verify(configErrs *ConfigErrors) {
checkNotEmpty(configErrs, "virtual_host.*.server_name", string(v.ServerName))
}
func (v *VirtualHost) SigningIdentity() *gomatrixserverlib.SigningIdentity {
return &gomatrixserverlib.SigningIdentity{
ServerName: v.ServerName,
KeyID: v.KeyID,
PrivateKey: v.PrivateKey,
}
}
type OldVerifyKeys struct { type OldVerifyKeys struct {
// Path to the private key. // Path to the private key.
PrivateKeyPath Path `yaml:"private_key"` PrivateKeyPath Path `yaml:"private_key"`

View file

@ -397,7 +397,7 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen
serversToQuery := rc.getServersForEventID(parentID) serversToQuery := rc.getServersForEventID(parentID)
var result *MSC2836EventRelationshipsResponse var result *MSC2836EventRelationshipsResponse
for _, srv := range serversToQuery { for _, srv := range serversToQuery {
res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, rc.serverName, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{
EventID: parentID, EventID: parentID,
Direction: "down", Direction: "down",
Limit: 100, Limit: 100,
@ -484,7 +484,7 @@ func walkThread(
// MSC2836EventRelationships performs an /event_relationships request to a remote server // MSC2836EventRelationships performs an /event_relationships request to a remote server
func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*MSC2836EventRelationshipsResponse, error) { func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*MSC2836EventRelationshipsResponse, error) {
res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, rc.serverName, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{
EventID: eventID, EventID: eventID,
DepthFirst: rc.req.DepthFirst, DepthFirst: rc.req.DepthFirst,
Direction: rc.req.Direction, Direction: rc.req.Direction,
@ -665,7 +665,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo
}) })
} }
// we've got the data by this point so use a background context // we've got the data by this point so use a background context
err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false) err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, rc.serverName, ires, false)
if err != nil { if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver") util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver")
} }

View file

@ -433,7 +433,7 @@ func (w *walker) federatedRoomInfo(roomID string, vias []string) *gomatrixserver
if serverName == string(w.thisServer) { if serverName == string(w.thisServer) {
continue continue
} }
res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly) res, err := w.fsAPI.MSC2946Spaces(ctx, w.thisServer, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly)
if err != nil { if err != nil {
util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName) util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName)
continue continue

View file

@ -28,22 +28,20 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// OutputKeyChangeEventConsumer consumes events that originated in the key server. // OutputKeyChangeEventConsumer consumes events that originated in the key server.
type OutputKeyChangeEventConsumer struct { type OutputKeyChangeEventConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable string durable string
topic string topic string
db storage.Database db storage.Database
notifier *notifier.Notifier notifier *notifier.Notifier
stream streams.StreamProvider stream streams.StreamProvider
serverName gomatrixserverlib.ServerName // our server name rsAPI roomserverAPI.SyncRoomserverAPI
rsAPI roomserverAPI.SyncRoomserverAPI
} }
// NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer. // NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer.
@ -59,15 +57,14 @@ func NewOutputKeyChangeEventConsumer(
stream streams.StreamProvider, stream streams.StreamProvider,
) *OutputKeyChangeEventConsumer { ) *OutputKeyChangeEventConsumer {
s := &OutputKeyChangeEventConsumer{ s := &OutputKeyChangeEventConsumer{
ctx: process.Context(), ctx: process.Context(),
jetstream: js, jetstream: js,
durable: cfg.Matrix.JetStream.Durable("SyncAPIKeyChangeConsumer"), durable: cfg.Matrix.JetStream.Durable("SyncAPIKeyChangeConsumer"),
topic: topic, topic: topic,
db: store, db: store,
serverName: cfg.Matrix.ServerName, rsAPI: rsAPI,
rsAPI: rsAPI, notifier: notifier,
notifier: notifier, stream: stream,
stream: stream,
} }
return s return s

View file

@ -34,14 +34,13 @@ import (
// OutputReceiptEventConsumer consumes events that originated in the EDU server. // OutputReceiptEventConsumer consumes events that originated in the EDU server.
type OutputReceiptEventConsumer struct { type OutputReceiptEventConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable string durable string
topic string topic string
db storage.Database db storage.Database
stream streams.StreamProvider stream streams.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName
} }
// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. // NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer.
@ -55,14 +54,13 @@ func NewOutputReceiptEventConsumer(
stream streams.StreamProvider, stream streams.StreamProvider,
) *OutputReceiptEventConsumer { ) *OutputReceiptEventConsumer {
return &OutputReceiptEventConsumer{ return &OutputReceiptEventConsumer{
ctx: process.Context(), ctx: process.Context(),
jetstream: js, jetstream: js,
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent),
durable: cfg.Matrix.JetStream.Durable("SyncAPIReceiptConsumer"), durable: cfg.Matrix.JetStream.Durable("SyncAPIReceiptConsumer"),
db: store, db: store,
notifier: notifier, notifier: notifier,
stream: stream, stream: stream,
serverName: cfg.Matrix.ServerName,
} }
} }

View file

@ -364,11 +364,7 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *gom
// TODO: check that it's a join and not a profile change (means unmarshalling prev_content) // TODO: check that it's a join and not a profile change (means unmarshalling prev_content)
if membership == gomatrixserverlib.Join { if membership == gomatrixserverlib.Join {
// check it's a local join // check it's a local join
_, domain, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) if _, _, err := s.cfg.Matrix.SplitLocalID('@', *ev.StateKey()); err != nil {
if err != nil {
return sp, fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if domain != s.cfg.Matrix.ServerName {
return sp, nil return sp, nil
} }
@ -390,9 +386,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
if msg.Event.StateKey() == nil { if msg.Event.StateKey() == nil {
return return
} }
if _, serverName, err := gomatrixserverlib.SplitID('@', *msg.Event.StateKey()); err != nil { if _, _, err := s.cfg.Matrix.SplitLocalID('@', *msg.Event.StateKey()); err != nil {
return
} else if serverName != s.cfg.Matrix.ServerName {
return return
} }
pduPos, err := s.db.AddInviteEvent(ctx, msg.Event) pduPos, err := s.db.AddInviteEvent(ctx, msg.Event)

View file

@ -37,15 +37,15 @@ import (
// OutputSendToDeviceEventConsumer consumes events that originated in the EDU server. // OutputSendToDeviceEventConsumer consumes events that originated in the EDU server.
type OutputSendToDeviceEventConsumer struct { type OutputSendToDeviceEventConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable string durable string
topic string topic string
db storage.Database db storage.Database
keyAPI keyapi.SyncKeyAPI keyAPI keyapi.SyncKeyAPI
serverName gomatrixserverlib.ServerName // our server name isLocalServerName func(gomatrixserverlib.ServerName) bool
stream streams.StreamProvider stream streams.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
} }
// NewOutputSendToDeviceEventConsumer creates a new OutputSendToDeviceEventConsumer. // NewOutputSendToDeviceEventConsumer creates a new OutputSendToDeviceEventConsumer.
@ -60,15 +60,15 @@ func NewOutputSendToDeviceEventConsumer(
stream streams.StreamProvider, stream streams.StreamProvider,
) *OutputSendToDeviceEventConsumer { ) *OutputSendToDeviceEventConsumer {
return &OutputSendToDeviceEventConsumer{ return &OutputSendToDeviceEventConsumer{
ctx: process.Context(), ctx: process.Context(),
jetstream: js, jetstream: js,
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
durable: cfg.Matrix.JetStream.Durable("SyncAPISendToDeviceConsumer"), durable: cfg.Matrix.JetStream.Durable("SyncAPISendToDeviceConsumer"),
db: store, db: store,
keyAPI: keyAPI, keyAPI: keyAPI,
serverName: cfg.Matrix.ServerName, isLocalServerName: cfg.Matrix.IsLocalServerName,
notifier: notifier, notifier: notifier,
stream: stream, stream: stream,
} }
} }
@ -89,7 +89,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs []
log.WithError(err).Errorf("send-to-device: failed to split user id, dropping message") log.WithError(err).Errorf("send-to-device: failed to split user id, dropping message")
return true return true
} }
if domain != s.serverName { if !s.isLocalServerName(domain) {
log.Tracef("ignoring send-to-device event with destination %s", domain) log.Tracef("ignoring send-to-device event with destination %s", domain)
return true return true
} }
@ -114,7 +114,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs []
if output.Type == "m.room_key_request" { if output.Type == "m.room_key_request" {
requestingDeviceID := gjson.GetBytes(output.SendToDeviceEvent.Content, "requesting_device_id").Str requestingDeviceID := gjson.GetBytes(output.SendToDeviceEvent.Content, "requesting_device_id").Str
_, senderDomain, _ := gomatrixserverlib.SplitID('@', output.Sender) _, senderDomain, _ := gomatrixserverlib.SplitID('@', output.Sender)
if requestingDeviceID != "" && senderDomain != s.serverName { if requestingDeviceID != "" && !s.isLocalServerName(senderDomain) {
// Mark the requesting device as stale, if we don't know about it. // Mark the requesting device as stale, if we don't know about it.
if err = s.keyAPI.PerformMarkAsStaleIfNeeded(ctx, &keyapi.PerformMarkAsStaleRequest{ if err = s.keyAPI.PerformMarkAsStaleIfNeeded(ctx, &keyapi.PerformMarkAsStaleRequest{
UserID: output.Sender, Domain: senderDomain, DeviceID: requestingDeviceID, UserID: output.Sender, Domain: senderDomain, DeviceID: requestingDeviceID,

View file

@ -528,6 +528,7 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
BackwardsExtremities: backwardsExtremities, BackwardsExtremities: backwardsExtremities,
Limit: limit, Limit: limit,
ServerName: r.cfg.Matrix.ServerName, ServerName: r.cfg.Matrix.ServerName,
VirtualHost: r.device.UserDomain(),
}, &res) }, &res)
if err != nil { if err != nil {
return nil, fmt.Errorf("PerformBackfill failed: %w", err) return nil, fmt.Errorf("PerformBackfill failed: %w", err)

View file

@ -433,7 +433,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
beforeJoinBody := fmt.Sprintf("Before invite in a %s room", tc.historyVisibility) beforeJoinBody := fmt.Sprintf("Before invite in a %s room", tc.historyVisibility)
beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": beforeJoinBody}) beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": beforeJoinBody})
eventsToSend := append(room.Events(), beforeJoinEv) eventsToSend := append(room.Events(), beforeJoinEv)
if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err) t.Fatalf("failed to send events: %v", err)
} }
syncUntil(t, base, aliceDev.AccessToken, false, syncUntil(t, base, aliceDev.AccessToken, false,
@ -472,7 +472,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv) eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv)
if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err) t.Fatalf("failed to send events: %v", err)
} }
syncUntil(t, base, aliceDev.AccessToken, false, syncUntil(t, base, aliceDev.AccessToken, false,