From c3643feaf5eb058c626861f356f0a4d0d30bc2c8 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 6 May 2020 17:53:26 +0100 Subject: [PATCH] Fix invite loopback --- roomserver/internal/input.go | 11 ++++- roomserver/internal/input_events.go | 63 ++++++++++++++++++++++++----- 2 files changed, 62 insertions(+), 12 deletions(-) diff --git a/roomserver/internal/input.go b/roomserver/internal/input.go index 16f6d6bba..a3a88e409 100644 --- a/roomserver/internal/input.go +++ b/roomserver/internal/input.go @@ -18,6 +18,7 @@ package internal import ( "context" "encoding/json" + "fmt" "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/roomserver/api" @@ -59,9 +60,17 @@ func (r *RoomserverInternalAPI) InputRoomEvents( r.mutex.Lock() defer r.mutex.Unlock() for i := range request.InputInviteEvents { - if err = processInviteEvent(ctx, r.DB, r, request.InputInviteEvents[i]); err != nil { + var loopback *api.InputRoomEvent + if loopback, err = processInviteEvent(ctx, r.DB, r, request.InputInviteEvents[i]); err != nil { return err } + // The processInviteEvent function can optionally return a + // loopback room event containing the invite, for local invites. + // If it does, we should process it with the room events below. + if loopback != nil { + fmt.Println("LOOPING BACK", string(loopback.Event.JSON())) + request.InputRoomEvents = append(request.InputRoomEvents, *loopback) + } } for i := range request.InputRoomEvents { if response.EventID, err = processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil { diff --git a/roomserver/internal/input_events.go b/roomserver/internal/input_events.go index 6da63716c..b17076efe 100644 --- a/roomserver/internal/input_events.go +++ b/roomserver/internal/input_events.go @@ -18,6 +18,7 @@ package internal import ( "context" + "errors" "fmt" "github.com/matrix-org/dendrite/common" @@ -132,11 +133,11 @@ func calculateAndSetState( func processInviteEvent( ctx context.Context, db storage.Database, - ow OutputRoomEventWriter, + ow *RoomserverInternalAPI, input api.InputInviteEvent, -) (err error) { +) (*api.InputRoomEvent, error) { if input.Event.StateKey() == nil { - return fmt.Errorf("invite must be a state event") + return nil, fmt.Errorf("invite must be a state event") } roomID := input.Event.RoomID() @@ -151,7 +152,7 @@ func processInviteEvent( updater, err := db.MembershipUpdater(ctx, roomID, targetUserID, input.RoomVersion) if err != nil { - return err + return nil, err } succeeded := false defer func() { @@ -189,17 +190,27 @@ func processInviteEvent( // For now we will implement option 2. Since in the abesence of a retry // mechanism it will be equivalent to option 1, and we don't have a // signalling mechanism to implement option 3. - return nil + return nil, nil + } + + // Normally, with a federated invite, the federation sender would do + // the /v2/invite request (in which the remote server signs the invite) + // and then the signed event gets sent back to the roomserver as an input + // event. When the invite is local, we don't interact with the federation + // sender therefore we need to generate the loopback invite event for + // the room ourselves. + loopback, err := localInviteLoopback(ow, input) + if err != nil { + return nil, err } event := input.Event.Unwrap() - if len(input.InviteRoomState) > 0 { // If we were supplied with some invite room state already (which is // most likely to be if the event came in over federation) then use // that. if err = event.SetUnsignedField("invite_room_state", input.InviteRoomState); err != nil { - return err + return nil, err } } else { // There's no invite room state, so let's have a go at building it @@ -208,22 +219,52 @@ func processInviteEvent( // the invite room state, if we don't then we just fail quietly. if irs, ierr := buildInviteStrippedState(ctx, db, input); ierr == nil { if err = event.SetUnsignedField("invite_room_state", irs); err != nil { - return err + return nil, err } } } outputUpdates, err := updateToInviteMembership(updater, &event, nil, input.Event.RoomVersion) if err != nil { - return err + return nil, err } if err = ow.WriteOutputEvents(roomID, outputUpdates); err != nil { - return err + return nil, err } succeeded = true - return nil + return loopback, nil +} + +func localInviteLoopback( + ow *RoomserverInternalAPI, + input api.InputInviteEvent, +) (ire *api.InputRoomEvent, err error) { + if input.Event.StateKey() == nil { + return nil, errors.New("no state key on invite event") + } + ourServerName := string(ow.Cfg.Matrix.ServerName) + _, theirServerName, err := gomatrixserverlib.SplitID('@', *input.Event.StateKey()) + if err != nil { + return nil, err + } + // Check if the invite originated locally and is destined locally. + if input.Event.Origin() == ow.Cfg.Matrix.ServerName && string(theirServerName) == ourServerName { + rsEvent := input.Event.Sign( + ourServerName, + ow.Cfg.Matrix.KeyID, + ow.Cfg.Matrix.PrivateKey, + ).Headered(input.RoomVersion) + ire = &api.InputRoomEvent{ + Kind: api.KindNew, + Event: rsEvent, + AuthEventIDs: rsEvent.AuthEventIDs(), + SendAsServer: ourServerName, + TransactionID: nil, + } + } + return ire, nil } func buildInviteStrippedState(