Remove PerformError (#3066)

This removes `PerformError`, which was needed when we still had
polylith.

This removes quite a bunch of
```go
if err != nil {
	return err
}
if err := res.Error; err != nil {
	return err.JSONResponse()
}
```

Hopefully can be read commit by commit. [skip ci]
This commit is contained in:
Till 2023-04-28 17:46:01 +02:00 committed by GitHub
parent 1432743d1a
commit 6b47cf0f6a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 469 additions and 903 deletions

View file

@ -3,12 +3,14 @@ package routing
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -28,88 +30,60 @@ func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAP
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
res := &roomserverAPI.PerformAdminEvacuateRoomResponse{}
if err := rsAPI.PerformAdminEvacuateRoom( affected, err := rsAPI.PerformAdminEvacuateRoom(req.Context(), vars["roomID"])
req.Context(), switch err {
&roomserverAPI.PerformAdminEvacuateRoomRequest{ case nil:
RoomID: vars["roomID"], case eventutil.ErrRoomNoExists:
}, return util.JSONResponse{
res, Code: http.StatusNotFound,
); err != nil { JSON: jsonerror.NotFound(err.Error()),
return util.ErrorResponse(err)
} }
if err := res.Error; err != nil { default:
return err.JSONResponse() logrus.WithError(err).WithField("roomID", vars["roomID"]).Error("Failed to evacuate room")
return util.ErrorResponse(err)
} }
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
JSON: map[string]interface{}{ JSON: map[string]interface{}{
"affected": res.Affected, "affected": affected,
}, },
} }
} }
func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { func AdminEvacuateUser(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
userID := vars["userID"]
_, domain, err := gomatrixserverlib.SplitID('@', userID) affected, err := rsAPI.PerformAdminEvacuateUser(req.Context(), vars["userID"])
if err != nil { if err != nil {
logrus.WithError(err).WithField("userID", vars["userID"]).Error("Failed to evacuate user")
return util.MessageResponse(http.StatusBadRequest, err.Error()) return util.MessageResponse(http.StatusBadRequest, err.Error())
} }
if !cfg.Matrix.IsLocalServerName(domain) {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("User ID must belong to this server."),
}
}
res := &roomserverAPI.PerformAdminEvacuateUserResponse{}
if err := rsAPI.PerformAdminEvacuateUser(
req.Context(),
&roomserverAPI.PerformAdminEvacuateUserRequest{
UserID: userID,
},
res,
); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if err := res.Error; err != nil {
return err.JSONResponse()
}
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
JSON: map[string]interface{}{ JSON: map[string]interface{}{
"affected": res.Affected, "affected": affected,
}, },
} }
} }
func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { func AdminPurgeRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
roomID := vars["roomID"]
res := &roomserverAPI.PerformAdminPurgeRoomResponse{} if err = rsAPI.PerformAdminPurgeRoom(context.Background(), vars["roomID"]); err != nil {
if err := rsAPI.PerformAdminPurgeRoom(
context.Background(),
&roomserverAPI.PerformAdminPurgeRoomRequest{
RoomID: roomID,
},
res,
); err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
if err := res.Error; err != nil {
return err.JSONResponse()
}
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
JSON: res, JSON: struct{}{},
} }
} }
@ -238,7 +212,7 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien
} }
} }
func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { func AdminDownloadState(req *http.Request, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -257,23 +231,22 @@ func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *api.De
JSON: jsonerror.MissingArgument("Expecting remote server name."), JSON: jsonerror.MissingArgument("Expecting remote server name."),
} }
} }
res := &roomserverAPI.PerformAdminDownloadStateResponse{} if err = rsAPI.PerformAdminDownloadState(req.Context(), roomID, device.UserID, spec.ServerName(serverName)); err != nil {
if err := rsAPI.PerformAdminDownloadState( if errors.Is(err, eventutil.ErrRoomNoExists) {
req.Context(), return util.JSONResponse{
&roomserverAPI.PerformAdminDownloadStateRequest{ Code: 200,
UserID: device.UserID, JSON: jsonerror.NotFound(eventutil.ErrRoomNoExists.Error()),
RoomID: roomID,
ServerName: spec.ServerName(serverName),
},
res,
); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
} }
if err := res.Error; err != nil { }
return err.JSONResponse() logrus.WithError(err).WithFields(logrus.Fields{
"userID": device.UserID,
"serverName": serverName,
"roomID": roomID,
}).Error("failed to download state")
return util.ErrorResponse(err)
} }
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
JSON: map[string]interface{}{}, JSON: struct{}{},
} }
} }

View file

@ -22,6 +22,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/getsentry/sentry-go"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
@ -544,9 +545,10 @@ func createRoom(
} }
// Process the invites. // Process the invites.
var inviteEvent *types.HeaderedEvent
for _, invitee := range r.Invite { for _, invitee := range r.Invite {
// Build the invite event. // Build the invite event.
inviteEvent, err := buildMembershipEvent( inviteEvent, err = buildMembershipEvent(
ctx, invitee, "", profileAPI, device, spec.Invite, ctx, invitee, "", profileAPI, device, spec.Invite,
roomID, r.IsDirect, cfg, evTime, rsAPI, asAPI, roomID, r.IsDirect, cfg, evTime, rsAPI, asAPI,
) )
@ -559,38 +561,44 @@ func createRoom(
fclient.NewInviteV2StrippedState(inviteEvent.Event), fclient.NewInviteV2StrippedState(inviteEvent.Event),
) )
// Send the invite event to the roomserver. // Send the invite event to the roomserver.
var inviteRes roomserverAPI.PerformInviteResponse
event := inviteEvent event := inviteEvent
if err := rsAPI.PerformInvite(ctx, &roomserverAPI.PerformInviteRequest{ err = rsAPI.PerformInvite(ctx, &roomserverAPI.PerformInviteRequest{
Event: event, Event: event,
InviteRoomState: inviteStrippedState, InviteRoomState: inviteStrippedState,
RoomVersion: event.Version(), RoomVersion: event.Version(),
SendAsServer: string(userDomain), SendAsServer: string(userDomain),
}, &inviteRes); err != nil { })
switch e := err.(type) {
case roomserverAPI.ErrInvalidID:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(e.Error()),
}
case roomserverAPI.ErrNotAllowed:
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(e.Error()),
}
case nil:
default:
util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") util.GetLogger(ctx).WithError(err).Error("PerformInvite failed")
sentry.CaptureException(err)
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(), JSON: jsonerror.InternalServerError(),
} }
} }
if inviteRes.Error != nil {
return inviteRes.Error.JSONResponse()
}
} }
} }
if r.Visibility == "public" { if r.Visibility == spec.Public {
// expose this room in the published room list // expose this room in the published room list
var pubRes roomserverAPI.PerformPublishResponse if err = rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{
if err := rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{
RoomID: roomID, RoomID: roomID,
Visibility: "public", Visibility: spec.Public,
}, &pubRes); err != nil { }); err != nil {
return jsonerror.InternalAPIError(ctx, err) util.GetLogger(ctx).WithError(err).Error("failed to publish room")
} return jsonerror.InternalServerError()
if pubRes.Error != nil {
// treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(pubRes.Error).Error("failed to visibility:public")
} }
} }

View file

@ -304,16 +304,12 @@ func SetVisibility(
return *reqErr return *reqErr
} }
var publishRes roomserverAPI.PerformPublishResponse if err = rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{
if err := rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{
RoomID: roomID, RoomID: roomID,
Visibility: v.Visibility, Visibility: v.Visibility,
}, &publishRes); err != nil { }); err != nil {
return jsonerror.InternalAPIError(req.Context(), err) util.GetLogger(req.Context()).WithError(err).Error("failed to publish room")
} return jsonerror.InternalServerError()
if publishRes.Error != nil {
util.GetLogger(req.Context()).WithError(publishRes.Error).Error("PerformPublish failed")
return publishRes.Error.JSONResponse()
} }
return util.JSONResponse{ return util.JSONResponse{
@ -342,18 +338,14 @@ func SetVisibilityAS(
return *reqErr return *reqErr
} }
} }
var publishRes roomserverAPI.PerformPublishResponse
if err := rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{ if err := rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{
RoomID: roomID, RoomID: roomID,
Visibility: v.Visibility, Visibility: v.Visibility,
NetworkID: networkID, NetworkID: networkID,
AppserviceID: dev.AppserviceID, AppserviceID: dev.AppserviceID,
}, &publishRes); err != nil { }); err != nil {
return jsonerror.InternalAPIError(req.Context(), err) util.GetLogger(req.Context()).WithError(err).Error("failed to publish room")
} return jsonerror.InternalServerError()
if publishRes.Error != nil {
util.GetLogger(req.Context()).WithError(publishRes.Error).Error("PerformPublish failed")
return publishRes.Error.JSONResponse()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -15,14 +15,18 @@
package routing package routing
import ( import (
"encoding/json"
"errors"
"net/http" "net/http"
"time" "time"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -41,7 +45,6 @@ func JoinRoomByIDOrAlias(
IsGuest: device.AccountType == api.AccountTypeGuest, IsGuest: device.AccountType == api.AccountTypeGuest,
Content: map[string]interface{}{}, Content: map[string]interface{}{},
} }
joinRes := roomserverAPI.PerformJoinResponse{}
// Check to see if any ?server_name= query parameters were // Check to see if any ?server_name= query parameters were
// given in the request. // given in the request.
@ -81,37 +84,66 @@ func JoinRoomByIDOrAlias(
done := make(chan util.JSONResponse, 1) done := make(chan util.JSONResponse, 1)
go func() { go func() {
defer close(done) defer close(done)
if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil { roomID, _, err := rsAPI.PerformJoin(req.Context(), &joinReq)
done <- jsonerror.InternalAPIError(req.Context(), err) var response util.JSONResponse
} else if joinRes.Error != nil {
if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest { switch e := err.(type) {
done <- util.JSONResponse{ case nil: // success case
Code: http.StatusForbidden, response = util.JSONResponse{
JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg),
}
} else {
done <- joinRes.Error.JSONResponse()
}
} else {
done <- util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
// TODO: Put the response struct somewhere internal. // TODO: Put the response struct somewhere internal.
JSON: struct { JSON: struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
}{joinRes.RoomID}, }{roomID},
}
case roomserverAPI.ErrInvalidID:
response = util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(e.Error()),
}
case roomserverAPI.ErrNotAllowed:
jsonErr := jsonerror.Forbidden(e.Error())
if device.AccountType == api.AccountTypeGuest {
jsonErr = jsonerror.GuestAccessForbidden(e.Error())
}
response = util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonErr,
}
case *gomatrix.HTTPError: // this ensures we proxy responses over federation to the client
response = util.JSONResponse{
Code: e.Code,
JSON: json.RawMessage(e.Message),
}
default:
response = util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(),
}
if errors.Is(err, eventutil.ErrRoomNoExists) {
response = util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(e.Error()),
} }
} }
}
done <- response
}() }()
// Wait either for the join to finish, or for us to hit a reasonable // Wait either for the join to finish, or for us to hit a reasonable
// timeout, at which point we'll just return a 200 to placate clients. // timeout, at which point we'll just return a 200 to placate clients.
timer := time.NewTimer(time.Second * 20)
select { select {
case <-time.After(time.Second * 20): case <-timer.C:
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusAccepted, Code: http.StatusAccepted,
JSON: jsonerror.Unknown("The room join will continue in the background."), JSON: jsonerror.Unknown("The room join will continue in the background."),
} }
case result := <-done: case result := <-done:
// Stop and drain the timer
if !timer.Stop() {
<-timer.C
}
return result return result
} }
} }

