mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-15 01:53:09 -06:00
Refactor PerformInvite to remove duplication
This commit is contained in:
parent
b426961d84
commit
e678912a80
|
|
@ -70,7 +70,7 @@ func CheckForSoftFail(
|
||||||
)
|
)
|
||||||
|
|
||||||
// Load the actual auth events from the database.
|
// Load the actual auth events from the database.
|
||||||
authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries)
|
authEvents, err := loadAuthEvents(ctx, db, roomInfo.RoomVersion, stateNeeded, authStateEntries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true, fmt.Errorf("loadAuthEvents: %w", err)
|
return true, fmt.Errorf("loadAuthEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -88,7 +88,7 @@ func CheckForSoftFail(
|
||||||
func CheckAuthEvents(
|
func CheckAuthEvents(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
db storage.RoomDatabase,
|
db storage.RoomDatabase,
|
||||||
roomInfo *types.RoomInfo,
|
roomVersion gomatrixserverlib.RoomVersion,
|
||||||
event *types.HeaderedEvent,
|
event *types.HeaderedEvent,
|
||||||
authEventIDs []string,
|
authEventIDs []string,
|
||||||
) ([]types.EventNID, error) {
|
) ([]types.EventNID, error) {
|
||||||
|
|
@ -103,7 +103,7 @@ func CheckAuthEvents(
|
||||||
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.PDU{event.PDU})
|
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.PDU{event.PDU})
|
||||||
|
|
||||||
// Load the actual auth events from the database.
|
// Load the actual auth events from the database.
|
||||||
authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries)
|
authEvents, err := loadAuthEvents(ctx, db, roomVersion, stateNeeded, authStateEntries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("loadAuthEvents: %w", err)
|
return nil, fmt.Errorf("loadAuthEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -196,7 +196,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) g
|
||||||
func loadAuthEvents(
|
func loadAuthEvents(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
db state.StateResolutionStorage,
|
db state.StateResolutionStorage,
|
||||||
roomInfo *types.RoomInfo,
|
roomVersion gomatrixserverlib.RoomVersion,
|
||||||
needed gomatrixserverlib.StateNeeded,
|
needed gomatrixserverlib.StateNeeded,
|
||||||
state []types.StateEntry,
|
state []types.StateEntry,
|
||||||
) (result authEvents, err error) {
|
) (result authEvents, err error) {
|
||||||
|
|
@ -220,11 +220,7 @@ func loadAuthEvents(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if roomInfo == nil {
|
if result.events, err = db.Events(ctx, roomVersion, eventNIDs); err != nil {
|
||||||
err = types.ErrorInvalidRoomInfo
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if result.events, err = db.Events(ctx, roomInfo.RoomVersion, eventNIDs); err != nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
roomID := ""
|
roomID := ""
|
||||||
|
|
|
||||||
|
|
@ -142,9 +142,28 @@ func (r *Inviter) PerformInvite(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
info, inviteState, err := r.generateInviteStrippedState(ctx, *validRoomID, req.Event, req.InviteRoomState)
|
|
||||||
if err != nil {
|
inviteState := req.InviteRoomState
|
||||||
return nil, err
|
if len(inviteState) == 0 {
|
||||||
|
// "If they are set on the room, at least the state for m.room.avatar, m.room.canonical_alias, m.room.join_rules, and m.room.name SHOULD be included."
|
||||||
|
// https://matrix.org/docs/spec/client_server/r0.6.0#m-room-member
|
||||||
|
stateWanted := []gomatrixserverlib.StateKeyTuple{}
|
||||||
|
for _, t := range []string{
|
||||||
|
spec.MRoomName, spec.MRoomCanonicalAlias,
|
||||||
|
spec.MRoomJoinRules, spec.MRoomAvatar,
|
||||||
|
spec.MRoomEncryption, spec.MRoomCreate,
|
||||||
|
} {
|
||||||
|
stateWanted = append(stateWanted, gomatrixserverlib.StateKeyTuple{
|
||||||
|
EventType: t,
|
||||||
|
StateKey: "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if is, err := r.GenerateInviteStrippedState(ctx, *validRoomID, stateWanted, req.Event); err == nil {
|
||||||
|
inviteState = is
|
||||||
|
} else {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("failed querying known room")
|
||||||
|
return nil, spec.InternalServerError{}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := util.GetLogger(ctx).WithFields(map[string]interface{}{
|
logger := util.GetLogger(ctx).WithFields(map[string]interface{}{
|
||||||
|
|
@ -154,10 +173,9 @@ func (r *Inviter) PerformInvite(
|
||||||
"event_id": event.EventID(),
|
"event_id": event.EventID(),
|
||||||
})
|
})
|
||||||
logger.WithFields(log.Fields{
|
logger.WithFields(log.Fields{
|
||||||
"room_version": req.RoomVersion,
|
"room_version": req.RoomVersion,
|
||||||
"room_info_exists": info != nil,
|
"target_local": isTargetLocal,
|
||||||
"target_local": isTargetLocal,
|
"origin_local": true,
|
||||||
"origin_local": isOriginLocal,
|
|
||||||
}).Debug("processing invite event")
|
}).Debug("processing invite event")
|
||||||
|
|
||||||
if len(inviteState) == 0 {
|
if len(inviteState) == 0 {
|
||||||
|
|
@ -170,14 +188,17 @@ func (r *Inviter) PerformInvite(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var isAlreadyJoined bool
|
membershipReq := api.QueryMembershipForUserRequest{
|
||||||
if info != nil {
|
RoomID: validRoomID.String(),
|
||||||
_, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey())
|
UserID: *event.StateKey(),
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("r.DB.GetMembership: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if isAlreadyJoined {
|
res := api.QueryMembershipForUserResponse{}
|
||||||
|
err = r.RSAPI.QueryMembershipForUser(ctx, &membershipReq, &res)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("r.RSAPI.QueryMembershipForUser: %w", err)
|
||||||
|
}
|
||||||
|
if res.Membership == spec.Join {
|
||||||
// If the user is joined to the room then that takes precedence over this
|
// If the user is joined to the room then that takes precedence over this
|
||||||
// invite event. It makes little sense to move a user that is already
|
// invite event. It makes little sense to move a user that is already
|
||||||
// joined to the room into the invite state.
|
// joined to the room into the invite state.
|
||||||
|
|
@ -213,7 +234,7 @@ func (r *Inviter) PerformInvite(
|
||||||
// try and see if the user is allowed to make this invite. We can't do
|
// try and see if the user is allowed to make this invite. We can't do
|
||||||
// this for invites coming in over federation - we have to take those on
|
// this for invites coming in over federation - we have to take those on
|
||||||
// trust.
|
// trust.
|
||||||
_, err = helpers.CheckAuthEvents(ctx, r.DB, info, event, event.AuthEventIDs())
|
_, err = helpers.CheckAuthEvents(ctx, r.DB, req.RoomVersion, event, event.AuthEventIDs())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
|
logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
|
||||||
"ProcessInviteEvent.checkAuthEvents failed for event",
|
"ProcessInviteEvent.checkAuthEvents failed for event",
|
||||||
|
|
@ -266,68 +287,3 @@ func (r *Inviter) PerformInvite(
|
||||||
// gets the invite, as the roomserver will do this when it processes the m.room.member invite.
|
// gets the invite, as the roomserver will do this when it processes the m.room.member invite.
|
||||||
return outputUpdates, nil
|
return outputUpdates, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Inviter) generateInviteStrippedState(
|
|
||||||
ctx context.Context, roomID spec.RoomID, inviteEvent *types.HeaderedEvent, inviteState []gomatrixserverlib.InviteStrippedState,
|
|
||||||
) (*types.RoomInfo, []gomatrixserverlib.InviteStrippedState, error) {
|
|
||||||
info, err := r.DB.RoomInfo(ctx, roomID.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("failed to load RoomInfo: %w", err)
|
|
||||||
}
|
|
||||||
strippedState := inviteState
|
|
||||||
if len(strippedState) == 0 && info != nil {
|
|
||||||
var is []gomatrixserverlib.InviteStrippedState
|
|
||||||
if is, err = buildInviteStrippedState(ctx, r.DB, info, inviteEvent); err == nil {
|
|
||||||
strippedState = is
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return info, strippedState, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildInviteStrippedState(
|
|
||||||
ctx context.Context,
|
|
||||||
db storage.Database,
|
|
||||||
info *types.RoomInfo,
|
|
||||||
inviteEvent *types.HeaderedEvent,
|
|
||||||
) ([]gomatrixserverlib.InviteStrippedState, error) {
|
|
||||||
stateWanted := []gomatrixserverlib.StateKeyTuple{}
|
|
||||||
// "If they are set on the room, at least the state for m.room.avatar, m.room.canonical_alias, m.room.join_rules, and m.room.name SHOULD be included."
|
|
||||||
// https://matrix.org/docs/spec/client_server/r0.6.0#m-room-member
|
|
||||||
for _, t := range []string{
|
|
||||||
spec.MRoomName, spec.MRoomCanonicalAlias,
|
|
||||||
spec.MRoomJoinRules, spec.MRoomAvatar,
|
|
||||||
spec.MRoomEncryption, spec.MRoomCreate,
|
|
||||||
} {
|
|
||||||
stateWanted = append(stateWanted, gomatrixserverlib.StateKeyTuple{
|
|
||||||
EventType: t,
|
|
||||||
StateKey: "",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
roomState := state.NewStateResolution(db, info)
|
|
||||||
stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples(
|
|
||||||
ctx, info.StateSnapshotNID(), stateWanted,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
stateNIDs := []types.EventNID{}
|
|
||||||
for _, stateNID := range stateEntries {
|
|
||||||
stateNIDs = append(stateNIDs, stateNID.EventNID)
|
|
||||||
}
|
|
||||||
if info == nil {
|
|
||||||
return nil, types.ErrorInvalidRoomInfo
|
|
||||||
}
|
|
||||||
stateEvents, err := db.Events(ctx, info.RoomVersion, stateNIDs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
inviteState := []gomatrixserverlib.InviteStrippedState{
|
|
||||||
gomatrixserverlib.NewInviteStrippedState(inviteEvent.PDU),
|
|
||||||
}
|
|
||||||
stateEvents = append(stateEvents, types.Event{PDU: inviteEvent.PDU})
|
|
||||||
for _, event := range stateEvents {
|
|
||||||
inviteState = append(inviteState, gomatrixserverlib.NewInviteStrippedState(event.PDU))
|
|
||||||
}
|
|
||||||
return inviteState, nil
|
|
||||||
}
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue