Refactor PerformInvite to remove duplication

This commit is contained in:
Devon Hudson 2023-05-23 15:16:28 -06:00
parent b426961d84
commit e678912a80
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
2 changed files with 41 additions and 89 deletions

View file

@ -70,7 +70,7 @@ func CheckForSoftFail(
)
// Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries)
authEvents, err := loadAuthEvents(ctx, db, roomInfo.RoomVersion, stateNeeded, authStateEntries)
if err != nil {
return true, fmt.Errorf("loadAuthEvents: %w", err)
}
@ -88,7 +88,7 @@ func CheckForSoftFail(
func CheckAuthEvents(
ctx context.Context,
db storage.RoomDatabase,
roomInfo *types.RoomInfo,
roomVersion gomatrixserverlib.RoomVersion,
event *types.HeaderedEvent,
authEventIDs []string,
) ([]types.EventNID, error) {
@ -103,7 +103,7 @@ func CheckAuthEvents(
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.PDU{event.PDU})
// Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries)
authEvents, err := loadAuthEvents(ctx, db, roomVersion, stateNeeded, authStateEntries)
if err != nil {
return nil, fmt.Errorf("loadAuthEvents: %w", err)
}
@ -196,7 +196,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) g
func loadAuthEvents(
ctx context.Context,
db state.StateResolutionStorage,
roomInfo *types.RoomInfo,
roomVersion gomatrixserverlib.RoomVersion,
needed gomatrixserverlib.StateNeeded,
state []types.StateEntry,
) (result authEvents, err error) {
@ -220,11 +220,7 @@ func loadAuthEvents(
}
}
if roomInfo == nil {
err = types.ErrorInvalidRoomInfo
return
}
if result.events, err = db.Events(ctx, roomInfo.RoomVersion, eventNIDs); err != nil {
if result.events, err = db.Events(ctx, roomVersion, eventNIDs); err != nil {
return
}
roomID := ""

View file

@ -142,9 +142,28 @@ func (r *Inviter) PerformInvite(
if err != nil {
return nil, err
}
info, inviteState, err := r.generateInviteStrippedState(ctx, *validRoomID, req.Event, req.InviteRoomState)
if err != nil {
return nil, err
inviteState := req.InviteRoomState
if len(inviteState) == 0 {
// "If they are set on the room, at least the state for m.room.avatar, m.room.canonical_alias, m.room.join_rules, and m.room.name SHOULD be included."
// https://matrix.org/docs/spec/client_server/r0.6.0#m-room-member
stateWanted := []gomatrixserverlib.StateKeyTuple{}
for _, t := range []string{
spec.MRoomName, spec.MRoomCanonicalAlias,
spec.MRoomJoinRules, spec.MRoomAvatar,
spec.MRoomEncryption, spec.MRoomCreate,
} {
stateWanted = append(stateWanted, gomatrixserverlib.StateKeyTuple{
EventType: t,
StateKey: "",
})
}
if is, err := r.GenerateInviteStrippedState(ctx, *validRoomID, stateWanted, req.Event); err == nil {
inviteState = is
} else {
util.GetLogger(ctx).WithError(err).Error("failed querying known room")
return nil, spec.InternalServerError{}
}
}
logger := util.GetLogger(ctx).WithFields(map[string]interface{}{
@ -154,10 +173,9 @@ func (r *Inviter) PerformInvite(
"event_id": event.EventID(),
})
logger.WithFields(log.Fields{
"room_version": req.RoomVersion,
"room_info_exists": info != nil,
"target_local": isTargetLocal,
"origin_local": isOriginLocal,
"room_version": req.RoomVersion,
"target_local": isTargetLocal,
"origin_local": true,
}).Debug("processing invite event")
if len(inviteState) == 0 {
@ -170,14 +188,17 @@ func (r *Inviter) PerformInvite(
}
}
var isAlreadyJoined bool
if info != nil {
_, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey())
if err != nil {
return nil, fmt.Errorf("r.DB.GetMembership: %w", err)
}
membershipReq := api.QueryMembershipForUserRequest{
RoomID: validRoomID.String(),
UserID: *event.StateKey(),
}
if isAlreadyJoined {
res := api.QueryMembershipForUserResponse{}
err = r.RSAPI.QueryMembershipForUser(ctx, &membershipReq, &res)
if err != nil {
return nil, fmt.Errorf("r.RSAPI.QueryMembershipForUser: %w", err)
}
if res.Membership == spec.Join {
// If the user is joined to the room then that takes precedence over this
// invite event. It makes little sense to move a user that is already
// joined to the room into the invite state.
@ -213,7 +234,7 @@ func (r *Inviter) PerformInvite(
// try and see if the user is allowed to make this invite. We can't do
// this for invites coming in over federation - we have to take those on
// trust.
_, err = helpers.CheckAuthEvents(ctx, r.DB, info, event, event.AuthEventIDs())
_, err = helpers.CheckAuthEvents(ctx, r.DB, req.RoomVersion, event, event.AuthEventIDs())
if err != nil {
logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
"ProcessInviteEvent.checkAuthEvents failed for event",
@ -266,68 +287,3 @@ func (r *Inviter) PerformInvite(
// gets the invite, as the roomserver will do this when it processes the m.room.member invite.
return outputUpdates, nil
}
func (r *Inviter) generateInviteStrippedState(
ctx context.Context, roomID spec.RoomID, inviteEvent *types.HeaderedEvent, inviteState []gomatrixserverlib.InviteStrippedState,
) (*types.RoomInfo, []gomatrixserverlib.InviteStrippedState, error) {
info, err := r.DB.RoomInfo(ctx, roomID.String())
if err != nil {
return nil, nil, fmt.Errorf("failed to load RoomInfo: %w", err)
}
strippedState := inviteState
if len(strippedState) == 0 && info != nil {
var is []gomatrixserverlib.InviteStrippedState
if is, err = buildInviteStrippedState(ctx, r.DB, info, inviteEvent); err == nil {
strippedState = is
}
}
return info, strippedState, nil
}
func buildInviteStrippedState(
ctx context.Context,
db storage.Database,
info *types.RoomInfo,
inviteEvent *types.HeaderedEvent,
) ([]gomatrixserverlib.InviteStrippedState, error) {
stateWanted := []gomatrixserverlib.StateKeyTuple{}
// "If they are set on the room, at least the state for m.room.avatar, m.room.canonical_alias, m.room.join_rules, and m.room.name SHOULD be included."
// https://matrix.org/docs/spec/client_server/r0.6.0#m-room-member
for _, t := range []string{
spec.MRoomName, spec.MRoomCanonicalAlias,
spec.MRoomJoinRules, spec.MRoomAvatar,
spec.MRoomEncryption, spec.MRoomCreate,
} {
stateWanted = append(stateWanted, gomatrixserverlib.StateKeyTuple{
EventType: t,
StateKey: "",
})
}
roomState := state.NewStateResolution(db, info)
stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples(
ctx, info.StateSnapshotNID(), stateWanted,
)
if err != nil {
return nil, err
}
stateNIDs := []types.EventNID{}
for _, stateNID := range stateEntries {
stateNIDs = append(stateNIDs, stateNID.EventNID)
}
if info == nil {
return nil, types.ErrorInvalidRoomInfo
}
stateEvents, err := db.Events(ctx, info.RoomVersion, stateNIDs)
if err != nil {
return nil, err
}
inviteState := []gomatrixserverlib.InviteStrippedState{
gomatrixserverlib.NewInviteStrippedState(inviteEvent.PDU),
}
stateEvents = append(stateEvents, types.Event{PDU: inviteEvent.PDU})
for _, event := range stateEvents {
inviteState = append(inviteState, gomatrixserverlib.NewInviteStrippedState(event.PDU))
}
return inviteState, nil
}