diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 24958091b..72a2cf1eb 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -70,7 +70,7 @@ func CheckForSoftFail( ) // Load the actual auth events from the database. - authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries) + authEvents, err := loadAuthEvents(ctx, db, roomInfo.RoomVersion, stateNeeded, authStateEntries) if err != nil { return true, fmt.Errorf("loadAuthEvents: %w", err) } @@ -88,7 +88,7 @@ func CheckForSoftFail( func CheckAuthEvents( ctx context.Context, db storage.RoomDatabase, - roomInfo *types.RoomInfo, + roomVersion gomatrixserverlib.RoomVersion, event *types.HeaderedEvent, authEventIDs []string, ) ([]types.EventNID, error) { @@ -103,7 +103,7 @@ func CheckAuthEvents( stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.PDU{event.PDU}) // Load the actual auth events from the database. - authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries) + authEvents, err := loadAuthEvents(ctx, db, roomVersion, stateNeeded, authStateEntries) if err != nil { return nil, fmt.Errorf("loadAuthEvents: %w", err) } @@ -196,7 +196,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) g func loadAuthEvents( ctx context.Context, db state.StateResolutionStorage, - roomInfo *types.RoomInfo, + roomVersion gomatrixserverlib.RoomVersion, needed gomatrixserverlib.StateNeeded, state []types.StateEntry, ) (result authEvents, err error) { @@ -220,11 +220,7 @@ func loadAuthEvents( } } - if roomInfo == nil { - err = types.ErrorInvalidRoomInfo - return - } - if result.events, err = db.Events(ctx, roomInfo.RoomVersion, eventNIDs); err != nil { + if result.events, err = db.Events(ctx, roomVersion, eventNIDs); err != nil { return } roomID := "" diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index c69a4b8b9..9f7e379eb 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -142,9 +142,28 @@ func (r *Inviter) PerformInvite( if err != nil { return nil, err } - info, inviteState, err := r.generateInviteStrippedState(ctx, *validRoomID, req.Event, req.InviteRoomState) - if err != nil { - return nil, err + + inviteState := req.InviteRoomState + if len(inviteState) == 0 { + // "If they are set on the room, at least the state for m.room.avatar, m.room.canonical_alias, m.room.join_rules, and m.room.name SHOULD be included." + // https://matrix.org/docs/spec/client_server/r0.6.0#m-room-member + stateWanted := []gomatrixserverlib.StateKeyTuple{} + for _, t := range []string{ + spec.MRoomName, spec.MRoomCanonicalAlias, + spec.MRoomJoinRules, spec.MRoomAvatar, + spec.MRoomEncryption, spec.MRoomCreate, + } { + stateWanted = append(stateWanted, gomatrixserverlib.StateKeyTuple{ + EventType: t, + StateKey: "", + }) + } + if is, err := r.GenerateInviteStrippedState(ctx, *validRoomID, stateWanted, req.Event); err == nil { + inviteState = is + } else { + util.GetLogger(ctx).WithError(err).Error("failed querying known room") + return nil, spec.InternalServerError{} + } } logger := util.GetLogger(ctx).WithFields(map[string]interface{}{ @@ -154,10 +173,9 @@ func (r *Inviter) PerformInvite( "event_id": event.EventID(), }) logger.WithFields(log.Fields{ - "room_version": req.RoomVersion, - "room_info_exists": info != nil, - "target_local": isTargetLocal, - "origin_local": isOriginLocal, + "room_version": req.RoomVersion, + "target_local": isTargetLocal, + "origin_local": true, }).Debug("processing invite event") if len(inviteState) == 0 { @@ -170,14 +188,17 @@ func (r *Inviter) PerformInvite( } } - var isAlreadyJoined bool - if info != nil { - _, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey()) - if err != nil { - return nil, fmt.Errorf("r.DB.GetMembership: %w", err) - } + membershipReq := api.QueryMembershipForUserRequest{ + RoomID: validRoomID.String(), + UserID: *event.StateKey(), } - if isAlreadyJoined { + res := api.QueryMembershipForUserResponse{} + err = r.RSAPI.QueryMembershipForUser(ctx, &membershipReq, &res) + + if err != nil { + return nil, fmt.Errorf("r.RSAPI.QueryMembershipForUser: %w", err) + } + if res.Membership == spec.Join { // If the user is joined to the room then that takes precedence over this // invite event. It makes little sense to move a user that is already // joined to the room into the invite state. @@ -213,7 +234,7 @@ func (r *Inviter) PerformInvite( // try and see if the user is allowed to make this invite. We can't do // this for invites coming in over federation - we have to take those on // trust. - _, err = helpers.CheckAuthEvents(ctx, r.DB, info, event, event.AuthEventIDs()) + _, err = helpers.CheckAuthEvents(ctx, r.DB, req.RoomVersion, event, event.AuthEventIDs()) if err != nil { logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( "ProcessInviteEvent.checkAuthEvents failed for event", @@ -266,68 +287,3 @@ func (r *Inviter) PerformInvite( // gets the invite, as the roomserver will do this when it processes the m.room.member invite. return outputUpdates, nil } - -func (r *Inviter) generateInviteStrippedState( - ctx context.Context, roomID spec.RoomID, inviteEvent *types.HeaderedEvent, inviteState []gomatrixserverlib.InviteStrippedState, -) (*types.RoomInfo, []gomatrixserverlib.InviteStrippedState, error) { - info, err := r.DB.RoomInfo(ctx, roomID.String()) - if err != nil { - return nil, nil, fmt.Errorf("failed to load RoomInfo: %w", err) - } - strippedState := inviteState - if len(strippedState) == 0 && info != nil { - var is []gomatrixserverlib.InviteStrippedState - if is, err = buildInviteStrippedState(ctx, r.DB, info, inviteEvent); err == nil { - strippedState = is - } - } - - return info, strippedState, nil -} - -func buildInviteStrippedState( - ctx context.Context, - db storage.Database, - info *types.RoomInfo, - inviteEvent *types.HeaderedEvent, -) ([]gomatrixserverlib.InviteStrippedState, error) { - stateWanted := []gomatrixserverlib.StateKeyTuple{} - // "If they are set on the room, at least the state for m.room.avatar, m.room.canonical_alias, m.room.join_rules, and m.room.name SHOULD be included." - // https://matrix.org/docs/spec/client_server/r0.6.0#m-room-member - for _, t := range []string{ - spec.MRoomName, spec.MRoomCanonicalAlias, - spec.MRoomJoinRules, spec.MRoomAvatar, - spec.MRoomEncryption, spec.MRoomCreate, - } { - stateWanted = append(stateWanted, gomatrixserverlib.StateKeyTuple{ - EventType: t, - StateKey: "", - }) - } - roomState := state.NewStateResolution(db, info) - stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( - ctx, info.StateSnapshotNID(), stateWanted, - ) - if err != nil { - return nil, err - } - stateNIDs := []types.EventNID{} - for _, stateNID := range stateEntries { - stateNIDs = append(stateNIDs, stateNID.EventNID) - } - if info == nil { - return nil, types.ErrorInvalidRoomInfo - } - stateEvents, err := db.Events(ctx, info.RoomVersion, stateNIDs) - if err != nil { - return nil, err - } - inviteState := []gomatrixserverlib.InviteStrippedState{ - gomatrixserverlib.NewInviteStrippedState(inviteEvent.PDU), - } - stateEvents = append(stateEvents, types.Event{PDU: inviteEvent.PDU}) - for _, event := range stateEvents { - inviteState = append(inviteState, gomatrixserverlib.NewInviteStrippedState(event.PDU)) - } - return inviteState, nil -}