From cf1e12c81409a6a10c97ae6b615e2c5e7401735f Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 22 Jul 2021 10:46:40 +0100 Subject: [PATCH] Knock in membership updater --- roomserver/internal/input/input_membership.go | 14 +++++++++++ .../storage/shared/membership_updater.go | 23 +++++++++++++++++++ roomserver/storage/tables/interface.go | 1 + 3 files changed, 38 insertions(+) diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 44435bfd9..2511097d0 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -136,6 +136,8 @@ func (r *Inputer) updateMembership( return updateToJoinMembership(mu, add, updates) case gomatrixserverlib.Leave, gomatrixserverlib.Ban: return updateToLeaveMembership(mu, add, newMembership, updates) + case gomatrixserverlib.Knock: + return updateToKnockMembership(mu, add, updates) default: panic(fmt.Errorf( "input: membership %q is not one of the allowed values", newMembership, @@ -220,6 +222,18 @@ func updateToLeaveMembership( return updates, nil } +func updateToKnockMembership( + mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, +) ([]api.OutputEvent, error) { + if mu.IsLeave() { + _, err := mu.SetToKnock(add) + if err != nil { + return nil, err + } + } + return updates, nil +} + // membershipChanges pairs up the membership state changes. func membershipChanges(removed, added []types.StateEntry) []stateChange { changes := pairUpChanges(removed, added) diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 57f3a520a..29232c9d6 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -86,6 +86,11 @@ func (u *MembershipUpdater) IsLeave() bool { return u.membership == tables.MembershipStateLeaveOrBan } +// IsKnock implements types.MembershipUpdater +func (u *MembershipUpdater) IsKnock() bool { + return u.membership == tables.MembershipStateKnock +} + // SetToInvite implements types.MembershipUpdater func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { var inserted bool @@ -180,3 +185,21 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s }) return inviteEventIDs, err } + +// SetToKnock implements types.MembershipUpdater +func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, error) { + var inserted bool + err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) + if err != nil { + return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) + } + if u.membership != tables.MembershipStateKnock { + if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateKnock, 0, false); err != nil { + return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) + } + } + return nil + }) + return inserted, err +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index f762cb712..8720d4007 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -120,6 +120,7 @@ const ( MembershipStateLeaveOrBan MembershipState = 1 MembershipStateInvite MembershipState = 2 MembershipStateJoin MembershipState = 3 + MembershipStateKnock MembershipState = 4 ) type Membership interface {