diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index b05a931fe..f39b26eaf 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -93,6 +93,7 @@ func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSen FSAPI: r.fsAPI, RSAPI: r, Inputer: r.Inputer, + Queryer: r.Queryer, } r.Peeker = &perform.Peeker{ ServerName: r.Cfg.Matrix.ServerName, diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index a829bffca..a389cc898 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -50,6 +50,10 @@ func UpdateToInviteMembership( return updates, nil } +// IsServerCurrentlyInRoom checks if a server is in a given room, based on the room +// memberships. If the servername is not supplied then the local server will be +// checked instead using a faster code path. +// TODO: This should probably be replaced by an API call. func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) { info, err := db.RoomInfo(ctx, roomID) if err != nil { @@ -59,6 +63,10 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam return false, fmt.Errorf("unknown room %s", roomID) } + if serverName == "" { + return db.GetLocalServerInRoom(ctx, info.RoomNID) + } + eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) if err != nil { return false, err diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 048496d45..876888e29 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -28,6 +28,7 @@ import ( rsAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/internal/input" + "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" @@ -42,6 +43,7 @@ type Joiner struct { DB storage.Database Inputer *input.Inputer + Queryer *query.Queryer } // PerformJoin handles joining matrix rooms, including over federation by talking to the federationsender. @@ -205,7 +207,14 @@ func (r *Joiner) performJoinRoomByID( // Force a federated join if we aren't in the room and we've been // given some server names to try joining by. - serverInRoom, _ := helpers.IsServerCurrentlyInRoom(ctx, r.DB, r.ServerName, req.RoomIDOrAlias) + inRoomReq := &api.QueryServerJoinedToRoomRequest{ + RoomID: req.RoomIDOrAlias, + } + inRoomRes := &api.QueryServerJoinedToRoomResponse{} + if err = r.Queryer.QueryServerJoinedToRoom(ctx, inRoomReq, inRoomRes); err != nil { + return "", "", fmt.Errorf("r.Queryer.QueryServerJoinedToRoom: %w", err) + } + serverInRoom := inRoomRes.IsInRoom forceFederatedJoin := len(req.ServerNames) > 0 && !serverInRoom // Force a federated join if we're dealing with a pending invite diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index ccd093726..4af0e6397 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -388,10 +388,16 @@ func (r *Queryer) QueryServerAllowedToSeeEvent( return } roomID := events[0].RoomID() - isServerInRoom, err := helpers.IsServerCurrentlyInRoom(ctx, r.DB, request.ServerName, roomID) - if err != nil { - return + + inRoomReq := &api.QueryServerJoinedToRoomRequest{ + RoomID: roomID, + ServerName: request.ServerName, } + inRoomRes := &api.QueryServerJoinedToRoomResponse{} + if err = r.QueryServerJoinedToRoom(ctx, inRoomReq, inRoomRes); err != nil { + return fmt.Errorf("r.Queryer.QueryServerJoinedToRoom: %w", err) + } + info, err := r.DB.RoomInfo(ctx, roomID) if err != nil { return err @@ -400,7 +406,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent( return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID) } response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent( - ctx, r.DB, *info, request.EventID, request.ServerName, isServerInRoom, + ctx, r.DB, *info, request.EventID, request.ServerName, inRoomRes.IsInRoom, ) return }