From 285d0638275a76cdabb4ff268bcec2743a9141ac Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 27 Jun 2022 11:42:08 +0100 Subject: [PATCH] Membership updater refactoring --- federationapi/api/api.go | 1 + federationapi/internal/perform.go | 1 + roomserver/internal/helpers/helpers.go | 2 +- roomserver/internal/input/input_membership.go | 39 +---- roomserver/internal/perform/perform_leave.go | 2 +- .../storage/shared/membership_updater.go | 155 ++++++------------ 6 files changed, 55 insertions(+), 145 deletions(-) diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 53d4701f3..89e726f98 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -181,6 +181,7 @@ type PerformLeaveRequest struct { } type PerformLeaveResponse struct { + Event *gomatrixserverlib.HeaderedEvent `json:"event"` } type PerformInviteRequest struct { diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 9100c8f18..98fac039b 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -561,6 +561,7 @@ func (r *FederationInternalAPI) PerformLeave( } r.statistics.ForServer(serverName).Success() + response.Event = event.Headered(respMakeLeave.RoomVersion) return nil } diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index e67bbfcaa..474f3892c 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -28,7 +28,7 @@ func UpdateToInviteMembership( // reprocessing this event, or because the we received this invite from a // remote server via the federation invite API. In those cases we don't need // to send the event. - needsSending, err := mu.SetToInvite(add) + needsSending, _, err := mu.Update(add) if err != nil { return nil, err } diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 3ce8791a3..11993fd87 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -90,26 +90,13 @@ func (r *Inputer) updateMembership( ) ([]api.OutputEvent, error) { var err error // Default the membership to Leave if no event was added or removed. - oldMembership := gomatrixserverlib.Leave newMembership := gomatrixserverlib.Leave - - if remove != nil { - oldMembership, err = remove.Membership() - if err != nil { - return nil, err - } - } if add != nil { newMembership, err = add.Membership() if err != nil { return nil, err } } - if oldMembership == newMembership && newMembership != gomatrixserverlib.Join { - // If the membership is the same then nothing changed and we can return - // immediately, unless it's a Join update (e.g. profile update). - return updates, nil - } // In an ideal world, we shouldn't ever have "add" be nil and "remove" be // set, as this implies that we're deleting a state event without replacing @@ -161,21 +148,11 @@ func (r *Inputer) isLocalTarget(event *gomatrixserverlib.Event) bool { func updateToJoinMembership( mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, ) ([]api.OutputEvent, error) { - // If the user is already marked as being joined, we call SetToJoin to update - // the event ID then we can return immediately. Retired is ignored as there - // is no invite event to retire. - if mu.IsJoin() { - _, err := mu.SetToJoin(add.Sender(), add.EventID(), true) - if err != nil { - return nil, err - } - return updates, nil - } // When we mark a user as being joined we will invalidate any invites that // are active for that user. We notify the consumers that the invites have // been retired using a special event, even though they could infer this // by studying the state changes in the room event stream. - retired, err := mu.SetToJoin(add.Sender(), add.EventID(), false) + _, retired, err := mu.Update(add) if err != nil { return nil, err } @@ -198,16 +175,11 @@ func updateToLeaveMembership( mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, newMembership string, updates []api.OutputEvent, ) ([]api.OutputEvent, error) { - // If the user is already neither joined, nor invited to the room then we - // can return immediately. - if mu.IsLeave() { - return updates, nil - } // When we mark a user as having left we will invalidate any invites that // are active for that user. We notify the consumers that the invites have // been retired using a special event, even though they could infer this // by studying the state changes in the room event stream. - retired, err := mu.SetToLeave(add.Sender(), add.EventID()) + _, retired, err := mu.Update(add) if err != nil { return nil, err } @@ -229,11 +201,8 @@ func updateToLeaveMembership( func updateToKnockMembership( mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, ) ([]api.OutputEvent, error) { - if mu.IsLeave() { - _, err := mu.SetToKnock(add) - if err != nil { - return nil, err - } + if _, _, err := mu.Update(add); err != nil { + return nil, err } return updates, nil } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index c5b62ac00..0377cbb84 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -228,7 +228,7 @@ func (r *Leaver) performFederatedRejectInvite( util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event") } if updater != nil { - if _, err = updater.SetToLeave(req.UserID, eventID); err != nil { + if _, _, err = updater.Update(leaveRes.Event.Unwrap()); err != nil { util.GetLogger(ctx).WithError(err).Errorf("failed to set membership to leave, still retiring invite event") if err = updater.Rollback(); err != nil { util.GetLogger(ctx).WithError(err).Errorf("failed to rollback membership leave, still retiring invite event") diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index ebfcef569..277b2d89a 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -15,7 +15,7 @@ type MembershipUpdater struct { d *Database roomNID types.RoomNID targetUserNID types.EventStateKeyNID - membership tables.MembershipState + oldMembership tables.MembershipState } func NewMembershipUpdater( @@ -30,7 +30,6 @@ func NewMembershipUpdater( if err != nil { return err } - targetUserNID, err = d.assignStateKeyNID(ctx, targetUserID) if err != nil { return err @@ -73,139 +72,79 @@ func (d *Database) membershipUpdaterTxn( // IsInvite implements types.MembershipUpdater func (u *MembershipUpdater) IsInvite() bool { - return u.membership == tables.MembershipStateInvite + return u.oldMembership == tables.MembershipStateInvite } // IsJoin implements types.MembershipUpdater func (u *MembershipUpdater) IsJoin() bool { - return u.membership == tables.MembershipStateJoin + return u.oldMembership == tables.MembershipStateJoin } // IsLeave implements types.MembershipUpdater func (u *MembershipUpdater) IsLeave() bool { - return u.membership == tables.MembershipStateLeaveOrBan + return u.oldMembership == tables.MembershipStateLeaveOrBan } // IsKnock implements types.MembershipUpdater func (u *MembershipUpdater) IsKnock() bool { - return u.membership == tables.MembershipStateKnock + return u.oldMembership == tables.MembershipStateKnock } -// SetToInvite implements types.MembershipUpdater -func (u *MembershipUpdater) SetToInvite(event *gomatrixserverlib.Event) (bool, error) { +func (u *MembershipUpdater) Update(event *gomatrixserverlib.Event) (bool, []string, error) { var inserted bool - err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + var retired []string + return inserted, retired, u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + membership, err := event.Membership() + if err != nil { + return fmt.Errorf("event.Membership: %w", err) + } senderUserNID, err := u.d.assignStateKeyNID(u.ctx, event.Sender()) if err != nil { return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } - inserted, err = u.d.InvitesTable.InsertInviteEvent( - u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), - ) - if err != nil { - return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err) + var newMembership tables.MembershipState + switch membership { + case gomatrixserverlib.Join: + newMembership = tables.MembershipStateJoin + case gomatrixserverlib.Leave, gomatrixserverlib.Ban: + newMembership = tables.MembershipStateLeaveOrBan + case gomatrixserverlib.Invite: + newMembership = tables.MembershipStateInvite + case gomatrixserverlib.Knock: + newMembership = tables.MembershipStateKnock + default: + return fmt.Errorf("unrecognised membership %q", membership) } - if u.membership != tables.MembershipStateInvite { - if inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, false); err != nil { - return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) + eventID := event.EventID() + eventNIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false) + if err != nil { + return fmt.Errorf("u.d.eventNIDs: %w", err) + } + inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, newMembership, eventNIDs[eventID], false) + if err != nil { + return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) + } + if !inserted { + return nil + } + switch { + case u.oldMembership == tables.MembershipStateLeaveOrBan && newMembership == tables.MembershipStateInvite: + // add invite entry + inserted, err = u.d.InvitesTable.InsertInviteEvent( + u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), + ) + if err != nil { + return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err) } - } - return nil - }) - return inserted, err -} - -// SetToJoin implements types.MembershipUpdater -func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { - var inviteEventIDs []string - - err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, senderUserID) - if err != nil { - return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) - } - - // If this is a join event update, there is no invite to update - if !isUpdate { - inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired( + case u.oldMembership == tables.MembershipStateInvite && newMembership == tables.MembershipStateLeaveOrBan: + // retire event + retired, err = u.d.InvitesTable.UpdateInviteRetired( u.ctx, u.txn, u.roomNID, u.targetUserNID, ) if err != nil { return fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err) } } - - // Look up the NID of the new join event - nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false) - if err != nil { - return fmt.Errorf("u.d.EventNIDs: %w", err) - } - - if u.membership != tables.MembershipStateJoin || isUpdate { - if _, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], false); err != nil { - return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) - } - } - return nil }) - - return inviteEventIDs, err -} - -// SetToLeave implements types.MembershipUpdater -func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { - var inviteEventIDs []string - - err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, senderUserID) - if err != nil { - return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) - } - inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired( - u.ctx, u.txn, u.roomNID, u.targetUserNID, - ) - if err != nil { - return fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err) - } - - // Look up the NID of the new leave event - nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false) - if err != nil { - return fmt.Errorf("u.d.EventNIDs: %w", err) - } - - if u.membership != tables.MembershipStateLeaveOrBan { - if _, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], false); err != nil { - return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) - } - } - - return nil - }) - return inviteEventIDs, err -} - -// SetToKnock implements types.MembershipUpdater -func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, error) { - var inserted bool - err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, event.Sender()) - if err != nil { - return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) - } - if u.membership != tables.MembershipStateKnock { - // Look up the NID of the new knock event - nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}, false) - if err != nil { - return fmt.Errorf("u.d.EventNIDs: %w", err) - } - - if inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateKnock, nIDs[event.EventID()], false); err != nil { - return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) - } - } - return nil - }) - return inserted, err }