diff --git a/clientapi/producers/roomserver.go b/clientapi/producers/roomserver.go index f537557f5..513674607 100644 --- a/clientapi/producers/roomserver.go +++ b/clientapi/producers/roomserver.go @@ -99,3 +99,16 @@ func (c *RoomserverProducer) SendInputRoomEvents( eventID = response.EventID return } + +// SendInputNewInviteEvents writes the given input new events to the roomserver input API. +// The roomserver will automatically populate the invite room state for us before sending +// the invite onward. +func (c *RoomserverProducer) SendInputNewInviteEvents( + ctx context.Context, ires []api.InputRoomEvent, +) (eventID string, err error) { + request := api.InputRoomEventsRequest{InputRoomEvents: ires} + var response api.InputRoomEventsResponse + err = c.InputAPI.InputNewInviteEvents(ctx, &request, &response) + eventID = response.EventID + return +} diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 9f386b718..fd29901ac 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -104,14 +104,31 @@ func SendMembership( return jsonerror.InternalServerError() } - if _, err := producer.SendEvents( - req.Context(), - []gomatrixserverlib.HeaderedEvent{(*event).Headered(verRes.RoomVersion)}, - cfg.Matrix.ServerName, - nil, - ); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("producer.SendEvents failed") - return jsonerror.InternalServerError() + if membership == gomatrixserverlib.Invite { + if _, err := producer.SendInputNewInviteEvents( + req.Context(), + []api.InputRoomEvent{ + api.InputRoomEvent{ + Kind: api.KindNew, + Event: event.Headered(verRes.RoomVersion), + AuthEventIDs: event.AuthEventIDs(), + SendAsServer: string(cfg.Matrix.ServerName), + }, + }, + ); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("producer.SendEvents failed") + return jsonerror.InternalServerError() + } + } else { + if _, err := producer.SendEvents( + req.Context(), + []gomatrixserverlib.HeaderedEvent{(*event).Headered(verRes.RoomVersion)}, + cfg.Matrix.ServerName, + nil, + ); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("producer.SendEvents failed") + return jsonerror.InternalServerError() + } } var returnData interface{} = struct{}{} diff --git a/roomserver/api/input.go b/roomserver/api/input.go index fbedff2ed..41c741f80 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -99,10 +99,17 @@ type RoomserverInputAPI interface { request *InputRoomEventsRequest, response *InputRoomEventsResponse, ) error + + InputNewInviteEvents( + ctx context.Context, + request *InputRoomEventsRequest, + response *InputRoomEventsResponse, + ) error } // RoomserverInputRoomEventsPath is the HTTP path for the InputRoomEvents API. const RoomserverInputRoomEventsPath = "/api/roomserver/inputRoomEvents" +const RoomserverInputNewInviteEventsPath = "/api/roomserver/inputRoomEvents" // NewRoomserverInputAPIHTTP creates a RoomserverInputAPI implemented by talking to a HTTP POST API. // If httpClient is nil an error is returned @@ -130,3 +137,16 @@ func (h *httpRoomserverInputAPI) InputRoomEvents( apiURL := h.roomserverURL + RoomserverInputRoomEventsPath return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } + +// InputRoomEvents implements RoomserverInputAPI +func (h *httpRoomserverInputAPI) InputNewInviteEvents( + ctx context.Context, + request *InputRoomEventsRequest, + response *InputRoomEventsResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "InputNewInviteEvents") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverInputNewInviteEventsPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/roomserver/input/events.go b/roomserver/input/events.go index 64d696cd1..4fd5f69a3 100644 --- a/roomserver/input/events.go +++ b/roomserver/input/events.go @@ -280,53 +280,4 @@ func processInviteEvent( succeeded = true return nil } - -func buildInviteStrippedState( - ctx context.Context, - db RoomEventDatabase, - input api.InputInviteEvent, -) (json.RawMessage, error) { - roomNID, err := db.RoomNID(ctx, input.Event.RoomID()) - if err != nil || roomNID == 0 { - return nil, nil - } - stateWanted := []gomatrixserverlib.StateKeyTuple{} - for _, t := range []string{ - gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomCanonicalAlias, - gomatrixserverlib.MRoomAliases, gomatrixserverlib.MRoomJoinRules, - } { - stateWanted = append(stateWanted, gomatrixserverlib.StateKeyTuple{ - EventType: t, - StateKey: "", - }) - } - _, currentStateSnapshotNID, _, err := db.LatestEventIDs(ctx, roomNID) - if err != nil { - return nil, err - } - roomState := state.NewStateResolution(db) - stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( - ctx, currentStateSnapshotNID, stateWanted, - ) - if err != nil { - return nil, err - } - stateNIDs := []types.EventNID{} - for _, stateNID := range stateEntries { - stateNIDs = append(stateNIDs, stateNID.EventNID) - } - stateEvents, err := db.Events(ctx, stateNIDs) - if err != nil { - return nil, err - } - inviteState := []gomatrixserverlib.InviteV2StrippedState{} - for _, event := range stateEvents { - inviteState = append(inviteState, gomatrixserverlib.NewInviteV2StrippedState(&event.Event)) - } - inviteStrippedState, err := json.Marshal(inviteState) - if err != nil { - return nil, err - } - return inviteStrippedState, nil -} */ diff --git a/roomserver/input/input.go b/roomserver/input/input.go index fb69e9014..ee3eebf0f 100644 --- a/roomserver/input/input.go +++ b/roomserver/input/input.go @@ -72,6 +72,24 @@ func (r *RoomserverInputAPI) InputRoomEvents( return nil } +// InputNewInviteEvents implements api.RoomserverInputAPI +func (r *RoomserverInputAPI) InputNewInviteEvents( + ctx context.Context, + request *api.InputRoomEventsRequest, + response *api.InputRoomEventsResponse, +) (err error) { + for i := range request.InputRoomEvents { + inviteRoomState, err := buildInviteStrippedState(ctx, r.DB, request.InputRoomEvents[i]) + if err != nil { + return err + } + if err := request.InputRoomEvents[i].Event.SetUnsignedField("invite_room_state", inviteRoomState); err != nil { + return err + } + } + return r.InputRoomEvents(ctx, request, response) +} + // SetupHTTP adds the RoomserverInputAPI handlers to the http.ServeMux. func (r *RoomserverInputAPI) SetupHTTP(servMux *http.ServeMux) { servMux.Handle(api.RoomserverInputRoomEventsPath, diff --git a/roomserver/input/invites.go b/roomserver/input/invites.go new file mode 100644 index 000000000..2217e9e7e --- /dev/null +++ b/roomserver/input/invites.go @@ -0,0 +1,60 @@ +package input + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +func buildInviteStrippedState( + ctx context.Context, + db RoomEventDatabase, + input api.InputRoomEvent, +) (json.RawMessage, error) { + roomNID, err := db.RoomNID(ctx, input.Event.RoomID()) + if err != nil || roomNID == 0 { + return nil, nil + } + stateWanted := []gomatrixserverlib.StateKeyTuple{} + for _, t := range []string{ + gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomCanonicalAlias, + gomatrixserverlib.MRoomAliases, gomatrixserverlib.MRoomJoinRules, + } { + stateWanted = append(stateWanted, gomatrixserverlib.StateKeyTuple{ + EventType: t, + StateKey: "", + }) + } + _, currentStateSnapshotNID, _, err := db.LatestEventIDs(ctx, roomNID) + if err != nil { + return nil, err + } + roomState := state.NewStateResolution(db) + stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( + ctx, currentStateSnapshotNID, stateWanted, + ) + if err != nil { + return nil, err + } + stateNIDs := []types.EventNID{} + for _, stateNID := range stateEntries { + stateNIDs = append(stateNIDs, stateNID.EventNID) + } + stateEvents, err := db.Events(ctx, stateNIDs) + if err != nil { + return nil, err + } + inviteState := []gomatrixserverlib.InviteV2StrippedState{} + for _, event := range stateEvents { + inviteState = append(inviteState, gomatrixserverlib.NewInviteV2StrippedState(&event.Event)) + } + inviteStrippedState, err := json.Marshal(inviteState) + if err != nil { + return nil, err + } + return inviteStrippedState, nil +}