Fix invite loopback

This commit is contained in:
Neil Alexander 2020-05-06 17:53:26 +01:00
parent c9dc2bcd9b
commit c3643feaf5
2 changed files with 62 additions and 12 deletions

View file

@ -18,6 +18,7 @@ package internal
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -59,9 +60,17 @@ func (r *RoomserverInternalAPI) InputRoomEvents(
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
for i := range request.InputInviteEvents { for i := range request.InputInviteEvents {
if err = processInviteEvent(ctx, r.DB, r, request.InputInviteEvents[i]); err != nil { var loopback *api.InputRoomEvent
if loopback, err = processInviteEvent(ctx, r.DB, r, request.InputInviteEvents[i]); err != nil {
return err return err
} }
// The processInviteEvent function can optionally return a
// loopback room event containing the invite, for local invites.
// If it does, we should process it with the room events below.
if loopback != nil {
fmt.Println("LOOPING BACK", string(loopback.Event.JSON()))
request.InputRoomEvents = append(request.InputRoomEvents, *loopback)
}
} }
for i := range request.InputRoomEvents { for i := range request.InputRoomEvents {
if response.EventID, err = processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil { if response.EventID, err = processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil {

View file

@ -18,6 +18,7 @@ package internal
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
@ -132,11 +133,11 @@ func calculateAndSetState(
func processInviteEvent( func processInviteEvent(
ctx context.Context, ctx context.Context,
db storage.Database, db storage.Database,
ow OutputRoomEventWriter, ow *RoomserverInternalAPI,
input api.InputInviteEvent, input api.InputInviteEvent,
) (err error) { ) (*api.InputRoomEvent, error) {
if input.Event.StateKey() == nil { if input.Event.StateKey() == nil {
return fmt.Errorf("invite must be a state event") return nil, fmt.Errorf("invite must be a state event")
} }
roomID := input.Event.RoomID() roomID := input.Event.RoomID()
@ -151,7 +152,7 @@ func processInviteEvent(
updater, err := db.MembershipUpdater(ctx, roomID, targetUserID, input.RoomVersion) updater, err := db.MembershipUpdater(ctx, roomID, targetUserID, input.RoomVersion)
if err != nil { if err != nil {
return err return nil, err
} }
succeeded := false succeeded := false
defer func() { defer func() {
@ -189,17 +190,27 @@ func processInviteEvent(
// For now we will implement option 2. Since in the abesence of a retry // For now we will implement option 2. Since in the abesence of a retry
// mechanism it will be equivalent to option 1, and we don't have a // mechanism it will be equivalent to option 1, and we don't have a
// signalling mechanism to implement option 3. // signalling mechanism to implement option 3.
return nil return nil, nil
}
// Normally, with a federated invite, the federation sender would do
// the /v2/invite request (in which the remote server signs the invite)
// and then the signed event gets sent back to the roomserver as an input
// event. When the invite is local, we don't interact with the federation
// sender therefore we need to generate the loopback invite event for
// the room ourselves.
loopback, err := localInviteLoopback(ow, input)
if err != nil {
return nil, err
} }
event := input.Event.Unwrap() event := input.Event.Unwrap()
if len(input.InviteRoomState) > 0 { if len(input.InviteRoomState) > 0 {
// If we were supplied with some invite room state already (which is // If we were supplied with some invite room state already (which is
// most likely to be if the event came in over federation) then use // most likely to be if the event came in over federation) then use
// that. // that.
if err = event.SetUnsignedField("invite_room_state", input.InviteRoomState); err != nil { if err = event.SetUnsignedField("invite_room_state", input.InviteRoomState); err != nil {
return err return nil, err
} }
} else { } else {
// There's no invite room state, so let's have a go at building it // There's no invite room state, so let's have a go at building it
@ -208,22 +219,52 @@ func processInviteEvent(
// the invite room state, if we don't then we just fail quietly. // the invite room state, if we don't then we just fail quietly.
if irs, ierr := buildInviteStrippedState(ctx, db, input); ierr == nil { if irs, ierr := buildInviteStrippedState(ctx, db, input); ierr == nil {
if err = event.SetUnsignedField("invite_room_state", irs); err != nil { if err = event.SetUnsignedField("invite_room_state", irs); err != nil {
return err return nil, err
} }
} }
} }
outputUpdates, err := updateToInviteMembership(updater, &event, nil, input.Event.RoomVersion) outputUpdates, err := updateToInviteMembership(updater, &event, nil, input.Event.RoomVersion)
if err != nil { if err != nil {
return err return nil, err
} }
if err = ow.WriteOutputEvents(roomID, outputUpdates); err != nil { if err = ow.WriteOutputEvents(roomID, outputUpdates); err != nil {
return err return nil, err
} }
succeeded = true succeeded = true
return nil return loopback, nil
}
func localInviteLoopback(
ow *RoomserverInternalAPI,
input api.InputInviteEvent,
) (ire *api.InputRoomEvent, err error) {
if input.Event.StateKey() == nil {
return nil, errors.New("no state key on invite event")
}
ourServerName := string(ow.Cfg.Matrix.ServerName)
_, theirServerName, err := gomatrixserverlib.SplitID('@', *input.Event.StateKey())
if err != nil {
return nil, err
}
// Check if the invite originated locally and is destined locally.
if input.Event.Origin() == ow.Cfg.Matrix.ServerName && string(theirServerName) == ourServerName {
rsEvent := input.Event.Sign(
ourServerName,
ow.Cfg.Matrix.KeyID,
ow.Cfg.Matrix.PrivateKey,
).Headered(input.RoomVersion)
ire = &api.InputRoomEvent{
Kind: api.KindNew,
Event: rsEvent,
AuthEventIDs: rsEvent.AuthEventIDs(),
SendAsServer: ourServerName,
TransactionID: nil,
}
}
return ire, nil
} }
func buildInviteStrippedState( func buildInviteStrippedState(