From 60653878d7fa2c31c80b90e706e8805c69aea56b Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 24 May 2023 09:55:55 -0600 Subject: [PATCH] Refactor invite to consolidate shared logic --- federationapi/routing/invite.go | 52 ++++----- go.mod | 2 +- go.sum | 4 +- roomserver/api/api.go | 2 +- roomserver/internal/api.go | 6 +- roomserver/internal/perform/perform_invite.go | 101 +++++++++--------- 6 files changed, 83 insertions(+), 84 deletions(-) diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 16bfd3541..4f5ae0949 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -76,19 +76,19 @@ func InviteV2( } input := gomatrixserverlib.HandleInviteInput{ - Context: httpReq.Context(), - RoomVersion: inviteReq.RoomVersion(), - RoomID: roomID, - EventID: eventID, - InvitedUser: *invitedUser, - KeyID: cfg.Matrix.KeyID, - PrivateKey: cfg.Matrix.PrivateKey, - Verifier: keys, - InviteQuerier: rsAPI, - MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, - GenerateStrippedState: rsAPI.GenerateInviteStrippedState, - InviteEvent: inviteReq.Event(), - StrippedState: inviteReq.InviteRoomState(), + Context: httpReq.Context(), + RoomVersion: inviteReq.RoomVersion(), + RoomID: roomID, + EventID: eventID, + InvitedUser: *invitedUser, + KeyID: cfg.Matrix.KeyID, + PrivateKey: cfg.Matrix.PrivateKey, + Verifier: keys, + InviteQuerier: rsAPI, + MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, + StateQuerier: rsAPI.StateQuerier(), + InviteEvent: inviteReq.Event(), + StrippedState: inviteReq.InviteRoomState(), } event, jsonErr := handleInvite(input, rsAPI) if jsonErr != nil { @@ -161,19 +161,19 @@ func InviteV1( } input := gomatrixserverlib.HandleInviteInput{ - Context: httpReq.Context(), - RoomVersion: roomVer, - RoomID: roomID, - EventID: eventID, - InvitedUser: *invitedUser, - KeyID: cfg.Matrix.KeyID, - PrivateKey: cfg.Matrix.PrivateKey, - Verifier: keys, - InviteQuerier: rsAPI, - MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, - GenerateStrippedState: rsAPI.GenerateInviteStrippedState, - InviteEvent: event, - StrippedState: strippedState, + Context: httpReq.Context(), + RoomVersion: roomVer, + RoomID: roomID, + EventID: eventID, + InvitedUser: *invitedUser, + KeyID: cfg.Matrix.KeyID, + PrivateKey: cfg.Matrix.PrivateKey, + Verifier: keys, + InviteQuerier: rsAPI, + MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, + StateQuerier: rsAPI.StateQuerier(), + InviteEvent: event, + StrippedState: strippedState, } event, jsonErr := handleInvite(input, rsAPI) if jsonErr != nil { diff --git a/go.mod b/go.mod index ab70f155b..ea45e47ef 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230524040519-2b64a2fae808 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230524154314-f588632e6d1a github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index 3b638d9bd..6d64f0296 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230524040519-2b64a2fae808 h1:71+KiT2O0AVk+Gb/SeFhs/gN9VHexGkEgYPR3soSDvo= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230524040519-2b64a2fae808/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230524154314-f588632e6d1a h1:a/9ZAYb7AsjVY1gRRmILKvbd27/sWYDjQweksB33VmQ= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230524154314-f588632e6d1a/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/roomserver/api/api.go b/roomserver/api/api.go index b4e3de564..544597f3f 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -236,7 +236,7 @@ type FederationRoomserverAPI interface { LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error) IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error) - GenerateInviteStrippedState(ctx context.Context, roomID spec.RoomID, stateWanted []gomatrixserverlib.StateKeyTuple, inviteEvent gomatrixserverlib.PDU) ([]gomatrixserverlib.InviteStrippedState, error) + StateQuerier() gomatrixserverlib.StateQuerier } type KeyserverRoomserverAPI interface { diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index e97dc8d1d..ff619d079 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -212,10 +212,8 @@ func (r *RoomserverInternalAPI) IsKnownRoom(ctx context.Context, roomID spec.Roo return r.Inviter.IsKnownRoom(ctx, roomID) } -func (r *RoomserverInternalAPI) GenerateInviteStrippedState( - ctx context.Context, roomID spec.RoomID, stateWanted []gomatrixserverlib.StateKeyTuple, inviteEvent gomatrixserverlib.PDU, -) ([]gomatrixserverlib.InviteStrippedState, error) { - return r.Inviter.GenerateInviteStrippedState(ctx, roomID, stateWanted, inviteEvent) +func (r *RoomserverInternalAPI) StateQuerier() gomatrixserverlib.StateQuerier { + return r.Inviter.StateQuerier() } func (r *RoomserverInternalAPI) HandleInvite( diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index b78984950..fee5f0f00 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -32,6 +32,47 @@ import ( "github.com/matrix-org/util" ) +type QueryState struct { + storage.Database +} + +func (q *QueryState) GetAuthEvents(ctx context.Context, event gomatrixserverlib.PDU) (gomatrixserverlib.AuthEventProvider, error) { + return helpers.GetAuthEvents(ctx, q.Database, event.Version(), event, event.AuthEventIDs()) +} + +func (q *QueryState) GetState(ctx context.Context, roomID spec.RoomID, stateWanted []gomatrixserverlib.StateKeyTuple) ([]gomatrixserverlib.PDU, error) { + info, err := q.Database.RoomInfo(ctx, roomID.String()) + if err != nil { + return nil, fmt.Errorf("failed to load RoomInfo: %w", err) + } + if info != nil { + roomState := state.NewStateResolution(q.Database, info) + stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( + ctx, info.StateSnapshotNID(), stateWanted, + ) + if err != nil { + return nil, nil + } + stateNIDs := []types.EventNID{} + for _, stateNID := range stateEntries { + stateNIDs = append(stateNIDs, stateNID.EventNID) + } + stateEvents, err := q.Database.Events(ctx, info.RoomVersion, stateNIDs) + if err != nil { + // TODO: really? no err? + return nil, nil + } + + events := []gomatrixserverlib.PDU{} + for _, event := range stateEvents { + events = append(events, event.PDU) + } + return events, nil + } + + return nil, nil +} + type Inviter struct { DB storage.Database Cfg *config.RoomServer @@ -48,39 +89,8 @@ func (r *Inviter) IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, er return (info != nil && !info.IsStub()), nil } -func (r *Inviter) GenerateInviteStrippedState( - ctx context.Context, roomID spec.RoomID, stateWanted []gomatrixserverlib.StateKeyTuple, inviteEvent gomatrixserverlib.PDU, -) ([]gomatrixserverlib.InviteStrippedState, error) { - info, err := r.DB.RoomInfo(ctx, roomID.String()) - if err != nil { - return nil, fmt.Errorf("failed to load RoomInfo: %w", err) - } - if info != nil { - roomState := state.NewStateResolution(r.DB, info) - stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( - ctx, info.StateSnapshotNID(), stateWanted, - ) - if err != nil { - return nil, nil - } - stateNIDs := []types.EventNID{} - for _, stateNID := range stateEntries { - stateNIDs = append(stateNIDs, stateNID.EventNID) - } - stateEvents, err := r.DB.Events(ctx, info.RoomVersion, stateNIDs) - if err != nil { - return nil, nil - } - inviteState := []gomatrixserverlib.InviteStrippedState{ - gomatrixserverlib.NewInviteStrippedState(inviteEvent), - } - stateEvents = append(stateEvents, types.Event{PDU: inviteEvent}) - for _, event := range stateEvents { - inviteState = append(inviteState, gomatrixserverlib.NewInviteStrippedState(event.PDU)) - } - return inviteState, nil - } - return nil, nil +func (r *Inviter) StateQuerier() gomatrixserverlib.StateQuerier { + return &QueryState{Database: r.DB} } func (r *Inviter) ProcessInviteMembership( @@ -109,14 +119,6 @@ func (r *Inviter) ProcessInviteMembership( return outputUpdates, nil } -type QueryState struct { - storage.Database -} - -func (q *QueryState) GetAuthEvents(ctx context.Context, event gomatrixserverlib.PDU) (gomatrixserverlib.AuthEventProvider, error) { - return helpers.GetAuthEvents(ctx, q.Database, event.Version(), event, event.AuthEventIDs()) -} - // nolint:gocyclo func (r *Inviter) PerformInvite( ctx context.Context, @@ -147,15 +149,14 @@ func (r *Inviter) PerformInvite( } input := gomatrixserverlib.PerformInviteInput{ - Context: ctx, - RoomID: *validRoomID, - Event: event.PDU, - InvitedUser: *invitedUser, - IsTargetLocal: isTargetLocal, - StrippedState: req.InviteRoomState, - MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, - StateQuerier: &QueryState{r.DB}, - GenerateStrippedState: r.GenerateInviteStrippedState, + Context: ctx, + RoomID: *validRoomID, + Event: event.PDU, + InvitedUser: *invitedUser, + IsTargetLocal: isTargetLocal, + StrippedState: req.InviteRoomState, + MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, + StateQuerier: &QueryState{r.DB}, } inviteEvent, err := gomatrixserverlib.PerformInvite(input, r.FSAPI) if err != nil {