View file

@ -20,6 +20,7 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
@ -265,22 +266,33 @@ func sendInvite(
return jsonerror.InternalServerError(), err return jsonerror.InternalServerError(), err
} }
var inviteRes api.PerformInviteResponse err = rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{
if err := rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{
Event: event, Event: event,
InviteRoomState: nil, // ask the roomserver to draw up invite room state for us InviteRoomState: nil, // ask the roomserver to draw up invite room state for us
RoomVersion: event.Version(), RoomVersion: event.Version(),
SendAsServer: string(device.UserDomain()), SendAsServer: string(device.UserDomain()),
}, &inviteRes); err != nil { })
switch e := err.(type) {
case roomserverAPI.ErrInvalidID:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(e.Error()),
}, e
case roomserverAPI.ErrNotAllowed:
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(e.Error()),
}, e
case nil:
default:
util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") util.GetLogger(ctx).WithError(err).Error("PerformInvite failed")
sentry.CaptureException(err)
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(), JSON: jsonerror.InternalServerError(),
}, err }, err
} }
if inviteRes.Error != nil {
return inviteRes.Error.JSONResponse(), inviteRes.Error
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,

View file

@ -15,13 +15,16 @@
package routing package routing
import ( import (
"encoding/json"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
func PeekRoomByIDOrAlias( func PeekRoomByIDOrAlias(
@ -41,8 +44,6 @@ func PeekRoomByIDOrAlias(
UserID: device.UserID, UserID: device.UserID,
DeviceID: device.ID, DeviceID: device.ID,
} }
peekRes := roomserverAPI.PerformPeekResponse{}
// Check to see if any ?server_name= query parameters were // Check to see if any ?server_name= query parameters were
// given in the request. // given in the request.
if serverNames, ok := req.URL.Query()["server_name"]; ok { if serverNames, ok := req.URL.Query()["server_name"]; ok {
@ -55,11 +56,27 @@ func PeekRoomByIDOrAlias(
} }
// Ask the roomserver to perform the peek. // Ask the roomserver to perform the peek.
if err := rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes); err != nil { roomID, err := rsAPI.PerformPeek(req.Context(), &peekReq)
return util.ErrorResponse(err) switch e := err.(type) {
case roomserverAPI.ErrInvalidID:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(e.Error()),
} }
if peekRes.Error != nil { case roomserverAPI.ErrNotAllowed:
return peekRes.Error.JSONResponse() return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(e.Error()),
}
case *gomatrix.HTTPError:
return util.JSONResponse{
Code: e.Code,
JSON: json.RawMessage(e.Message),
}
case nil:
default:
logrus.WithError(err).WithField("roomID", roomIDOrAlias).Errorf("Failed to peek room")
return jsonerror.InternalServerError()
} }
// if this user is already joined to the room, we let them peek anyway // if this user is already joined to the room, we let them peek anyway
@ -75,7 +92,7 @@ func PeekRoomByIDOrAlias(
// TODO: Put the response struct somewhere internal. // TODO: Put the response struct somewhere internal.
JSON: struct { JSON: struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
}{peekRes.RoomID}, }{roomID},
} }
} }
@ -85,18 +102,17 @@ func UnpeekRoomByID(
rsAPI roomserverAPI.ClientRoomserverAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
roomID string, roomID string,
) util.JSONResponse { ) util.JSONResponse {
unpeekReq := roomserverAPI.PerformUnpeekRequest{ err := rsAPI.PerformUnpeek(req.Context(), roomID, device.UserID, device.ID)
RoomID: roomID, switch e := err.(type) {
UserID: device.UserID, case roomserverAPI.ErrInvalidID:
DeviceID: device.ID, return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(e.Error()),
} }
unpeekRes := roomserverAPI.PerformUnpeekResponse{} case nil:
default:
if err := rsAPI.PerformUnpeek(req.Context(), &unpeekReq, &unpeekRes); err != nil { logrus.WithError(err).WithField("roomID", roomID).Errorf("Failed to un-peek room")
return jsonerror.InternalAPIError(req.Context(), err) return jsonerror.InternalServerError()
}
if unpeekRes.Error != nil {
return unpeekRes.Error.JSONResponse()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -162,13 +162,13 @@ func Setup(
dendriteAdminRouter.Handle("/admin/evacuateUser/{userID}", dendriteAdminRouter.Handle("/admin/evacuateUser/{userID}",
httputil.MakeAdminAPI("admin_evacuate_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAdminAPI("admin_evacuate_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminEvacuateUser(req, cfg, rsAPI) return AdminEvacuateUser(req, rsAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/purgeRoom/{roomID}", dendriteAdminRouter.Handle("/admin/purgeRoom/{roomID}",
httputil.MakeAdminAPI("admin_purge_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAdminAPI("admin_purge_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminPurgeRoom(req, cfg, device, rsAPI) return AdminPurgeRoom(req, rsAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
@ -180,7 +180,7 @@ func Setup(
dendriteAdminRouter.Handle("/admin/downloadState/{serverName}/{roomID}", dendriteAdminRouter.Handle("/admin/downloadState/{serverName}/{roomID}",
httputil.MakeAdminAPI("admin_download_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAdminAPI("admin_download_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminDownloadState(req, cfg, device, rsAPI) return AdminDownloadState(req, device, rsAPI)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)

View file

@ -15,11 +15,13 @@
package routing package routing
import ( import (
"errors"
"net/http" "net/http"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -57,38 +59,28 @@ func UpgradeRoom(
} }
} }
upgradeReq := roomserverAPI.PerformRoomUpgradeRequest{ newRoomID, err := rsAPI.PerformRoomUpgrade(req.Context(), roomID, device.UserID, gomatrixserverlib.RoomVersion(r.NewVersion))
UserID: device.UserID, switch e := err.(type) {
RoomID: roomID, case nil:
RoomVersion: gomatrixserverlib.RoomVersion(r.NewVersion), case roomserverAPI.ErrNotAllowed:
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(e.Error()),
} }
upgradeResp := roomserverAPI.PerformRoomUpgradeResponse{} default:
if errors.Is(err, eventutil.ErrRoomNoExists) {
if err := rsAPI.PerformRoomUpgrade(req.Context(), &upgradeReq, &upgradeResp); err != nil {
return jsonerror.InternalAPIError(req.Context(), err)
}
if upgradeResp.Error != nil {
if upgradeResp.Error.Code == roomserverAPI.PerformErrorNoRoom {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Room does not exist"), JSON: jsonerror.NotFound("Room does not exist"),
} }
} else if upgradeResp.Error.Code == roomserverAPI.PerformErrorNotAllowed {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(upgradeResp.Error.Msg),
} }
} else {
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: upgradeRoomResponse{ JSON: upgradeRoomResponse{
ReplacementRoom: upgradeResp.NewRoomID, ReplacementRoom: newRoomID,
}, },
} }
} }

View file

@ -20,6 +20,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
@ -205,17 +206,36 @@ func processInvite(
SendAsServer: string(api.DoNotSendToOtherServers), SendAsServer: string(api.DoNotSendToOtherServers),
TransactionID: nil, TransactionID: nil,
} }
response := &api.PerformInviteResponse{}
if err := rsAPI.PerformInvite(ctx, request, response); err != nil { if err = rsAPI.PerformInvite(ctx, request); err != nil {
util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") util.GetLogger(ctx).WithError(err).Error("PerformInvite failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(), JSON: jsonerror.InternalServerError(),
} }
} }
if response.Error != nil {
return response.Error.JSONResponse() switch e := err.(type) {
case api.ErrInvalidID:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(e.Error()),
} }
case api.ErrNotAllowed:
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(e.Error()),
}
case nil:
default:
util.GetLogger(ctx).WithError(err).Error("PerformInvite failed")
sentry.CaptureException(err)
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(),
}
}
// Return the signed event to the originating server, it should then tell // Return the signed event to the originating server, it should then tell
// the other servers in the room that we have been invited. // the other servers in the room that we have been invited.
if isInviteV2 { if isInviteV2 {

View file

@ -11,6 +11,25 @@ import (
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
) )
// ErrInvalidID is an error returned if the userID is invalid
type ErrInvalidID struct {
Err error
}
func (e ErrInvalidID) Error() string {
return e.Err.Error()
}
// ErrNotAllowed is an error returned if the user is not allowed
// to execute some action (e.g. invite)
type ErrNotAllowed struct {
Err error
}
func (e ErrNotAllowed) Error() string {
return e.Err.Error()
}
// RoomserverInputAPI is used to write events to the room server. // RoomserverInputAPI is used to write events to the room server.
type RoomserverInternalAPI interface { type RoomserverInternalAPI interface {
SyncRoomserverAPI SyncRoomserverAPI
@ -150,17 +169,17 @@ type ClientRoomserverAPI interface {
GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error
// PerformRoomUpgrade upgrades a room to a newer version // PerformRoomUpgrade upgrades a room to a newer version
PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) error PerformRoomUpgrade(ctx context.Context, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error)
PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) error PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error)
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error)
PerformAdminPurgeRoom(ctx context.Context, req *PerformAdminPurgeRoomRequest, res *PerformAdminPurgeRoomResponse) error PerformAdminPurgeRoom(ctx context.Context, roomID string) error
PerformAdminDownloadState(ctx context.Context, req *PerformAdminDownloadStateRequest, res *PerformAdminDownloadStateResponse) error PerformAdminDownloadState(ctx context.Context, roomID, userID string, serverName spec.ServerName) error
PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error PerformPeek(ctx context.Context, req *PerformPeekRequest) (roomID string, err error)
PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error PerformUnpeek(ctx context.Context, roomID, userID, deviceID string) error
PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error PerformInvite(ctx context.Context, req *PerformInviteRequest) error
PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) error PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error
PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse) error PerformPublish(ctx context.Context, req *PerformPublishRequest) error
// PerformForget forgets a rooms history for a specific user // PerformForget forgets a rooms history for a specific user
PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error
SetRoomAlias(ctx context.Context, req *SetRoomAliasRequest, res *SetRoomAliasResponse) error SetRoomAlias(ctx context.Context, req *SetRoomAliasRequest, res *SetRoomAliasResponse) error
@ -172,8 +191,8 @@ type UserRoomserverAPI interface {
KeyserverRoomserverAPI KeyserverRoomserverAPI
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error)
PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) error PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
} }
type FederationRoomserverAPI interface { type FederationRoomserverAPI interface {
@ -202,7 +221,7 @@ type FederationRoomserverAPI interface {
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error PerformInvite(ctx context.Context, req *PerformInviteRequest) error
// Query a given amount (or less) of events prior to a given set of events. // Query a given amount (or less) of events prior to a given set of events.
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
} }

View file

@ -1,81 +1,11 @@
package api package api
import ( import (
"encoding/json" "github.com/matrix-org/dendrite/roomserver/types"
"fmt"
"net/http"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/roomserver/types"
)
type PerformErrorCode int
type PerformError struct {
Msg string
RemoteCode int // remote HTTP status code, for PerformErrRemote
Code PerformErrorCode
}
func (p *PerformError) Error() string {
return fmt.Sprintf("%d : %s", p.Code, p.Msg)
}
// JSONResponse maps error codes to suitable HTTP error codes, defaulting to 500.
func (p *PerformError) JSONResponse() util.JSONResponse {
switch p.Code {
case PerformErrorBadRequest:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(p.Msg),
}
case PerformErrorNoRoom:
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(p.Msg),
}
case PerformErrorNotAllowed:
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(p.Msg),
}
case PerformErrorNoOperation:
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(p.Msg),
}
case PerformErrRemote:
// if the code is 0 then something bad happened and it isn't
// a remote HTTP error being encapsulated, e.g network error to remote.
if p.RemoteCode == 0 {
return util.ErrorResponse(fmt.Errorf("%s", p.Msg))
}
return util.JSONResponse{
Code: p.RemoteCode,
// TODO: Should we assert this is in fact JSON? E.g gjson parse?
JSON: json.RawMessage(p.Msg),
}
default:
return util.ErrorResponse(p)
}
}
const (
// PerformErrorNotAllowed means the user is not allowed to invite/join/etc this room (e.g join_rule:invite or banned)
PerformErrorNotAllowed PerformErrorCode = 1
// PerformErrorBadRequest means the request was wrong in some way (invalid user ID, wrong server, etc)
PerformErrorBadRequest PerformErrorCode = 2
// PerformErrorNoRoom means that the room being joined doesn't exist.
PerformErrorNoRoom PerformErrorCode = 3
// PerformErrorNoOperation means that the request resulted in nothing happening e.g invite->invite or leave->leave.
PerformErrorNoOperation PerformErrorCode = 4
// PerformErrRemote means that the request failed and the PerformError.Msg is the raw remote JSON error response
PerformErrRemote PerformErrorCode = 5
) )
type PerformJoinRequest struct { type PerformJoinRequest struct {
@ -87,14 +17,6 @@ type PerformJoinRequest struct {
Unsigned map[string]interface{} `json:"unsigned"` Unsigned map[string]interface{} `json:"unsigned"`
} }
type PerformJoinResponse struct {
// The room ID, populated on success.
RoomID string `json:"room_id"`
JoinedVia spec.ServerName
// If non-nil, the join request failed. Contains more information why it failed.
Error *PerformError
}
type PerformLeaveRequest struct { type PerformLeaveRequest struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
@ -113,10 +35,6 @@ type PerformInviteRequest struct {
TransactionID *TransactionID `json:"transaction_id"` TransactionID *TransactionID `json:"transaction_id"`
} }
type PerformInviteResponse struct {
Error *PerformError
}
type PerformPeekRequest struct { type PerformPeekRequest struct {
RoomIDOrAlias string `json:"room_id_or_alias"` RoomIDOrAlias string `json:"room_id_or_alias"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
@ -124,24 +42,6 @@ type PerformPeekRequest struct {
ServerNames []spec.ServerName `json:"server_names"` ServerNames []spec.ServerName `json:"server_names"`
} }
type PerformPeekResponse struct {
// The room ID, populated on success.
RoomID string `json:"room_id"`
// If non-nil, the join request failed. Contains more information why it failed.
Error *PerformError
}
type PerformUnpeekRequest struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
}
type PerformUnpeekResponse struct {
// If non-nil, the join request failed. Contains more information why it failed.
Error *PerformError
}
// PerformBackfillRequest is a request to PerformBackfill. // PerformBackfillRequest is a request to PerformBackfill.
type PerformBackfillRequest struct { type PerformBackfillRequest struct {
// The room to backfill // The room to backfill
@ -180,11 +80,6 @@ type PerformPublishRequest struct {
NetworkID string NetworkID string
} }
type PerformPublishResponse struct {
// If non-nil, the publish request failed. Contains more information why it failed.
Error *PerformError
}
type PerformInboundPeekRequest struct { type PerformInboundPeekRequest struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
@ -214,50 +109,3 @@ type PerformForgetRequest struct {
} }
type PerformForgetResponse struct{} type PerformForgetResponse struct{}
type PerformRoomUpgradeRequest struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
}
type PerformRoomUpgradeResponse struct {
NewRoomID string
Error *PerformError
}
type PerformAdminEvacuateRoomRequest struct {
RoomID string `json:"room_id"`
}
type PerformAdminEvacuateRoomResponse struct {
Affected []string `json:"affected"`
Error *PerformError
}
type PerformAdminEvacuateUserRequest struct {
UserID string `json:"user_id"`
}
type PerformAdminEvacuateUserResponse struct {
Affected []string `json:"affected"`
Error *PerformError
}
type PerformAdminPurgeRoomRequest struct {
RoomID string `json:"room_id"`
}
type PerformAdminPurgeRoomResponse struct {
Error *PerformError `json:"error,omitempty"`
}
type PerformAdminDownloadStateRequest struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
ServerName spec.ServerName `json:"server_name"`
}
type PerformAdminDownloadStateResponse struct {
Error *PerformError `json:"error,omitempty"`
}

View file

@ -209,11 +209,9 @@ func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalA
func (r *RoomserverInternalAPI) PerformInvite( func (r *RoomserverInternalAPI) PerformInvite(
ctx context.Context, ctx context.Context,
req *api.PerformInviteRequest, req *api.PerformInviteRequest,
res *api.PerformInviteResponse,
) error { ) error {
outputEvents, err := r.Inviter.PerformInvite(ctx, req, res) outputEvents, err := r.Inviter.PerformInvite(ctx, req)
if err != nil { if err != nil {
sentry.CaptureException(err)
return err return err
} }
if len(outputEvents) == 0 { if len(outputEvents) == 0 {

View file

@ -29,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -41,61 +42,44 @@ type Admin struct {
Leaver *Leaver Leaver *Leaver
} }
// PerformEvacuateRoom will remove all local users from the given room. // PerformAdminEvacuateRoom will remove all local users from the given room.
func (r *Admin) PerformAdminEvacuateRoom( func (r *Admin) PerformAdminEvacuateRoom(
ctx context.Context, ctx context.Context,
req *api.PerformAdminEvacuateRoomRequest, roomID string,
res *api.PerformAdminEvacuateRoomResponse, ) (affected []string, err error) {
) error { roomInfo, err := r.DB.RoomInfo(ctx, roomID)
roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err),
}
return nil
} }
if roomInfo == nil || roomInfo.IsStub() { if roomInfo == nil || roomInfo.IsStub() {
res.Error = &api.PerformError{ return nil, eventutil.ErrRoomNoExists
Code: api.PerformErrorNoRoom,
Msg: fmt.Sprintf("Room %s not found", req.RoomID),
}
return nil
} }
memberNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true) memberNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetMembershipEventNIDsForRoom: %s", err),
}
return nil
} }
memberEvents, err := r.DB.Events(ctx, roomInfo, memberNIDs) memberEvents, err := r.DB.Events(ctx, roomInfo, memberNIDs)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.Events: %s", err),
}
return nil
} }
inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents)) inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents))
res.Affected = make([]string, 0, len(memberEvents)) affected = make([]string, 0, len(memberEvents))
latestReq := &api.QueryLatestEventsAndStateRequest{ latestReq := &api.QueryLatestEventsAndStateRequest{
RoomID: req.RoomID, RoomID: roomID,
} }
latestRes := &api.QueryLatestEventsAndStateResponse{} latestRes := &api.QueryLatestEventsAndStateResponse{}
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Queryer.QueryLatestEventsAndState: %s", err),
}
return nil
} }
prevEvents := latestRes.LatestEvents prevEvents := latestRes.LatestEvents
var senderDomain spec.ServerName
var eventsNeeded gomatrixserverlib.StateNeeded
var identity *fclient.SigningIdentity
var event *types.HeaderedEvent
for _, memberEvent := range memberEvents { for _, memberEvent := range memberEvents {
if memberEvent.StateKey() == nil { if memberEvent.StateKey() == nil {
continue continue
@ -103,57 +87,41 @@ func (r *Admin) PerformAdminEvacuateRoom(
var memberContent gomatrixserverlib.MemberContent var memberContent gomatrixserverlib.MemberContent
if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil { if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("json.Unmarshal: %s", err),
}
return nil
} }
memberContent.Membership = spec.Leave memberContent.Membership = spec.Leave
stateKey := *memberEvent.StateKey() stateKey := *memberEvent.StateKey()
fledglingEvent := &gomatrixserverlib.EventBuilder{ fledglingEvent := &gomatrixserverlib.EventBuilder{
RoomID: req.RoomID, RoomID: roomID,
Type: spec.MRoomMember, Type: spec.MRoomMember,
StateKey: &stateKey, StateKey: &stateKey,
Sender: stateKey, Sender: stateKey,
PrevEvents: prevEvents, PrevEvents: prevEvents,
} }
_, senderDomain, err := gomatrixserverlib.SplitID('@', fledglingEvent.Sender) _, senderDomain, err = gomatrixserverlib.SplitID('@', fledglingEvent.Sender)
if err != nil { if err != nil {
continue continue
} }
if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("json.Marshal: %s", err),
}
return nil
} }
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent) eventsNeeded, err = gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err),
}
return nil
} }
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) identity, err = r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil { if err != nil {
continue continue
} }
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, latestRes) event, err = eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, latestRes)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err),
}
return nil
} }
inputEvents = append(inputEvents, api.InputRoomEvent{ inputEvents = append(inputEvents, api.InputRoomEvent{
@ -162,7 +130,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Origin: senderDomain, Origin: senderDomain,
SendAsServer: string(senderDomain), SendAsServer: string(senderDomain),
}) })
res.Affected = append(res.Affected, stateKey) affected = append(affected, stateKey)
prevEvents = []gomatrixserverlib.EventReference{ prevEvents = []gomatrixserverlib.EventReference{
event.EventReference(), event.EventReference(),
} }
@ -173,108 +141,85 @@ func (r *Admin) PerformAdminEvacuateRoom(
Asynchronous: true, Asynchronous: true,
} }
inputRes := &api.InputRoomEventsResponse{} inputRes := &api.InputRoomEventsResponse{}
return r.Inputer.InputRoomEvents(ctx, inputReq, inputRes) err = r.Inputer.InputRoomEvents(ctx, inputReq, inputRes)
return affected, err
} }
// PerformAdminEvacuateUser will remove the given user from all rooms.
func (r *Admin) PerformAdminEvacuateUser( func (r *Admin) PerformAdminEvacuateUser(
ctx context.Context, ctx context.Context,
req *api.PerformAdminEvacuateUserRequest, userID string,
res *api.PerformAdminEvacuateUserResponse, ) (affected []string, err error) {
) error { _, domain, err := gomatrixserverlib.SplitID('@', userID)
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Malformed user ID: %s", err),
}
return nil
} }
if !r.Cfg.Matrix.IsLocalServerName(domain) { if !r.Cfg.Matrix.IsLocalServerName(domain) {
res.Error = &api.PerformError{ return nil, fmt.Errorf("can only evacuate local users using this endpoint")
Code: api.PerformErrorBadRequest,
Msg: "Can only evacuate local users using this endpoint",
}
return nil
} }
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, spec.Join) roomIDs, err := r.DB.GetRoomsByMembership(ctx, userID, spec.Join)
if err != nil {
return nil, err
}
inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, userID, spec.Invite)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err),
}
return nil
} }
inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, spec.Invite) allRooms := append(roomIDs, inviteRoomIDs...)
if err != nil && err != sql.ErrNoRows { affected = make([]string, 0, len(allRooms))
res.Error = &api.PerformError{ for _, roomID := range allRooms {
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err),
}
return nil
}
for _, roomID := range append(roomIDs, inviteRoomIDs...) {
leaveReq := &api.PerformLeaveRequest{ leaveReq := &api.PerformLeaveRequest{
RoomID: roomID, RoomID: roomID,
UserID: req.UserID, UserID: userID,
} }
leaveRes := &api.PerformLeaveResponse{} leaveRes := &api.PerformLeaveResponse{}
outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes) outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Leaver.PerformLeave: %s", err),
} }
return nil affected = append(affected, roomID)
}
res.Affected = append(res.Affected, roomID)
if len(outputEvents) == 0 { if len(outputEvents) == 0 {
continue continue
} }
if err := r.Inputer.OutputProducer.ProduceRoomEvents(roomID, outputEvents); err != nil { if err := r.Inputer.OutputProducer.ProduceRoomEvents(roomID, outputEvents); err != nil {
res.Error = &api.PerformError{ return nil, err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Inputer.WriteOutputEvents: %s", err),
}
return nil
} }
} }
return nil return affected, nil
} }
// PerformAdminPurgeRoom removes all traces for the given room from the database.
func (r *Admin) PerformAdminPurgeRoom( func (r *Admin) PerformAdminPurgeRoom(
ctx context.Context, ctx context.Context,
req *api.PerformAdminPurgeRoomRequest, roomID string,
res *api.PerformAdminPurgeRoomResponse,
) error { ) error {
// Validate we actually got a room ID and nothing else // Validate we actually got a room ID and nothing else
if _, _, err := gomatrixserverlib.SplitID('!', req.RoomID); err != nil { if _, _, err := gomatrixserverlib.SplitID('!', roomID); err != nil {
res.Error = &api.PerformError{ return err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Malformed room ID: %s", err),
}
return nil
} }
logrus.WithField("room_id", req.RoomID).Warn("Purging room from roomserver") // Evacuate the room before purging it from the database
if err := r.DB.PurgeRoom(ctx, req.RoomID); err != nil { if _, err := r.PerformAdminEvacuateRoom(ctx, roomID); err != nil {
logrus.WithField("room_id", req.RoomID).WithError(err).Warn("Failed to purge room from roomserver") logrus.WithField("room_id", roomID).WithError(err).Warn("Failed to evacuate room before purging")
res.Error = &api.PerformError{ return err
Code: api.PerformErrorBadRequest,
Msg: err.Error(),
}
return nil
} }
logrus.WithField("room_id", req.RoomID).Warn("Room purged from roomserver") logrus.WithField("room_id", roomID).Warn("Purging room from roomserver")
if err := r.DB.PurgeRoom(ctx, roomID); err != nil {
logrus.WithField("room_id", roomID).WithError(err).Warn("Failed to purge room from roomserver")
return err
}
return r.Inputer.OutputProducer.ProduceRoomEvents(req.RoomID, []api.OutputEvent{ logrus.WithField("room_id", roomID).Warn("Room purged from roomserver")
return r.Inputer.OutputProducer.ProduceRoomEvents(roomID, []api.OutputEvent{
{ {
Type: api.OutputTypePurgeRoom, Type: api.OutputTypePurgeRoom,
PurgeRoom: &api.OutputPurgeRoom{ PurgeRoom: &api.OutputPurgeRoom{
RoomID: req.RoomID, RoomID: roomID,
}, },
}, },
}) })
@ -282,42 +227,25 @@ func (r *Admin) PerformAdminPurgeRoom(
func (r *Admin) PerformAdminDownloadState( func (r *Admin) PerformAdminDownloadState(
ctx context.Context, ctx context.Context,
req *api.PerformAdminDownloadStateRequest, roomID, userID string, serverName spec.ServerName,
res *api.PerformAdminDownloadStateResponse,
) error { ) error {
_, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', req.UserID) _, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Cfg.Matrix.SplitLocalID: %s", err),
}
return nil
} }
roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) roomInfo, err := r.DB.RoomInfo(ctx, roomID)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err),
}
return nil
} }
if roomInfo == nil || roomInfo.IsStub() { if roomInfo == nil || roomInfo.IsStub() {
res.Error = &api.PerformError{ return eventutil.ErrRoomNoExists
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("room %q not found", req.RoomID),
}
return nil
} }
fwdExtremities, _, depth, err := r.DB.LatestEventIDs(ctx, roomInfo.RoomNID) fwdExtremities, _, depth, err := r.DB.LatestEventIDs(ctx, roomInfo.RoomNID)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return err
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.LatestEventIDs: %s", err),
}
return nil
} }
authEventMap := map[string]*gomatrixserverlib.Event{} authEventMap := map[string]*gomatrixserverlib.Event{}
@ -325,13 +253,9 @@ func (r *Admin) PerformAdminDownloadState(
for _, fwdExtremity := range fwdExtremities { for _, fwdExtremity := range fwdExtremities {
var state gomatrixserverlib.StateResponse var state gomatrixserverlib.StateResponse
state, err = r.Inputer.FSAPI.LookupState(ctx, r.Inputer.ServerName, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion) state, err = r.Inputer.FSAPI.LookupState(ctx, r.Inputer.ServerName, serverName, roomID, fwdExtremity.EventID, roomInfo.RoomVersion)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity.EventID, err)
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity.EventID, err),
}
return nil
} }
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing); err != nil { if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing); err != nil {
@ -361,18 +285,14 @@ func (r *Admin) PerformAdminDownloadState(
builder := &gomatrixserverlib.EventBuilder{ builder := &gomatrixserverlib.EventBuilder{
Type: "org.matrix.dendrite.state_download", Type: "org.matrix.dendrite.state_download",
Sender: req.UserID, Sender: userID,
RoomID: req.RoomID, RoomID: roomID,
Content: spec.RawJSON("{}"), Content: spec.RawJSON("{}"),
} }
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err)
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err),
}
return nil
} }
queryRes := &api.QueryLatestEventsAndStateResponse{ queryRes := &api.QueryLatestEventsAndStateResponse{
@ -390,11 +310,7 @@ func (r *Admin) PerformAdminDownloadState(
ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, queryRes) ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, queryRes)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return fmt.Errorf("eventutil.BuildEvent: %w", err)
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err),
}
return nil
} }
inputReq := &api.InputRoomEventsRequest{ inputReq := &api.InputRoomEventsRequest{
@ -418,19 +334,12 @@ func (r *Admin) PerformAdminDownloadState(
SendAsServer: string(r.Cfg.Matrix.ServerName), SendAsServer: string(r.Cfg.Matrix.ServerName),
}) })
if err := r.Inputer.InputRoomEvents(ctx, inputReq, inputRes); err != nil { if err = r.Inputer.InputRoomEvents(ctx, inputReq, inputRes); err != nil {
res.Error = &api.PerformError{ return fmt.Errorf("r.Inputer.InputRoomEvents: %w", err)
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Inputer.InputRoomEvents: %s", err),
}
return nil
} }
if inputRes.ErrMsg != "" { if inputRes.ErrMsg != "" {
res.Error = &api.PerformError{ return inputRes.Err()
Code: api.PerformErrorBadRequest,
Msg: inputRes.ErrMsg,
}
} }
return nil return nil

View file

@ -45,7 +45,6 @@ type Inviter struct {
func (r *Inviter) PerformInvite( func (r *Inviter) PerformInvite(
ctx context.Context, ctx context.Context,
req *api.PerformInviteRequest, req *api.PerformInviteRequest,
res *api.PerformInviteResponse,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
var outputUpdates []api.OutputEvent var outputUpdates []api.OutputEvent
event := req.Event event := req.Event
@ -66,20 +65,12 @@ func (r *Inviter) PerformInvite(
_, domain, err := gomatrixserverlib.SplitID('@', targetUserID) _, domain, err := gomatrixserverlib.SplitID('@', targetUserID)
if err != nil { if err != nil {
res.Error = &api.PerformError{ return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", targetUserID)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("The user ID %q is invalid!", targetUserID),
}
return nil, nil
} }
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain) isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain)
isOriginLocal := r.Cfg.Matrix.IsLocalServerName(senderDomain) isOriginLocal := r.Cfg.Matrix.IsLocalServerName(senderDomain)
if !isOriginLocal && !isTargetLocal { if !isOriginLocal && !isTargetLocal {
res.Error = &api.PerformError{ return nil, api.ErrInvalidID{Err: fmt.Errorf("the invite must be either from or to a local user")}
Code: api.PerformErrorBadRequest,
Msg: "The invite must be either from or to a local user",
}
return nil, nil
} }
logger := util.GetLogger(ctx).WithFields(map[string]interface{}{ logger := util.GetLogger(ctx).WithFields(map[string]interface{}{
@ -175,12 +166,8 @@ func (r *Inviter) PerformInvite(
// For now we will implement option 2. Since in the abesence of a retry // For now we will implement option 2. Since in the abesence of a retry
// mechanism it will be equivalent to option 1, and we don't have a // mechanism it will be equivalent to option 1, and we don't have a
// signalling mechanism to implement option 3. // signalling mechanism to implement option 3.
res.Error = &api.PerformError{
Code: api.PerformErrorNotAllowed,
Msg: "User is already joined to room",
}
logger.Debugf("user already joined") logger.Debugf("user already joined")
return nil, nil return nil, api.ErrNotAllowed{Err: fmt.Errorf("user is already joined to room")}
} }
// If the invite originated remotely then we can't send an // If the invite originated remotely then we can't send an
@ -201,11 +188,7 @@ func (r *Inviter) PerformInvite(
logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
"processInviteEvent.checkAuthEvents failed for event", "processInviteEvent.checkAuthEvents failed for event",
) )
res.Error = &api.PerformError{ return nil, api.ErrNotAllowed{Err: err}
Msg: err.Error(),
Code: api.PerformErrorNotAllowed,
}
return nil, nil
} }
// If the invite originated from us and the target isn't local then we // If the invite originated from us and the target isn't local then we
@ -220,12 +203,8 @@ func (r *Inviter) PerformInvite(
} }
fsRes := &federationAPI.PerformInviteResponse{} fsRes := &federationAPI.PerformInviteResponse{}
if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil { if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil {
res.Error = &api.PerformError{
Msg: err.Error(),
Code: api.PerformErrorNotAllowed,
}
logger.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed") logger.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed")
return nil, nil return nil, api.ErrNotAllowed{Err: err}
} }
event = fsRes.Event event = fsRes.Event
logger.Debugf("Federated PerformInvite success with event ID %s", event.EventID()) logger.Debugf("Federated PerformInvite success with event ID %s", event.EventID())
@ -251,11 +230,8 @@ func (r *Inviter) PerformInvite(
return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err) return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err)
} }
if err = inputRes.Err(); err != nil { if err = inputRes.Err(); err != nil {
res.Error = &api.PerformError{
Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()),
Code: api.PerformErrorNotAllowed,
}
logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed")
return nil, api.ErrNotAllowed{Err: err}
} }
// Don't notify the sync api of this event in the same way as a federated invite so the invitee // Don't notify the sync api of this event in the same way as a federated invite so the invitee

View file

@ -54,32 +54,22 @@ type Joiner struct {
func (r *Joiner) PerformJoin( func (r *Joiner) PerformJoin(
ctx context.Context, ctx context.Context,
req *rsAPI.PerformJoinRequest, req *rsAPI.PerformJoinRequest,
res *rsAPI.PerformJoinResponse, ) (roomID string, joinedVia spec.ServerName, err error) {
) error {
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
"room_id": req.RoomIDOrAlias, "room_id": req.RoomIDOrAlias,
"user_id": req.UserID, "user_id": req.UserID,
"servers": req.ServerNames, "servers": req.ServerNames,
}) })
logger.Info("User requested to room join") logger.Info("User requested to room join")
roomID, joinedVia, err := r.performJoin(context.Background(), req) roomID, joinedVia, err = r.performJoin(context.Background(), req)
if err != nil { if err != nil {
logger.WithError(err).Error("Failed to join room") logger.WithError(err).Error("Failed to join room")
sentry.CaptureException(err) sentry.CaptureException(err)
perr, ok := err.(*rsAPI.PerformError) return "", "", err
if ok {
res.Error = perr
} else {
res.Error = &rsAPI.PerformError{
Msg: err.Error(),
}
}
return nil
} }
logger.Info("User joined room successfully") logger.Info("User joined room successfully")
res.RoomID = roomID
res.JoinedVia = joinedVia return roomID, joinedVia, nil
return nil
} }
func (r *Joiner) performJoin( func (r *Joiner) performJoin(
@ -88,16 +78,10 @@ func (r *Joiner) performJoin(
) (string, spec.ServerName, error) { ) (string, spec.ServerName, error) {
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID) _, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("supplied user ID %q in incorrect format", req.UserID)}
Code: rsAPI.PerformErrorBadRequest,
Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID),
}
} }
if !r.Cfg.Matrix.IsLocalServerName(domain) { if !r.Cfg.Matrix.IsLocalServerName(domain) {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user %q does not belong to this homeserver", req.UserID)}
Code: rsAPI.PerformErrorBadRequest,
Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID),
}
} }
if strings.HasPrefix(req.RoomIDOrAlias, "!") { if strings.HasPrefix(req.RoomIDOrAlias, "!") {
return r.performJoinRoomByID(ctx, req) return r.performJoinRoomByID(ctx, req)
@ -105,10 +89,7 @@ func (r *Joiner) performJoin(
if strings.HasPrefix(req.RoomIDOrAlias, "#") { if strings.HasPrefix(req.RoomIDOrAlias, "#") {
return r.performJoinRoomByAlias(ctx, req) return r.performJoinRoomByAlias(ctx, req)
} }
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID or alias %q is invalid", req.RoomIDOrAlias)}
Code: rsAPI.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID or alias %q is invalid", req.RoomIDOrAlias),
}
} }
func (r *Joiner) performJoinRoomByAlias( func (r *Joiner) performJoinRoomByAlias(
@ -183,10 +164,7 @@ func (r *Joiner) performJoinRoomByID(
// Get the domain part of the room ID. // Get the domain part of the room ID.
_, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias) _, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias)
if err != nil { if err != nil {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", req.RoomIDOrAlias, err)}
Code: rsAPI.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID %q is invalid: %s", req.RoomIDOrAlias, err),
}
} }
// If the server name in the room ID isn't ours then it's a // If the server name in the room ID isn't ours then it's a
@ -200,10 +178,7 @@ func (r *Joiner) performJoinRoomByID(
userID := req.UserID userID := req.UserID
_, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", userID, err)}
Code: rsAPI.PerformErrorBadRequest,
Msg: fmt.Sprintf("User ID %q is invalid: %s", userID, err),
}
} }
eb := gomatrixserverlib.EventBuilder{ eb := gomatrixserverlib.EventBuilder{
Type: spec.MRoomMember, Type: spec.MRoomMember,
@ -287,10 +262,7 @@ func (r *Joiner) performJoinRoomByID(
// Servers MUST only allow guest users to join rooms if the m.room.guest_access state event // Servers MUST only allow guest users to join rooms if the m.room.guest_access state event
// is present on the room and has the guest_access value can_join. // is present on the room and has the guest_access value can_join.
if guestAccess != "can_join" { if guestAccess != "can_join" {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrNotAllowed{Err: fmt.Errorf("guest access is forbidden")}
Code: rsAPI.PerformErrorNotAllowed,
Msg: "Guest access is forbidden",
}
} }
} }
@ -342,16 +314,10 @@ func (r *Joiner) performJoinRoomByID(
} }
inputRes := rsAPI.InputRoomEventsResponse{} inputRes := rsAPI.InputRoomEventsResponse{}
if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrNotAllowed{Err: err}
Code: rsAPI.PerformErrorNoOperation,
Msg: fmt.Sprintf("InputRoomEvents failed: %s", err),
}
} }
if err = inputRes.Err(); err != nil { if err = inputRes.Err(); err != nil {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrNotAllowed{Err: err}
Code: rsAPI.PerformErrorNotAllowed,
Msg: fmt.Sprintf("InputRoomEvents auth failed: %s", err),
}
} }
} }
@ -364,10 +330,7 @@ func (r *Joiner) performJoinRoomByID(
// Otherwise we'll try a federated join as normal, since it's quite // Otherwise we'll try a federated join as normal, since it's quite
// possible that the room still exists on other servers. // possible that the room still exists on other servers.
if len(req.ServerNames) == 0 { if len(req.ServerNames) == 0 {
return "", "", &rsAPI.PerformError{ return "", "", eventutil.ErrRoomNoExists
Code: rsAPI.PerformErrorNoRoom,
Msg: fmt.Sprintf("room ID %q does not exist", req.RoomIDOrAlias),
}
} }
} }
@ -402,11 +365,7 @@ func (r *Joiner) performFederatedJoinRoomByID(
fedRes := fsAPI.PerformJoinResponse{} fedRes := fsAPI.PerformJoinResponse{}
r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes) r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes)
if fedRes.LastError != nil { if fedRes.LastError != nil {
return "", &rsAPI.PerformError{ return "", fedRes.LastError
Code: rsAPI.PerformErrRemote,
Msg: fedRes.LastError.Message,
RemoteCode: fedRes.LastError.Code,
}
} }
return fedRes.JoinedVia, nil return fedRes.JoinedVia, nil
} }
@ -430,10 +389,7 @@ func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin(
return "", nil return "", nil
} }
if !res.Allowed { if !res.Allowed {
return "", &rsAPI.PerformError{ return "", rsAPI.ErrNotAllowed{Err: fmt.Errorf("the join to room %s was not allowed", joinReq.RoomIDOrAlias)}
Code: rsAPI.PerformErrorNotAllowed,
Msg: fmt.Sprintf("The join to room %s was not allowed.", joinReq.RoomIDOrAlias),
}
} }
return res.AuthorisedVia, nil return res.AuthorisedVia, nil
} }

View file

@ -44,21 +44,8 @@ type Peeker struct {
func (r *Peeker) PerformPeek( func (r *Peeker) PerformPeek(
ctx context.Context, ctx context.Context,
req *api.PerformPeekRequest, req *api.PerformPeekRequest,
res *api.PerformPeekResponse, ) (roomID string, err error) {
) error { return r.performPeek(ctx, req)
roomID, err := r.performPeek(ctx, req)
if err != nil {
perr, ok := err.(*api.PerformError)
if ok {
res.Error = perr
} else {
res.Error = &api.PerformError{
Msg: err.Error(),
}
}
}
res.RoomID = roomID
return nil
} }
func (r *Peeker) performPeek( func (r *Peeker) performPeek(
@ -68,16 +55,10 @@ func (r *Peeker) performPeek(
// FIXME: there's way too much duplication with performJoin // FIXME: there's way too much duplication with performJoin
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID) _, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {
return "", &api.PerformError{ return "", api.ErrInvalidID{Err: fmt.Errorf("supplied user ID %q in incorrect format", req.UserID)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID),
}
} }
if !r.Cfg.Matrix.IsLocalServerName(domain) { if !r.Cfg.Matrix.IsLocalServerName(domain) {
return "", &api.PerformError{ return "", api.ErrInvalidID{Err: fmt.Errorf("user %q does not belong to this homeserver", req.UserID)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID),
}
} }
if strings.HasPrefix(req.RoomIDOrAlias, "!") { if strings.HasPrefix(req.RoomIDOrAlias, "!") {
return r.performPeekRoomByID(ctx, req) return r.performPeekRoomByID(ctx, req)
@ -85,10 +66,7 @@ func (r *Peeker) performPeek(
if strings.HasPrefix(req.RoomIDOrAlias, "#") { if strings.HasPrefix(req.RoomIDOrAlias, "#") {
return r.performPeekRoomByAlias(ctx, req) return r.performPeekRoomByAlias(ctx, req)
} }
return "", &api.PerformError{ return "", api.ErrInvalidID{Err: fmt.Errorf("room ID or alias %q is invalid", req.RoomIDOrAlias)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID or alias %q is invalid", req.RoomIDOrAlias),
}
} }
func (r *Peeker) performPeekRoomByAlias( func (r *Peeker) performPeekRoomByAlias(
@ -98,7 +76,7 @@ func (r *Peeker) performPeekRoomByAlias(
// Get the domain part of the room alias. // Get the domain part of the room alias.
_, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias) _, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias)
if err != nil { if err != nil {
return "", fmt.Errorf("alias %q is not in the correct format", req.RoomIDOrAlias) return "", api.ErrInvalidID{Err: fmt.Errorf("alias %q is not in the correct format", req.RoomIDOrAlias)}
} }
req.ServerNames = append(req.ServerNames, domain) req.ServerNames = append(req.ServerNames, domain)
@ -147,10 +125,7 @@ func (r *Peeker) performPeekRoomByID(
// Get the domain part of the room ID. // Get the domain part of the room ID.
_, domain, err := gomatrixserverlib.SplitID('!', roomID) _, domain, err := gomatrixserverlib.SplitID('!', roomID)
if err != nil { if err != nil {
return "", &api.PerformError{ return "", api.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", roomID, err)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID %q is invalid: %s", roomID, err),
}
} }
// handle federated peeks // handle federated peeks
@ -169,11 +144,7 @@ func (r *Peeker) performPeekRoomByID(
fedRes := fsAPI.PerformOutboundPeekResponse{} fedRes := fsAPI.PerformOutboundPeekResponse{}
_ = r.FSAPI.PerformOutboundPeek(ctx, &fedReq, &fedRes) _ = r.FSAPI.PerformOutboundPeek(ctx, &fedReq, &fedRes)
if fedRes.LastError != nil { if fedRes.LastError != nil {
return "", &api.PerformError{ return "", fedRes.LastError
Code: api.PerformErrRemote,
Msg: fedRes.LastError.Message,
RemoteCode: fedRes.LastError.Code,
}
} }
} }
@ -194,17 +165,11 @@ func (r *Peeker) performPeekRoomByID(
} }
if !worldReadable { if !worldReadable {
return "", &api.PerformError{ return "", api.ErrNotAllowed{Err: fmt.Errorf("room is not world-readable")}
Code: api.PerformErrorNotAllowed,
Msg: "Room is not world-readable",
}
} }
if ev, _ := r.DB.GetStateEvent(ctx, roomID, "m.room.encryption", ""); ev != nil { if ev, _ := r.DB.GetStateEvent(ctx, roomID, "m.room.encryption", ""); ev != nil {
return "", &api.PerformError{ return "", api.ErrNotAllowed{Err: fmt.Errorf("Cannot peek into an encrypted room")}
Code: api.PerformErrorNotAllowed,
Msg: "Cannot peek into an encrypted room",
}
} }
// TODO: handle federated peeks // TODO: handle federated peeks

View file

@ -25,16 +25,10 @@ type Publisher struct {
DB storage.Database DB storage.Database
} }
// PerformPublish publishes or unpublishes a room from the room directory. Returns a database error, if any.
func (r *Publisher) PerformPublish( func (r *Publisher) PerformPublish(
ctx context.Context, ctx context.Context,
req *api.PerformPublishRequest, req *api.PerformPublishRequest,
res *api.PerformPublishResponse,
) error { ) error {
err := r.DB.PublishRoom(ctx, req.RoomID, req.AppserviceID, req.NetworkID, req.Visibility == "public") return r.DB.PublishRoom(ctx, req.RoomID, req.AppserviceID, req.NetworkID, req.Visibility == "public")
if err != nil {
res.Error = &api.PerformError{
Msg: err.Error(),
}
}
return nil
} }

View file

@ -34,84 +34,48 @@ type Unpeeker struct {
Inputer *input.Inputer Inputer *input.Inputer
} }
// PerformPeek handles peeking into matrix rooms, including over federation by talking to the federationapi. // PerformUnpeek handles un-peeking matrix rooms, including over federation by talking to the federationapi.
func (r *Unpeeker) PerformUnpeek( func (r *Unpeeker) PerformUnpeek(
ctx context.Context, ctx context.Context,
req *api.PerformUnpeekRequest, roomID, userID, deviceID string,
res *api.PerformUnpeekResponse,
) error {
if err := r.performUnpeek(ctx, req); err != nil {
perr, ok := err.(*api.PerformError)
if ok {
res.Error = perr
} else {
res.Error = &api.PerformError{
Msg: err.Error(),
}
}
}
return nil
}
func (r *Unpeeker) performUnpeek(
ctx context.Context,
req *api.PerformUnpeekRequest,
) error { ) error {
// FIXME: there's way too much duplication with performJoin // FIXME: there's way too much duplication with performJoin
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID) _, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return &api.PerformError{ return api.ErrInvalidID{Err: fmt.Errorf("supplied user ID %q in incorrect format", userID)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID),
}
} }
if !r.Cfg.Matrix.IsLocalServerName(domain) { if !r.Cfg.Matrix.IsLocalServerName(domain) {
return &api.PerformError{ return api.ErrInvalidID{Err: fmt.Errorf("user %q does not belong to this homeserver", userID)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID),
} }
if strings.HasPrefix(roomID, "!") {
return r.performUnpeekRoomByID(ctx, roomID, userID, deviceID)
} }
if strings.HasPrefix(req.RoomID, "!") { return api.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid", roomID)}
return r.performUnpeekRoomByID(ctx, req)
}
return &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID %q is invalid", req.RoomID),
}
} }
func (r *Unpeeker) performUnpeekRoomByID( func (r *Unpeeker) performUnpeekRoomByID(
_ context.Context, _ context.Context,
req *api.PerformUnpeekRequest, roomID, userID, deviceID string,
) (err error) { ) (err error) {
// Get the domain part of the room ID. // Get the domain part of the room ID.
_, _, err = gomatrixserverlib.SplitID('!', req.RoomID) _, _, err = gomatrixserverlib.SplitID('!', roomID)
if err != nil { if err != nil {
return &api.PerformError{ return api.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", roomID, err)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID %q is invalid: %s", req.RoomID, err),
}
} }
// TODO: handle federated peeks // TODO: handle federated peeks
err = r.Inputer.OutputProducer.ProduceRoomEvents(req.RoomID, []api.OutputEvent{
{
Type: api.OutputTypeRetirePeek,
RetirePeek: &api.OutputRetirePeek{
RoomID: req.RoomID,
UserID: req.UserID,
DeviceID: req.DeviceID,
},
},
})
if err != nil {
return
}
// By this point, if req.RoomIDOrAlias contained an alias, then // By this point, if req.RoomIDOrAlias contained an alias, then
// it will have been overwritten with a room ID by performPeekRoomByAlias. // it will have been overwritten with a room ID by performPeekRoomByAlias.
// We should now include this in the response so that the CS API can // We should now include this in the response so that the CS API can
// return the right room ID. // return the right room ID.
return nil return r.Inputer.OutputProducer.ProduceRoomEvents(roomID, []api.OutputEvent{
{
Type: api.OutputTypeRetirePeek,
RetirePeek: &api.OutputRetirePeek{
RoomID: roomID,
UserID: userID,
DeviceID: deviceID,
},
},
})
} }

View file

@ -45,46 +45,29 @@ type fledglingEvent struct {
// PerformRoomUpgrade upgrades a room from one version to another // PerformRoomUpgrade upgrades a room from one version to another
func (r *Upgrader) PerformRoomUpgrade( func (r *Upgrader) PerformRoomUpgrade(
ctx context.Context, ctx context.Context,
req *api.PerformRoomUpgradeRequest, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion,
res *api.PerformRoomUpgradeResponse, ) (newRoomID string, err error) {
) error { return r.performRoomUpgrade(ctx, roomID, userID, roomVersion)
res.NewRoomID, res.Error = r.performRoomUpgrade(ctx, req)
if res.Error != nil {
res.NewRoomID = ""
logrus.WithContext(ctx).WithError(res.Error).Error("Room upgrade failed")
}
return nil
} }
func (r *Upgrader) performRoomUpgrade( func (r *Upgrader) performRoomUpgrade(
ctx context.Context, ctx context.Context,
req *api.PerformRoomUpgradeRequest, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion,
) (string, *api.PerformError) { ) (string, error) {
roomID := req.RoomID
userID := req.UserID
_, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
return "", &api.PerformError{ return "", api.ErrNotAllowed{Err: fmt.Errorf("error validating the user ID")}
Code: api.PerformErrorNotAllowed,
Msg: "Error validating the user ID",
}
} }
evTime := time.Now() evTime := time.Now()
// Return an immediate error if the room does not exist // Return an immediate error if the room does not exist
if err := r.validateRoomExists(ctx, roomID); err != nil { if err := r.validateRoomExists(ctx, roomID); err != nil {
return "", &api.PerformError{ return "", err
Code: api.PerformErrorNoRoom,
Msg: "Error validating that the room exists",
}
} }
// 1. Check if the user is authorized to actually perform the upgrade (can send m.room.tombstone) // 1. Check if the user is authorized to actually perform the upgrade (can send m.room.tombstone)
if !r.userIsAuthorized(ctx, userID, roomID) { if !r.userIsAuthorized(ctx, userID, roomID) {
return "", &api.PerformError{ return "", api.ErrNotAllowed{Err: fmt.Errorf("You don't have permission to upgrade the room, power level too low.")}
Code: api.PerformErrorNotAllowed,
Msg: "You don't have permission to upgrade the room, power level too low.",
}
} }
// TODO (#267): Check room ID doesn't clash with an existing one, and we // TODO (#267): Check room ID doesn't clash with an existing one, and we
@ -97,9 +80,7 @@ func (r *Upgrader) performRoomUpgrade(
} }
oldRoomRes := &api.QueryLatestEventsAndStateResponse{} oldRoomRes := &api.QueryLatestEventsAndStateResponse{}
if err := r.URSAPI.QueryLatestEventsAndState(ctx, oldRoomReq, oldRoomRes); err != nil { if err := r.URSAPI.QueryLatestEventsAndState(ctx, oldRoomReq, oldRoomRes); err != nil {
return "", &api.PerformError{ return "", fmt.Errorf("Failed to get latest state: %s", err)
Msg: fmt.Sprintf("Failed to get latest state: %s", err),
}
} }
// Make the tombstone event // Make the tombstone event
@ -110,13 +91,13 @@ func (r *Upgrader) performRoomUpgrade(
// Generate the initial events we need to send into the new room. This includes copied state events and bans // Generate the initial events we need to send into the new room. This includes copied state events and bans
// as well as the power level events needed to set up the room // as well as the power level events needed to set up the room
eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, userID, roomID, string(req.RoomVersion), tombstoneEvent) eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, userID, roomID, roomVersion, tombstoneEvent)
if pErr != nil { if pErr != nil {
return "", pErr return "", pErr
} }
// Send the setup events to the new room // Send the setup events to the new room
if pErr = r.sendInitialEvents(ctx, evTime, userID, userDomain, newRoomID, string(req.RoomVersion), eventsToMake); pErr != nil { if pErr = r.sendInitialEvents(ctx, evTime, userID, userDomain, newRoomID, roomVersion, eventsToMake); pErr != nil {
return "", pErr return "", pErr
} }
@ -148,22 +129,15 @@ func (r *Upgrader) performRoomUpgrade(
return newRoomID, nil return newRoomID, nil
} }
func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*gomatrixserverlib.PowerLevelContent, *api.PerformError) { func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*gomatrixserverlib.PowerLevelContent, error) {
oldPowerLevelsEvent := api.GetStateEvent(ctx, r.URSAPI, roomID, gomatrixserverlib.StateKeyTuple{ oldPowerLevelsEvent := api.GetStateEvent(ctx, r.URSAPI, roomID, gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomPowerLevels, EventType: spec.MRoomPowerLevels,
StateKey: "", StateKey: "",
}) })
powerLevelContent, err := oldPowerLevelsEvent.PowerLevels() return oldPowerLevelsEvent.PowerLevels()
if err != nil {
util.GetLogger(ctx).WithError(err).Error()
return nil, &api.PerformError{
Msg: "Power level event was invalid or malformed",
}
}
return powerLevelContent, nil
} }
func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) *api.PerformError { func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) error {
restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID) restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID)
if pErr != nil { if pErr != nil {
return pErr return pErr
@ -185,54 +159,46 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T
StateKey: "", StateKey: "",
Content: restrictedPowerLevelContent, Content: restrictedPowerLevelContent,
}) })
if resErr != nil {
if resErr.Code == api.PerformErrorNotAllowed { switch resErr.(type) {
case api.ErrNotAllowed:
util.GetLogger(ctx).WithField(logrus.ErrorKey, resErr).Warn("UpgradeRoom: Could not restrict power levels in old room") util.GetLogger(ctx).WithField(logrus.ErrorKey, resErr).Warn("UpgradeRoom: Could not restrict power levels in old room")
} else { case nil:
return r.sendHeaderedEvent(ctx, userDomain, restrictedPowerLevelsHeadered, api.DoNotSendToOtherServers)
default:
return resErr return resErr
} }
} else {
if resErr = r.sendHeaderedEvent(ctx, userDomain, restrictedPowerLevelsHeadered, api.DoNotSendToOtherServers); resErr != nil {
return resErr
}
}
return nil return nil
} }
func moveLocalAliases(ctx context.Context, func moveLocalAliases(ctx context.Context,
roomID, newRoomID, userID string, roomID, newRoomID, userID string,
URSAPI api.RoomserverInternalAPI) *api.PerformError { URSAPI api.RoomserverInternalAPI,
var err error ) (err error) {
aliasReq := api.GetAliasesForRoomIDRequest{RoomID: roomID} aliasReq := api.GetAliasesForRoomIDRequest{RoomID: roomID}
aliasRes := api.GetAliasesForRoomIDResponse{} aliasRes := api.GetAliasesForRoomIDResponse{}
if err = URSAPI.GetAliasesForRoomID(ctx, &aliasReq, &aliasRes); err != nil { if err = URSAPI.GetAliasesForRoomID(ctx, &aliasReq, &aliasRes); err != nil {
return &api.PerformError{ return fmt.Errorf("Failed to get old room aliases: %w", err)
Msg: fmt.Sprintf("Failed to get old room aliases: %s", err),
}
} }
for _, alias := range aliasRes.Aliases { for _, alias := range aliasRes.Aliases {
removeAliasReq := api.RemoveRoomAliasRequest{UserID: userID, Alias: alias} removeAliasReq := api.RemoveRoomAliasRequest{UserID: userID, Alias: alias}
removeAliasRes := api.RemoveRoomAliasResponse{} removeAliasRes := api.RemoveRoomAliasResponse{}
if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil { if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil {
return &api.PerformError{ return fmt.Errorf("Failed to remove old room alias: %w", err)
Msg: fmt.Sprintf("Failed to remove old room alias: %s", err),
}
} }
setAliasReq := api.SetRoomAliasRequest{UserID: userID, Alias: alias, RoomID: newRoomID} setAliasReq := api.SetRoomAliasRequest{UserID: userID, Alias: alias, RoomID: newRoomID}
setAliasRes := api.SetRoomAliasResponse{} setAliasRes := api.SetRoomAliasResponse{}
if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil { if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil {
return &api.PerformError{ return fmt.Errorf("Failed to set new room alias: %w", err)
Msg: fmt.Sprintf("Failed to set new room alias: %s", err),
}
} }
} }
return nil return nil
} }
func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) *api.PerformError { func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) error {
for _, event := range oldRoom.StateEvents { for _, event := range oldRoom.StateEvents {
if event.Type() != spec.MRoomCanonicalAlias || !event.StateKeyEquals("") { if event.Type() != spec.MRoomCanonicalAlias || !event.StateKeyEquals("") {
continue continue
@ -242,9 +208,7 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api
AltAliases []string `json:"alt_aliases"` AltAliases []string `json:"alt_aliases"`
} }
if err := json.Unmarshal(event.Content(), &aliasContent); err != nil { if err := json.Unmarshal(event.Content(), &aliasContent); err != nil {
return &api.PerformError{ return fmt.Errorf("failed to unmarshal canonical aliases: %w", err)
Msg: fmt.Sprintf("Failed to unmarshal canonical aliases: %s", err),
}
} }
if aliasContent.Alias == "" && len(aliasContent.AltAliases) == 0 { if aliasContent.Alias == "" && len(aliasContent.AltAliases) == 0 {
// There are no canonical aliases to clear, therefore do nothing. // There are no canonical aliases to clear, therefore do nothing.
@ -256,30 +220,25 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api
Type: spec.MRoomCanonicalAlias, Type: spec.MRoomCanonicalAlias,
Content: map[string]interface{}{}, Content: map[string]interface{}{},
}) })
if resErr != nil { switch resErr.(type) {
if resErr.Code == api.PerformErrorNotAllowed { case api.ErrNotAllowed:
util.GetLogger(ctx).WithField(logrus.ErrorKey, resErr).Warn("UpgradeRoom: Could not set empty canonical alias event in old room") util.GetLogger(ctx).WithField(logrus.ErrorKey, resErr).Warn("UpgradeRoom: Could not set empty canonical alias event in old room")
} else { case nil:
return r.sendHeaderedEvent(ctx, userDomain, emptyCanonicalAliasEvent, api.DoNotSendToOtherServers)
default:
return resErr return resErr
} }
} else {
if resErr = r.sendHeaderedEvent(ctx, userDomain, emptyCanonicalAliasEvent, api.DoNotSendToOtherServers); resErr != nil {
return resErr
}
}
return nil return nil
} }
func (r *Upgrader) publishIfOldRoomWasPublic(ctx context.Context, roomID, newRoomID string) *api.PerformError { func (r *Upgrader) publishIfOldRoomWasPublic(ctx context.Context, roomID, newRoomID string) error {
// check if the old room was published // check if the old room was published
var pubQueryRes api.QueryPublishedRoomsResponse var pubQueryRes api.QueryPublishedRoomsResponse
err := r.URSAPI.QueryPublishedRooms(ctx, &api.QueryPublishedRoomsRequest{ err := r.URSAPI.QueryPublishedRooms(ctx, &api.QueryPublishedRoomsRequest{
RoomID: roomID, RoomID: roomID,
}, &pubQueryRes) }, &pubQueryRes)
if err != nil { if err != nil {
return &api.PerformError{ return err
Msg: "QueryPublishedRooms failed",
}
} }
// if the old room is published (was public), publish the new room // if the old room is published (was public), publish the new room
@ -295,36 +254,27 @@ func publishNewRoomAndUnpublishOldRoom(
oldRoomID, newRoomID string, oldRoomID, newRoomID string,
) { ) {
// expose this room in the published room list // expose this room in the published room list
var pubNewRoomRes api.PerformPublishResponse
if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
RoomID: newRoomID, RoomID: newRoomID,
Visibility: "public", Visibility: spec.Public,
}, &pubNewRoomRes); err != nil { }); err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to reach internal API")
} else if pubNewRoomRes.Error != nil {
// treat as non-fatal since the room is already made by this point // treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(pubNewRoomRes.Error).Error("failed to visibility:public") util.GetLogger(ctx).WithError(err).Error("failed to publish room")
} }
var unpubOldRoomRes api.PerformPublishResponse
// remove the old room from the published room list // remove the old room from the published room list
if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
RoomID: oldRoomID, RoomID: oldRoomID,
Visibility: "private", Visibility: "private",
}, &unpubOldRoomRes); err != nil { }); err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to reach internal API")
} else if unpubOldRoomRes.Error != nil {
// treat as non-fatal since the room is already made by this point // treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(unpubOldRoomRes.Error).Error("failed to visibility:private") util.GetLogger(ctx).WithError(err).Error("failed to un-publish room")
} }
} }
func (r *Upgrader) validateRoomExists(ctx context.Context, roomID string) error { func (r *Upgrader) validateRoomExists(ctx context.Context, roomID string) error {
if _, err := r.URSAPI.QueryRoomVersionForRoom(ctx, roomID); err != nil { if _, err := r.URSAPI.QueryRoomVersionForRoom(ctx, roomID); err != nil {
return &api.PerformError{ return eventutil.ErrRoomNoExists
Code: api.PerformErrorNoRoom,
Msg: "Room does not exist",
}
} }
return nil return nil
} }
@ -348,7 +298,7 @@ func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string,
} }
// nolint:gocyclo // nolint:gocyclo
func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, userID, roomID, newVersion string, tombstoneEvent *types.HeaderedEvent) ([]fledglingEvent, *api.PerformError) { func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, userID, roomID string, newVersion gomatrixserverlib.RoomVersion, tombstoneEvent *types.HeaderedEvent) ([]fledglingEvent, error) {
state := make(map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent, len(oldRoom.StateEvents)) state := make(map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent, len(oldRoom.StateEvents))
for _, event := range oldRoom.StateEvents { for _, event := range oldRoom.StateEvents {
if event.StateKey() == nil { if event.StateKey() == nil {
@ -391,9 +341,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
// old room state. Check that they are there. // old room state. Check that they are there.
for tuple := range override { for tuple := range override {
if _, ok := state[tuple]; !ok { if _, ok := state[tuple]; !ok {
return nil, &api.PerformError{ return nil, fmt.Errorf("essential event of type %q state key %q is missing", tuple.EventType, tuple.StateKey)
Msg: fmt.Sprintf("Essential event of type %q state key %q is missing", tuple.EventType, tuple.StateKey),
}
} }
} }
@ -440,9 +388,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
powerLevelContent, err := oldPowerLevelsEvent.PowerLevels() powerLevelContent, err := oldPowerLevelsEvent.PowerLevels()
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error() util.GetLogger(ctx).WithError(err).Error()
return nil, &api.PerformError{ return nil, fmt.Errorf("Power level event content was invalid")
Msg: "Power level event content was invalid",
}
} }
tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, userID) tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, userID)
@ -506,7 +452,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
return eventsToMake, nil return eventsToMake, nil
} }
func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, newRoomID, newVersion string, eventsToMake []fledglingEvent) *api.PerformError { func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []fledglingEvent) error {
var err error var err error
var builtEvents []*types.HeaderedEvent var builtEvents []*types.HeaderedEvent
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
@ -522,34 +468,27 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
} }
err = builder.SetContent(e.Content) err = builder.SetContent(e.Content)
if err != nil { if err != nil {
return &api.PerformError{ return fmt.Errorf("failed to set content of new %q event: %w", builder.Type, err)
Msg: fmt.Sprintf("Failed to set content of new %q event: %s", builder.Type, err),
}
} }
if i > 0 { if i > 0 {
builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()}
} }
var event *gomatrixserverlib.Event var event *gomatrixserverlib.Event
event, err = builder.AddAuthEventsAndBuild(userDomain, &authEvents, evTime, gomatrixserverlib.RoomVersion(newVersion), r.Cfg.Matrix.KeyID, r.Cfg.Matrix.PrivateKey) event, err = builder.AddAuthEventsAndBuild(userDomain, &authEvents, evTime, newVersion, r.Cfg.Matrix.KeyID, r.Cfg.Matrix.PrivateKey)
if err != nil { if err != nil {
return &api.PerformError{ return fmt.Errorf("failed to build new %q event: %w", builder.Type, err)
Msg: fmt.Sprintf("Failed to build new %q event: %s", builder.Type, err),
}
} }
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
return &api.PerformError{ return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err)
Msg: fmt.Sprintf("Failed to auth new %q event: %s", builder.Type, err),
}
} }
// Add the event to the list of auth events // Add the event to the list of auth events
builtEvents = append(builtEvents, &types.HeaderedEvent{Event: event}) builtEvents = append(builtEvents, &types.HeaderedEvent{Event: event})
err = authEvents.AddEvent(event) err = authEvents.AddEvent(event)
if err != nil { if err != nil {
return &api.PerformError{ return fmt.Errorf("failed to add new %q event to auth set: %w", builder.Type, err)
Msg: fmt.Sprintf("Failed to add new %q event to auth set: %s", builder.Type, err),
}
} }
} }
@ -563,9 +502,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
}) })
} }
if err = api.SendInputRoomEvents(ctx, r.URSAPI, userDomain, inputs, false); err != nil { if err = api.SendInputRoomEvents(ctx, r.URSAPI, userDomain, inputs, false); err != nil {
return &api.PerformError{ return fmt.Errorf("failed to send new room %q to roomserver: %w", newRoomID, err)
Msg: fmt.Sprintf("Failed to send new room %q to roomserver: %s", newRoomID, err),
}
} }
return nil return nil
} }
@ -574,7 +511,7 @@ func (r *Upgrader) makeTombstoneEvent(
ctx context.Context, ctx context.Context,
evTime time.Time, evTime time.Time,
userID, roomID, newRoomID string, userID, roomID, newRoomID string,
) (*types.HeaderedEvent, *api.PerformError) { ) (*types.HeaderedEvent, error) {
content := map[string]interface{}{ content := map[string]interface{}{
"body": "This room has been replaced", "body": "This room has been replaced",
"replacement_room": newRoomID, "replacement_room": newRoomID,
@ -586,7 +523,7 @@ func (r *Upgrader) makeTombstoneEvent(
return r.makeHeaderedEvent(ctx, evTime, userID, roomID, event) return r.makeHeaderedEvent(ctx, evTime, userID, roomID, event)
} }
func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event fledglingEvent) (*types.HeaderedEvent, *api.PerformError) { func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event fledglingEvent) (*types.HeaderedEvent, error) {
builder := gomatrixserverlib.EventBuilder{ builder := gomatrixserverlib.EventBuilder{
Sender: userID, Sender: userID,
RoomID: roomID, RoomID: roomID,
@ -595,47 +532,27 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user
} }
err := builder.SetContent(event.Content) err := builder.SetContent(event.Content)
if err != nil { if err != nil {
return nil, &api.PerformError{ return nil, fmt.Errorf("failed to set new %q event content: %w", builder.Type, err)
Msg: fmt.Sprintf("Failed to set new %q event content: %s", builder.Type, err),
}
} }
// Get the sender domain. // Get the sender domain.
_, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', builder.Sender) _, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', builder.Sender)
if serr != nil { if serr != nil {
return nil, &api.PerformError{ return nil, fmt.Errorf("Failed to split user ID %q: %w", builder.Sender, err)
Msg: fmt.Sprintf("Failed to split user ID %q: %s", builder.Sender, err),
}
} }
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil { if err != nil {
return nil, &api.PerformError{ return nil, fmt.Errorf("failed to get signing identity for %q: %w", senderDomain, err)
Msg: fmt.Sprintf("Failed to get signing identity for %q: %s", senderDomain, err),
}
} }
var queryRes api.QueryLatestEventsAndStateResponse var queryRes api.QueryLatestEventsAndStateResponse
headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &builder, r.Cfg.Matrix, identity, evTime, r.URSAPI, &queryRes) headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &builder, r.Cfg.Matrix, identity, evTime, r.URSAPI, &queryRes)
if err == eventutil.ErrRoomNoExists { if err == eventutil.ErrRoomNoExists {
return nil, &api.PerformError{ return nil, err
Code: api.PerformErrorNoRoom,
Msg: "Room does not exist",
}
} else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok {
return nil, &api.PerformError{ return nil, e
Msg: e.Error(),
}
} else if e, ok := err.(gomatrixserverlib.EventValidationError); ok { } else if e, ok := err.(gomatrixserverlib.EventValidationError); ok {
if e.Code == gomatrixserverlib.EventValidationTooLarge { return nil, e
return nil, &api.PerformError{
Msg: e.Error(),
}
}
return nil, &api.PerformError{
Msg: e.Error(),
}
} else if err != nil { } else if err != nil {
return nil, &api.PerformError{ return nil, fmt.Errorf("failed to build new %q event: %w", builder.Type, err)
Msg: fmt.Sprintf("Failed to build new %q event: %s", builder.Type, err),
}
} }
// check to see if this user can perform this operation // check to see if this user can perform this operation
stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents)) stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents))
@ -644,10 +561,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user
} }
provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
if err = gomatrixserverlib.Allowed(headeredEvent.Event, &provider); err != nil { if err = gomatrixserverlib.Allowed(headeredEvent.Event, &provider); err != nil {
return nil, &api.PerformError{ return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", builder.Type, err)} // TODO: Is this error string comprehensible to the client?
Code: api.PerformErrorNotAllowed,
Msg: fmt.Sprintf("Failed to auth new %q event: %s", builder.Type, err), // TODO: Is this error string comprehensible to the client?
}
} }
return headeredEvent, nil return headeredEvent, nil
@ -695,7 +609,7 @@ func (r *Upgrader) sendHeaderedEvent(
serverName spec.ServerName, serverName spec.ServerName,
headeredEvent *types.HeaderedEvent, headeredEvent *types.HeaderedEvent,
sendAsServer string, sendAsServer string,
) *api.PerformError { ) error {
var inputs []api.InputRoomEvent var inputs []api.InputRoomEvent
inputs = append(inputs, api.InputRoomEvent{ inputs = append(inputs, api.InputRoomEvent{
Kind: api.KindNew, Kind: api.KindNew,
@ -703,11 +617,5 @@ func (r *Upgrader) sendHeaderedEvent(
Origin: serverName, Origin: serverName,
SendAsServer: sendAsServer, SendAsServer: sendAsServer,
}) })
if err := api.SendInputRoomEvents(ctx, r.URSAPI, serverName, inputs, false); err != nil { return api.SendInputRoomEvents(ctx, r.URSAPI, serverName, inputs, false)
return &api.PerformError{
Msg: fmt.Sprintf("Failed to send new %q event to roomserver: %s", headeredEvent.Type(), err),
}
}
return nil
} }

View file

@ -254,13 +254,9 @@ func TestPurgeRoom(t *testing.T) {
} }
// some dummy entries to validate after purging // some dummy entries to validate after purging
publishResp := &api.PerformPublishResponse{} if err = rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{RoomID: room.ID, Visibility: spec.Public}); err != nil {
if err = rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{RoomID: room.ID, Visibility: "public"}, publishResp); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if publishResp.Error != nil {
t.Fatal(publishResp.Error)
}
isPublished, err := db.GetPublishedRoom(ctx, room.ID) isPublished, err := db.GetPublishedRoom(ctx, room.ID)
if err != nil { if err != nil {
@ -328,8 +324,7 @@ func TestPurgeRoom(t *testing.T) {
} }
// purge the room from the database // purge the room from the database
purgeResp := &api.PerformAdminPurgeRoomResponse{} if err = rsAPI.PerformAdminPurgeRoom(ctx, room.ID); err != nil {
if err = rsAPI.PerformAdminPurgeRoom(ctx, &api.PerformAdminPurgeRoomRequest{RoomID: room.ID}, purgeResp); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -926,7 +921,7 @@ func TestUpgrade(t *testing.T) {
if err := rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{ if err := rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{
RoomID: r.ID, RoomID: r.ID,
Visibility: spec.Public, Visibility: spec.Public,
}, &api.PerformPublishResponse{}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1070,25 +1065,19 @@ func TestUpgrade(t *testing.T) {
} }
roomID := tc.roomFunc(rsAPI) roomID := tc.roomFunc(rsAPI)
upgradeReq := api.PerformRoomUpgradeRequest{ newRoomID, err := rsAPI.PerformRoomUpgrade(processCtx.Context(), roomID, tc.upgradeUser, version.DefaultRoomVersion())
RoomID: roomID, if err != nil && tc.wantNewRoom {
UserID: tc.upgradeUser,
RoomVersion: version.DefaultRoomVersion(), // always upgrade to the latest version
}
upgradeRes := api.PerformRoomUpgradeResponse{}
if err := rsAPI.PerformRoomUpgrade(processCtx.Context(), &upgradeReq, &upgradeRes); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if tc.wantNewRoom && upgradeRes.NewRoomID == "" { if tc.wantNewRoom && newRoomID == "" {
t.Fatalf("expected a new room, but the upgrade failed") t.Fatalf("expected a new room, but the upgrade failed")
} }
if !tc.wantNewRoom && upgradeRes.NewRoomID != "" { if !tc.wantNewRoom && newRoomID != "" {
t.Fatalf("expected no new room, but the upgrade succeeded") t.Fatalf("expected no new room, but the upgrade succeeded")
} }
if tc.validateFunc != nil { if tc.validateFunc != nil {
tc.validateFunc(t, roomID, upgradeRes.NewRoomID, rsAPI) tc.validateFunc(t, roomID, newRoomID, rsAPI)
} }
}) })
} }

View file

@ -139,7 +139,7 @@ type Database interface {
// not found. // not found.
// Returns an error if the retrieval went wrong. // Returns an error if the retrieval went wrong.
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
// Publish or unpublish a room from the room directory. // PerformPublish publishes or unpublishes a room from the room directory. Returns a database error, if any.
PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error
// Returns a list of room IDs for rooms which are published. // Returns a list of room IDs for rooms which are published.
GetPublishedRooms(ctx context.Context, networkID string, includeAllNetworks bool) ([]string, error) GetPublishedRooms(ctx context.Context, networkID string, includeAllNetworks bool) ([]string, error)

View file

@ -172,8 +172,8 @@ func addUserToRoom(
UserID: userID, UserID: userID,
Content: addGroupContent, Content: addGroupContent,
} }
joinRes := rsapi.PerformJoinResponse{} _, _, err := rsAPI.PerformJoin(ctx, &joinReq)
return rsAPI.PerformJoin(ctx, &joinReq, &joinRes) return err
} }
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
@ -624,33 +624,28 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
return fmt.Errorf("server name %q not locally configured", serverName) return fmt.Errorf("server name %q not locally configured", serverName)
} }
evacuateReq := &rsapi.PerformAdminEvacuateUserRequest{ userID := fmt.Sprintf("@%s:%s", req.Localpart, serverName)
UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), _, err := a.RSAPI.PerformAdminEvacuateUser(ctx, userID)
} if err != nil {
evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{} logrus.WithError(err).WithField("userID", userID).Errorf("Failed to evacuate user after account deactivation")
if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil {
return err
}
if err := evacuateRes.Error; err != nil {
logrus.WithError(err).Errorf("Failed to evacuate user after account deactivation")
} }
deviceReq := &api.PerformDeviceDeletionRequest{ deviceReq := &api.PerformDeviceDeletionRequest{
UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName),
} }
deviceRes := &api.PerformDeviceDeletionResponse{} deviceRes := &api.PerformDeviceDeletionResponse{}
if err := a.PerformDeviceDeletion(ctx, deviceReq, deviceRes); err != nil { if err = a.PerformDeviceDeletion(ctx, deviceReq, deviceRes); err != nil {
return err return err
} }
pusherReq := &api.PerformPusherDeletionRequest{ pusherReq := &api.PerformPusherDeletionRequest{
Localpart: req.Localpart, Localpart: req.Localpart,
} }
if err := a.PerformPusherDeletion(ctx, pusherReq, &struct{}{}); err != nil { if err = a.PerformPusherDeletion(ctx, pusherReq, &struct{}{}); err != nil {
return err return err
} }
err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName) err = a.DB.DeactivateAccount(ctx, req.Localpart, serverName)
res.AccountDeactivated = err == nil res.AccountDeactivated = err == nil
return err return err
} }