Slightly refactor sendEvent and CreateRoom so it can be reused

This commit is contained in:
Till Faelligen 2022-02-11 17:12:49 +01:00
parent 2819183c2b
commit e2bf73fce2
2 changed files with 69 additions and 62 deletions

View file

@ -15,6 +15,7 @@
package routing package routing
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -140,33 +141,14 @@ func CreateRoom(
accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse { ) util.JSONResponse {
// TODO (#267): Check room ID doesn't clash with an existing one, and we
// probably shouldn't be using pseudo-random strings, maybe GUIDs?
roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), cfg.Matrix.ServerName)
return createRoom(req, device, cfg, roomID, accountDB, rsAPI, asAPI)
}
// createRoom implements /createRoom
// nolint: gocyclo
func createRoom(
req *http.Request, device *api.Device,
cfg *config.ClientAPI, roomID string,
accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
logger := util.GetLogger(req.Context())
userID := device.UserID
var r createRoomRequest var r createRoomRequest
resErr := httputil.UnmarshalJSONRequest(req, &r) resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
// TODO: apply rate-limit
if resErr = r.Validate(); resErr != nil { if resErr = r.Validate(); resErr != nil {
return *resErr return *resErr
} }
evTime, err := httputil.ParseTSParam(req) evTime, err := httputil.ParseTSParam(req)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -174,6 +156,25 @@ func createRoom(
JSON: jsonerror.InvalidArgumentValue(err.Error()), JSON: jsonerror.InvalidArgumentValue(err.Error()),
} }
} }
return createRoom(req.Context(), r, device, cfg, accountDB, rsAPI, asAPI, evTime)
}
// createRoom implements /createRoom
// nolint: gocyclo
func createRoom(
ctx context.Context,
r createRoomRequest, device *api.Device,
cfg *config.ClientAPI,
accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
evTime time.Time,
) util.JSONResponse {
// TODO (#267): Check room ID doesn't clash with an existing one, and we
// probably shouldn't be using pseudo-random strings, maybe GUIDs?
roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), cfg.Matrix.ServerName)
logger := util.GetLogger(ctx)
userID := device.UserID
// Clobber keys: creator, room_version // Clobber keys: creator, room_version
@ -200,16 +201,16 @@ func createRoom(
"roomVersion": roomVersion, "roomVersion": roomVersion,
}).Info("Creating new room") }).Info("Creating new room")
profile, err := appserviceAPI.RetrieveUserProfile(req.Context(), userID, asAPI, accountDB) profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, accountDB)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed") util.GetLogger(ctx).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
createContent := map[string]interface{}{} createContent := map[string]interface{}{}
if len(r.CreationContent) > 0 { if len(r.CreationContent) > 0 {
if err = json.Unmarshal(r.CreationContent, &createContent); err != nil { if err = json.Unmarshal(r.CreationContent, &createContent); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal for creation_content failed") util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for creation_content failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("invalid create content"), JSON: jsonerror.BadJSON("invalid create content"),
@ -230,7 +231,7 @@ func createRoom(
// Merge powerLevelContentOverride fields by unmarshalling it atop the defaults // Merge powerLevelContentOverride fields by unmarshalling it atop the defaults
err = json.Unmarshal(r.PowerLevelContentOverride, &powerLevelContent) err = json.Unmarshal(r.PowerLevelContentOverride, &powerLevelContent)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal for power_level_content_override failed") util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for power_level_content_override failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("malformed power_level_content_override"), JSON: jsonerror.BadJSON("malformed power_level_content_override"),
@ -319,9 +320,9 @@ func createRoom(
} }
var aliasResp roomserverAPI.GetRoomIDForAliasResponse var aliasResp roomserverAPI.GetRoomIDForAliasResponse
err = rsAPI.GetRoomIDForAlias(req.Context(), &hasAliasReq, &aliasResp) err = rsAPI.GetRoomIDForAlias(ctx, &hasAliasReq, &aliasResp)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") util.GetLogger(ctx).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if aliasResp.RoomID != "" { if aliasResp.RoomID != "" {
@ -426,7 +427,7 @@ func createRoom(
} }
err = builder.SetContent(e.Content) err = builder.SetContent(e.Content)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("builder.SetContent failed") util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if i > 0 { if i > 0 {
@ -435,12 +436,12 @@ func createRoom(
var ev *gomatrixserverlib.Event var ev *gomatrixserverlib.Event
ev, err = buildEvent(&builder, &authEvents, cfg, evTime, roomVersion) ev, err = buildEvent(&builder, &authEvents, cfg, evTime, roomVersion)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("buildEvent failed") util.GetLogger(ctx).WithError(err).Error("buildEvent failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.Allowed failed") util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -448,7 +449,7 @@ func createRoom(
builtEvents = append(builtEvents, ev.Headered(roomVersion)) builtEvents = append(builtEvents, ev.Headered(roomVersion))
err = authEvents.AddEvent(ev) err = authEvents.AddEvent(ev)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("authEvents.AddEvent failed") util.GetLogger(ctx).WithError(err).Error("authEvents.AddEvent failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
} }
@ -462,8 +463,8 @@ func createRoom(
SendAsServer: roomserverAPI.DoNotSendToOtherServers, SendAsServer: roomserverAPI.DoNotSendToOtherServers,
}) })
} }
if err = roomserverAPI.SendInputRoomEvents(req.Context(), rsAPI, inputs, false); err != nil { if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, inputs, false); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -478,9 +479,9 @@ func createRoom(
} }
var aliasResp roomserverAPI.SetRoomAliasResponse var aliasResp roomserverAPI.SetRoomAliasResponse
err = rsAPI.SetRoomAlias(req.Context(), &aliasReq, &aliasResp) err = rsAPI.SetRoomAlias(ctx, &aliasReq, &aliasResp)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.SetRoomAlias failed") util.GetLogger(ctx).WithError(err).Error("aliasAPI.SetRoomAlias failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -519,11 +520,11 @@ func createRoom(
for _, invitee := range r.Invite { for _, invitee := range r.Invite {
// Build the invite event. // Build the invite event.
inviteEvent, err := buildMembershipEvent( inviteEvent, err := buildMembershipEvent(
req.Context(), invitee, "", accountDB, device, gomatrixserverlib.Invite, ctx, invitee, "", accountDB, device, gomatrixserverlib.Invite,
roomID, true, cfg, evTime, rsAPI, asAPI, roomID, true, cfg, evTime, rsAPI, asAPI,
) )
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvent failed") util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed")
continue continue
} }
inviteStrippedState := append( inviteStrippedState := append(
@ -532,7 +533,7 @@ func createRoom(
) )
// Send the invite event to the roomserver. // Send the invite event to the roomserver.
err = roomserverAPI.SendInvite( err = roomserverAPI.SendInvite(
req.Context(), ctx,
rsAPI, rsAPI,
inviteEvent.Headered(roomVersion), inviteEvent.Headered(roomVersion),
inviteStrippedState, // invite room state inviteStrippedState, // invite room state
@ -544,7 +545,7 @@ func createRoom(
return e.JSONResponse() return e.JSONResponse()
case nil: case nil:
default: default:
util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInvite failed") util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInvite failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(), JSON: jsonerror.InternalServerError(),
@ -556,13 +557,13 @@ func createRoom(
if r.Visibility == "public" { if r.Visibility == "public" {
// expose this room in the published room list // expose this room in the published room list
var pubRes roomserverAPI.PerformPublishResponse var pubRes roomserverAPI.PerformPublishResponse
rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{ rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{
RoomID: roomID, RoomID: roomID,
Visibility: "public", Visibility: "public",
}, &pubRes) }, &pubRes)
if pubRes.Error != nil { if pubRes.Error != nil {
// treat as non-fatal since the room is already made by this point // treat as non-fatal since the room is already made by this point
util.GetLogger(req.Context()).WithError(pubRes.Error).Error("failed to visibility:public") util.GetLogger(ctx).WithError(pubRes.Error).Error("failed to visibility:public")
} }
} }

View file

@ -15,10 +15,16 @@
package routing package routing
import ( import (
"context"
"net/http" "net/http"
"sync" "sync"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
@ -26,10 +32,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
) )
// http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-send-eventtype-txnid // http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-send-eventtype-txnid
@ -97,7 +99,22 @@ func SendEvent(
defer mutex.(*sync.Mutex).Unlock() defer mutex.(*sync.Mutex).Unlock()
startedGeneratingEvent := time.Now() startedGeneratingEvent := time.Now()
e, resErr := generateSendEvent(req, device, roomID, eventType, stateKey, cfg, rsAPI)
var r map[string]interface{} // must be a JSON object
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return *resErr
}
evTime, err := httputil.ParseTSParam(req)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue(err.Error()),
}
}
e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, cfg, rsAPI, evTime)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
@ -153,27 +170,16 @@ func SendEvent(
} }
func generateSendEvent( func generateSendEvent(
req *http.Request, ctx context.Context,
r map[string]interface{},
device *userapi.Device, device *userapi.Device,
roomID, eventType string, stateKey *string, roomID, eventType string, stateKey *string,
cfg *config.ClientAPI, cfg *config.ClientAPI,
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
evTime time.Time,
) (*gomatrixserverlib.Event, *util.JSONResponse) { ) (*gomatrixserverlib.Event, *util.JSONResponse) {
// parse the incoming http request // parse the incoming http request
userID := device.UserID userID := device.UserID
var r map[string]interface{} // must be a JSON object
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return nil, resErr
}
evTime, err := httputil.ParseTSParam(req)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue(err.Error()),
}
}
// create the new event and set all the fields we can // create the new event and set all the fields we can
builder := gomatrixserverlib.EventBuilder{ builder := gomatrixserverlib.EventBuilder{
@ -182,15 +188,15 @@ func generateSendEvent(
Type: eventType, Type: eventType,
StateKey: stateKey, StateKey: stateKey,
} }
err = builder.SetContent(r) err := builder.SetContent(r)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("builder.SetContent failed") util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed")
resErr := jsonerror.InternalServerError() resErr := jsonerror.InternalServerError()
return nil, &resErr return nil, &resErr
} }
var queryRes api.QueryLatestEventsAndStateResponse var queryRes api.QueryLatestEventsAndStateResponse
e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, evTime, rsAPI, &queryRes) e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
@ -213,7 +219,7 @@ func generateSendEvent(
JSON: jsonerror.BadJSON(e.Error()), JSON: jsonerror.BadJSON(e.Error()),
} }
} else if err != nil { } else if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("eventutil.BuildEvent failed") util.GetLogger(ctx).WithError(err).Error("eventutil.BuildEvent failed")
resErr := jsonerror.InternalServerError() resErr := jsonerror.InternalServerError()
return nil, &resErr return nil, &resErr
} }