mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-11 16:13:10 -06:00
Refactor PerformInvite to remove duplication
This commit is contained in:
parent
b426961d84
commit
e678912a80
|
|
@ -70,7 +70,7 @@ func CheckForSoftFail(
|
|||
)
|
||||
|
||||
// Load the actual auth events from the database.
|
||||
authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries)
|
||||
authEvents, err := loadAuthEvents(ctx, db, roomInfo.RoomVersion, stateNeeded, authStateEntries)
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("loadAuthEvents: %w", err)
|
||||
}
|
||||
|
|
@ -88,7 +88,7 @@ func CheckForSoftFail(
|
|||
func CheckAuthEvents(
|
||||
ctx context.Context,
|
||||
db storage.RoomDatabase,
|
||||
roomInfo *types.RoomInfo,
|
||||
roomVersion gomatrixserverlib.RoomVersion,
|
||||
event *types.HeaderedEvent,
|
||||
authEventIDs []string,
|
||||
) ([]types.EventNID, error) {
|
||||
|
|
@ -103,7 +103,7 @@ func CheckAuthEvents(
|
|||
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.PDU{event.PDU})
|
||||
|
||||
// Load the actual auth events from the database.
|
||||
authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries)
|
||||
authEvents, err := loadAuthEvents(ctx, db, roomVersion, stateNeeded, authStateEntries)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loadAuthEvents: %w", err)
|
||||
}
|
||||
|
|
@ -196,7 +196,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) g
|
|||
func loadAuthEvents(
|
||||
ctx context.Context,
|
||||
db state.StateResolutionStorage,
|
||||
roomInfo *types.RoomInfo,
|
||||
roomVersion gomatrixserverlib.RoomVersion,
|
||||
needed gomatrixserverlib.StateNeeded,
|
||||
state []types.StateEntry,
|
||||
) (result authEvents, err error) {
|
||||
|
|
@ -220,11 +220,7 @@ func loadAuthEvents(
|
|||
}
|
||||
}
|
||||
|
||||
if roomInfo == nil {
|
||||
err = types.ErrorInvalidRoomInfo
|
||||
return
|
||||
}
|
||||
if result.events, err = db.Events(ctx, roomInfo.RoomVersion, eventNIDs); err != nil {
|
||||
if result.events, err = db.Events(ctx, roomVersion, eventNIDs); err != nil {
|
||||
return
|
||||
}
|
||||
roomID := ""
|
||||
|
|
|
|||
|
|
@ -142,9 +142,28 @@ func (r *Inviter) PerformInvite(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info, inviteState, err := r.generateInviteStrippedState(ctx, *validRoomID, req.Event, req.InviteRoomState)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
inviteState := req.InviteRoomState
|
||||
if len(inviteState) == 0 {
|
||||
// "If they are set on the room, at least the state for m.room.avatar, m.room.canonical_alias, m.room.join_rules, and m.room.name SHOULD be included."
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.0#m-room-member
|
||||
stateWanted := []gomatrixserverlib.StateKeyTuple{}
|
||||
for _, t := range []string{
|
||||
spec.MRoomName, spec.MRoomCanonicalAlias,
|
||||
spec.MRoomJoinRules, spec.MRoomAvatar,
|
||||
spec.MRoomEncryption, spec.MRoomCreate,
|
||||
} {
|
||||
stateWanted = append(stateWanted, gomatrixserverlib.StateKeyTuple{
|
||||
EventType: t,
|
||||
StateKey: "",
|
||||
})
|
||||
}
|
||||
if is, err := r.GenerateInviteStrippedState(ctx, *validRoomID, stateWanted, req.Event); err == nil {
|
||||
inviteState = is
|
||||
} else {
|
||||
util.GetLogger(ctx).WithError(err).Error("failed querying known room")
|
||||
return nil, spec.InternalServerError{}
|
||||
}
|
||||
}
|
||||
|
||||
logger := util.GetLogger(ctx).WithFields(map[string]interface{}{
|
||||
|
|
@ -154,10 +173,9 @@ func (r *Inviter) PerformInvite(
|
|||
"event_id": event.EventID(),
|
||||
})
|
||||
logger.WithFields(log.Fields{
|
||||
"room_version": req.RoomVersion,
|
||||
"room_info_exists": info != nil,
|
||||
"target_local": isTargetLocal,
|
||||
"origin_local": isOriginLocal,
|
||||
"room_version": req.RoomVersion,
|
||||
"target_local": isTargetLocal,
|
||||
"origin_local": true,
|
||||
}).Debug("processing invite event")
|
||||
|
||||
if len(inviteState) == 0 {
|
||||
|
|
@ -170,14 +188,17 @@ func (r *Inviter) PerformInvite(
|
|||
}
|
||||
}
|
||||
|
||||
var isAlreadyJoined bool
|
||||
if info != nil {
|
||||
_, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("r.DB.GetMembership: %w", err)
|
||||
}
|
||||
membershipReq := api.QueryMembershipForUserRequest{
|
||||
RoomID: validRoomID.String(),
|
||||
UserID: *event.StateKey(),
|
||||
}
|
||||
if isAlreadyJoined {
|
||||
res := api.QueryMembershipForUserResponse{}
|
||||
err = r.RSAPI.QueryMembershipForUser(ctx, &membershipReq, &res)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("r.RSAPI.QueryMembershipForUser: %w", err)
|
||||
}
|
||||
if res.Membership == spec.Join {
|
||||
// If the user is joined to the room then that takes precedence over this
|
||||
// invite event. It makes little sense to move a user that is already
|
||||
// joined to the room into the invite state.
|
||||
|
|
@ -213,7 +234,7 @@ func (r *Inviter) PerformInvite(
|
|||
// try and see if the user is allowed to make this invite. We can't do
|
||||
// this for invites coming in over federation - we have to take those on
|
||||
// trust.
|
||||
_, err = helpers.CheckAuthEvents(ctx, r.DB, info, event, event.AuthEventIDs())
|
||||
_, err = helpers.CheckAuthEvents(ctx, r.DB, req.RoomVersion, event, event.AuthEventIDs())
|
||||
if err != nil {
|
||||
logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
|
||||
"ProcessInviteEvent.checkAuthEvents failed for event",
|
||||
|
|
@ -266,68 +287,3 @@ func (r *Inviter) PerformInvite(
|
|||
// gets the invite, as the roomserver will do this when it processes the m.room.member invite.
|
||||
return outputUpdates, nil
|
||||
}
|
||||
|
||||
func (r *Inviter) generateInviteStrippedState(
|
||||
ctx context.Context, roomID spec.RoomID, inviteEvent *types.HeaderedEvent, inviteState []gomatrixserverlib.InviteStrippedState,
|
||||
) (*types.RoomInfo, []gomatrixserverlib.InviteStrippedState, error) {
|
||||
info, err := r.DB.RoomInfo(ctx, roomID.String())
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to load RoomInfo: %w", err)
|
||||
}
|
||||
strippedState := inviteState
|
||||
if len(strippedState) == 0 && info != nil {
|
||||
var is []gomatrixserverlib.InviteStrippedState
|
||||
if is, err = buildInviteStrippedState(ctx, r.DB, info, inviteEvent); err == nil {
|
||||
strippedState = is
|
||||
}
|
||||
}
|
||||
|
||||
return info, strippedState, nil
|
||||
}
|
||||
|
||||
func buildInviteStrippedState(
|
||||
ctx context.Context,
|
||||
db storage.Database,
|
||||
info *types.RoomInfo,
|
||||
inviteEvent *types.HeaderedEvent,
|
||||
) ([]gomatrixserverlib.InviteStrippedState, error) {
|
||||
stateWanted := []gomatrixserverlib.StateKeyTuple{}
|
||||
// "If they are set on the room, at least the state for m.room.avatar, m.room.canonical_alias, m.room.join_rules, and m.room.name SHOULD be included."
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.0#m-room-member
|
||||
for _, t := range []string{
|
||||
spec.MRoomName, spec.MRoomCanonicalAlias,
|
||||
spec.MRoomJoinRules, spec.MRoomAvatar,
|
||||
spec.MRoomEncryption, spec.MRoomCreate,
|
||||
} {
|
||||
stateWanted = append(stateWanted, gomatrixserverlib.StateKeyTuple{
|
||||
EventType: t,
|
||||
StateKey: "",
|
||||
})
|
||||
}
|
||||
roomState := state.NewStateResolution(db, info)
|
||||
stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples(
|
||||
ctx, info.StateSnapshotNID(), stateWanted,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stateNIDs := []types.EventNID{}
|
||||
for _, stateNID := range stateEntries {
|
||||
stateNIDs = append(stateNIDs, stateNID.EventNID)
|
||||
}
|
||||
if info == nil {
|
||||
return nil, types.ErrorInvalidRoomInfo
|
||||
}
|
||||
stateEvents, err := db.Events(ctx, info.RoomVersion, stateNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inviteState := []gomatrixserverlib.InviteStrippedState{
|
||||
gomatrixserverlib.NewInviteStrippedState(inviteEvent.PDU),
|
||||
}
|
||||
stateEvents = append(stateEvents, types.Event{PDU: inviteEvent.PDU})
|
||||
for _, event := range stateEvents {
|
||||
inviteState = append(inviteState, gomatrixserverlib.NewInviteStrippedState(event.PDU))
|
||||
}
|
||||
return inviteState, nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue