Remove PerformError from peeking and joining endpoints

This commit is contained in:
Till Faelligen 2023-04-25 09:12:33 +02:00
parent e3f1456517
commit 4b7875fb27
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
8 changed files with 137 additions and 258 deletions

View file

@ -15,14 +15,18 @@
package routing package routing
import ( import (
"encoding/json"
"errors"
"net/http" "net/http"
"time" "time"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -41,7 +45,6 @@ func JoinRoomByIDOrAlias(
IsGuest: device.AccountType == api.AccountTypeGuest, IsGuest: device.AccountType == api.AccountTypeGuest,
Content: map[string]interface{}{}, Content: map[string]interface{}{},
} }
joinRes := roomserverAPI.PerformJoinResponse{}
// Check to see if any ?server_name= query parameters were // Check to see if any ?server_name= query parameters were
// given in the request. // given in the request.
@ -81,37 +84,66 @@ func JoinRoomByIDOrAlias(
done := make(chan util.JSONResponse, 1) done := make(chan util.JSONResponse, 1)
go func() { go func() {
defer close(done) defer close(done)
if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil { roomID, _, err := rsAPI.PerformJoin(req.Context(), &joinReq)
done <- jsonerror.InternalAPIError(req.Context(), err) var response util.JSONResponse
} else if joinRes.Error != nil {
if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest { switch e := err.(type) {
done <- util.JSONResponse{ case nil: // success case
Code: http.StatusForbidden, response = util.JSONResponse{
JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg),
}
} else {
done <- joinRes.Error.JSONResponse()
}
} else {
done <- util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
// TODO: Put the response struct somewhere internal. // TODO: Put the response struct somewhere internal.
JSON: struct { JSON: struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
}{joinRes.RoomID}, }{roomID},
}
case roomserverAPI.ErrInvalidID:
response = util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(e.Error()),
}
case roomserverAPI.ErrNotAllowed:
jsonErr := jsonerror.Forbidden(e.Error())
if device.AccountType == api.AccountTypeGuest {
jsonErr = jsonerror.GuestAccessForbidden(e.Error())
}
response = util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonErr,
}
case *gomatrix.HTTPError: // this ensures we proxy responses over federation to the client
response = util.JSONResponse{
Code: e.Code,
JSON: json.RawMessage(e.Message),
}
default:
response = util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.InternalServerError(),
}
if errors.Is(err, eventutil.ErrRoomNoExists) {
response = util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(e.Error()),
} }
} }
}
done <- response
}() }()
// Wait either for the join to finish, or for us to hit a reasonable // Wait either for the join to finish, or for us to hit a reasonable
// timeout, at which point we'll just return a 200 to placate clients. // timeout, at which point we'll just return a 200 to placate clients.
timer := time.NewTimer(time.Second * 20)
select { select {
case <-time.After(time.Second * 20): case <-timer.C:
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusAccepted, Code: http.StatusAccepted,
JSON: jsonerror.Unknown("The room join will continue in the background."), JSON: jsonerror.Unknown("The room join will continue in the background."),
} }
case result := <-done: case result := <-done:
// Stop and drain the timer
if !timer.Stop() {
<-timer.C
}
return result return result
} }
} }

View file

@ -15,13 +15,16 @@
package routing package routing
import ( import (
"encoding/json"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
func PeekRoomByIDOrAlias( func PeekRoomByIDOrAlias(
@ -41,8 +44,6 @@ func PeekRoomByIDOrAlias(
UserID: device.UserID, UserID: device.UserID,
DeviceID: device.ID, DeviceID: device.ID,
} }
peekRes := roomserverAPI.PerformPeekResponse{}
// Check to see if any ?server_name= query parameters were // Check to see if any ?server_name= query parameters were
// given in the request. // given in the request.
if serverNames, ok := req.URL.Query()["server_name"]; ok { if serverNames, ok := req.URL.Query()["server_name"]; ok {
@ -55,11 +56,27 @@ func PeekRoomByIDOrAlias(
} }
// Ask the roomserver to perform the peek. // Ask the roomserver to perform the peek.
if err := rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes); err != nil { roomID, err := rsAPI.PerformPeek(req.Context(), &peekReq)
return util.ErrorResponse(err) switch e := err.(type) {
case roomserverAPI.ErrInvalidID:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(e.Error()),
} }
if peekRes.Error != nil { case roomserverAPI.ErrNotAllowed:
return peekRes.Error.JSONResponse() return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(e.Error()),
}
case *gomatrix.HTTPError:
return util.JSONResponse{
Code: e.Code,
JSON: json.RawMessage(e.Message),
}
case nil:
default:
logrus.WithError(err).WithField("roomID", roomIDOrAlias).Errorf("Failed to peek room")
return jsonerror.InternalServerError()
} }
// if this user is already joined to the room, we let them peek anyway // if this user is already joined to the room, we let them peek anyway
@ -75,7 +92,7 @@ func PeekRoomByIDOrAlias(
// TODO: Put the response struct somewhere internal. // TODO: Put the response struct somewhere internal.
JSON: struct { JSON: struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
}{peekRes.RoomID}, }{roomID},
} }
} }
@ -85,18 +102,17 @@ func UnpeekRoomByID(
rsAPI roomserverAPI.ClientRoomserverAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
roomID string, roomID string,
) util.JSONResponse { ) util.JSONResponse {
unpeekReq := roomserverAPI.PerformUnpeekRequest{ err := rsAPI.PerformUnpeek(req.Context(), roomID, device.UserID, device.ID)
RoomID: roomID, switch e := err.(type) {
UserID: device.UserID, case roomserverAPI.ErrInvalidID:
DeviceID: device.ID, return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(e.Error()),
} }
unpeekRes := roomserverAPI.PerformUnpeekResponse{} case nil:
default:
if err := rsAPI.PerformUnpeek(req.Context(), &unpeekReq, &unpeekRes); err != nil { logrus.WithError(err).WithField("roomID", roomID).Errorf("Failed to un-peek room")
return jsonerror.InternalAPIError(req.Context(), err) return jsonerror.InternalServerError()
}
if unpeekRes.Error != nil {
return unpeekRes.Error.JSONResponse()
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -174,10 +174,10 @@ type ClientRoomserverAPI interface {
PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error) PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error)
PerformAdminPurgeRoom(ctx context.Context, roomID string) error PerformAdminPurgeRoom(ctx context.Context, roomID string) error
PerformAdminDownloadState(ctx context.Context, roomID, userID string, serverName spec.ServerName) error PerformAdminDownloadState(ctx context.Context, roomID, userID string, serverName spec.ServerName) error
PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error PerformPeek(ctx context.Context, req *PerformPeekRequest) (roomID string, err error)
PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error PerformUnpeek(ctx context.Context, roomID, userID, deviceID string) error
PerformInvite(ctx context.Context, req *PerformInviteRequest) error PerformInvite(ctx context.Context, req *PerformInviteRequest) error
PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) error PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error
PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse) error PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse) error
// PerformForget forgets a rooms history for a specific user // PerformForget forgets a rooms history for a specific user
@ -192,7 +192,7 @@ type UserRoomserverAPI interface {
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error) PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error)
PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) error PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
} }
type FederationRoomserverAPI interface { type FederationRoomserverAPI interface {

View file

@ -1,7 +1,6 @@
package api package api
import ( import (
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -28,11 +27,6 @@ func (p *PerformError) Error() string {
// JSONResponse maps error codes to suitable HTTP error codes, defaulting to 500. // JSONResponse maps error codes to suitable HTTP error codes, defaulting to 500.
func (p *PerformError) JSONResponse() util.JSONResponse { func (p *PerformError) JSONResponse() util.JSONResponse {
switch p.Code { switch p.Code {
case PerformErrorBadRequest:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(p.Msg),
}
case PerformErrorNoRoom: case PerformErrorNoRoom:
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
@ -43,22 +37,6 @@ func (p *PerformError) JSONResponse() util.JSONResponse {
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(p.Msg), JSON: jsonerror.Forbidden(p.Msg),
} }
case PerformErrorNoOperation:
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(p.Msg),
}
case PerformErrRemote:
// if the code is 0 then something bad happened and it isn't
// a remote HTTP error being encapsulated, e.g network error to remote.
if p.RemoteCode == 0 {
return util.ErrorResponse(fmt.Errorf("%s", p.Msg))
}
return util.JSONResponse{
Code: p.RemoteCode,
// TODO: Should we assert this is in fact JSON? E.g gjson parse?
JSON: json.RawMessage(p.Msg),
}
default: default:
return util.ErrorResponse(p) return util.ErrorResponse(p)
} }
@ -67,14 +45,8 @@ func (p *PerformError) JSONResponse() util.JSONResponse {
const ( const (
// PerformErrorNotAllowed means the user is not allowed to invite/join/etc this room (e.g join_rule:invite or banned) // PerformErrorNotAllowed means the user is not allowed to invite/join/etc this room (e.g join_rule:invite or banned)
PerformErrorNotAllowed PerformErrorCode = 1 PerformErrorNotAllowed PerformErrorCode = 1
// PerformErrorBadRequest means the request was wrong in some way (invalid user ID, wrong server, etc)
PerformErrorBadRequest PerformErrorCode = 2
// PerformErrorNoRoom means that the room being joined doesn't exist. // PerformErrorNoRoom means that the room being joined doesn't exist.
PerformErrorNoRoom PerformErrorCode = 3 PerformErrorNoRoom PerformErrorCode = 3
// PerformErrorNoOperation means that the request resulted in nothing happening e.g invite->invite or leave->leave.
PerformErrorNoOperation PerformErrorCode = 4
// PerformErrRemote means that the request failed and the PerformError.Msg is the raw remote JSON error response
PerformErrRemote PerformErrorCode = 5
) )
type PerformJoinRequest struct { type PerformJoinRequest struct {
@ -86,14 +58,6 @@ type PerformJoinRequest struct {
Unsigned map[string]interface{} `json:"unsigned"` Unsigned map[string]interface{} `json:"unsigned"`
} }
type PerformJoinResponse struct {
// The room ID, populated on success.
RoomID string `json:"room_id"`
JoinedVia spec.ServerName
// If non-nil, the join request failed. Contains more information why it failed.
Error *PerformError
}
type PerformLeaveRequest struct { type PerformLeaveRequest struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
@ -119,24 +83,6 @@ type PerformPeekRequest struct {
ServerNames []spec.ServerName `json:"server_names"` ServerNames []spec.ServerName `json:"server_names"`
} }
type PerformPeekResponse struct {
// The room ID, populated on success.
RoomID string `json:"room_id"`
// If non-nil, the join request failed. Contains more information why it failed.
Error *PerformError
}
type PerformUnpeekRequest struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
}
type PerformUnpeekResponse struct {
// If non-nil, the join request failed. Contains more information why it failed.
Error *PerformError
}
// PerformBackfillRequest is a request to PerformBackfill. // PerformBackfillRequest is a request to PerformBackfill.
type PerformBackfillRequest struct { type PerformBackfillRequest struct {
// The room to backfill // The room to backfill

View file

@ -53,32 +53,22 @@ type Joiner struct {
func (r *Joiner) PerformJoin( func (r *Joiner) PerformJoin(
ctx context.Context, ctx context.Context,
req *rsAPI.PerformJoinRequest, req *rsAPI.PerformJoinRequest,
res *rsAPI.PerformJoinResponse, ) (roomID string, joinedVia spec.ServerName, err error) {
) error {
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
"room_id": req.RoomIDOrAlias, "room_id": req.RoomIDOrAlias,
"user_id": req.UserID, "user_id": req.UserID,
"servers": req.ServerNames, "servers": req.ServerNames,
}) })
logger.Info("User requested to room join") logger.Info("User requested to room join")
roomID, joinedVia, err := r.performJoin(context.Background(), req) roomID, joinedVia, err = r.performJoin(context.Background(), req)
if err != nil { if err != nil {
logger.WithError(err).Error("Failed to join room") logger.WithError(err).Error("Failed to join room")
sentry.CaptureException(err) sentry.CaptureException(err)
perr, ok := err.(*rsAPI.PerformError) return "", "", err
if ok {
res.Error = perr
} else {
res.Error = &rsAPI.PerformError{
Msg: err.Error(),
}
}
return nil
} }
logger.Info("User joined room successfully") logger.Info("User joined room successfully")
res.RoomID = roomID
res.JoinedVia = joinedVia return roomID, joinedVia, nil
return nil
} }
func (r *Joiner) performJoin( func (r *Joiner) performJoin(
@ -87,16 +77,10 @@ func (r *Joiner) performJoin(
) (string, spec.ServerName, error) { ) (string, spec.ServerName, error) {
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID) _, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("supplied user ID %q in incorrect format", req.UserID)}
Code: rsAPI.PerformErrorBadRequest,
Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID),
}
} }
if !r.Cfg.Matrix.IsLocalServerName(domain) { if !r.Cfg.Matrix.IsLocalServerName(domain) {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user %q does not belong to this homeserver", req.UserID)}
Code: rsAPI.PerformErrorBadRequest,
Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID),
}
} }
if strings.HasPrefix(req.RoomIDOrAlias, "!") { if strings.HasPrefix(req.RoomIDOrAlias, "!") {
return r.performJoinRoomByID(ctx, req) return r.performJoinRoomByID(ctx, req)
@ -104,10 +88,7 @@ func (r *Joiner) performJoin(
if strings.HasPrefix(req.RoomIDOrAlias, "#") { if strings.HasPrefix(req.RoomIDOrAlias, "#") {
return r.performJoinRoomByAlias(ctx, req) return r.performJoinRoomByAlias(ctx, req)
} }
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID or alias %q is invalid", req.RoomIDOrAlias)}
Code: rsAPI.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID or alias %q is invalid", req.RoomIDOrAlias),
}
} }
func (r *Joiner) performJoinRoomByAlias( func (r *Joiner) performJoinRoomByAlias(
@ -182,10 +163,7 @@ func (r *Joiner) performJoinRoomByID(
// Get the domain part of the room ID. // Get the domain part of the room ID.
_, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias) _, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias)
if err != nil { if err != nil {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", req.RoomIDOrAlias, err)}
Code: rsAPI.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID %q is invalid: %s", req.RoomIDOrAlias, err),
}
} }
// If the server name in the room ID isn't ours then it's a // If the server name in the room ID isn't ours then it's a
@ -199,10 +177,7 @@ func (r *Joiner) performJoinRoomByID(
userID := req.UserID userID := req.UserID
_, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
if err != nil { if err != nil {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", userID, err)}
Code: rsAPI.PerformErrorBadRequest,
Msg: fmt.Sprintf("User ID %q is invalid: %s", userID, err),
}
} }
eb := gomatrixserverlib.EventBuilder{ eb := gomatrixserverlib.EventBuilder{
Type: spec.MRoomMember, Type: spec.MRoomMember,
@ -286,10 +261,7 @@ func (r *Joiner) performJoinRoomByID(
// Servers MUST only allow guest users to join rooms if the m.room.guest_access state event // Servers MUST only allow guest users to join rooms if the m.room.guest_access state event
// is present on the room and has the guest_access value can_join. // is present on the room and has the guest_access value can_join.
if guestAccess != "can_join" { if guestAccess != "can_join" {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrNotAllowed{Err: fmt.Errorf("guest access is forbidden")}
Code: rsAPI.PerformErrorNotAllowed,
Msg: "Guest access is forbidden",
}
} }
} }
@ -341,16 +313,10 @@ func (r *Joiner) performJoinRoomByID(
} }
inputRes := rsAPI.InputRoomEventsResponse{} inputRes := rsAPI.InputRoomEventsResponse{}
if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrNotAllowed{Err: err}
Code: rsAPI.PerformErrorNoOperation,
Msg: fmt.Sprintf("InputRoomEvents failed: %s", err),
}
} }
if err = inputRes.Err(); err != nil { if err = inputRes.Err(); err != nil {
return "", "", &rsAPI.PerformError{ return "", "", rsAPI.ErrNotAllowed{Err: err}
Code: rsAPI.PerformErrorNotAllowed,
Msg: fmt.Sprintf("InputRoomEvents auth failed: %s", err),
}
} }
} }
@ -363,10 +329,7 @@ func (r *Joiner) performJoinRoomByID(
// Otherwise we'll try a federated join as normal, since it's quite // Otherwise we'll try a federated join as normal, since it's quite
// possible that the room still exists on other servers. // possible that the room still exists on other servers.
if len(req.ServerNames) == 0 { if len(req.ServerNames) == 0 {
return "", "", &rsAPI.PerformError{ return "", "", eventutil.ErrRoomNoExists
Code: rsAPI.PerformErrorNoRoom,
Msg: fmt.Sprintf("room ID %q does not exist", req.RoomIDOrAlias),
}
} }
} }
@ -401,11 +364,7 @@ func (r *Joiner) performFederatedJoinRoomByID(
fedRes := fsAPI.PerformJoinResponse{} fedRes := fsAPI.PerformJoinResponse{}
r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes) r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes)
if fedRes.LastError != nil { if fedRes.LastError != nil {
return "", &rsAPI.PerformError{ return "", fedRes.LastError
Code: rsAPI.PerformErrRemote,
Msg: fedRes.LastError.Message,
RemoteCode: fedRes.LastError.Code,
}
} }
return fedRes.JoinedVia, nil return fedRes.JoinedVia, nil
} }
@ -429,10 +388,7 @@ func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin(
return "", nil return "", nil
} }
if !res.Allowed { if !res.Allowed {
return "", &rsAPI.PerformError{ return "", rsAPI.ErrNotAllowed{Err: fmt.Errorf("the join to room %s was not allowed", joinReq.RoomIDOrAlias)}
Code: rsAPI.PerformErrorNotAllowed,
Msg: fmt.Sprintf("The join to room %s was not allowed.", joinReq.RoomIDOrAlias),
}
} }
return res.AuthorisedVia, nil return res.AuthorisedVia, nil
} }

View file

@ -44,21 +44,8 @@ type Peeker struct {
func (r *Peeker) PerformPeek( func (r *Peeker) PerformPeek(
ctx context.Context, ctx context.Context,
req *api.PerformPeekRequest, req *api.PerformPeekRequest,
res *api.PerformPeekResponse, ) (roomID string, err error) {
) error { return r.performPeek(ctx, req)
roomID, err := r.performPeek(ctx, req)
if err != nil {
perr, ok := err.(*api.PerformError)
if ok {
res.Error = perr
} else {
res.Error = &api.PerformError{
Msg: err.Error(),
}
}
}
res.RoomID = roomID
return nil
} }
func (r *Peeker) performPeek( func (r *Peeker) performPeek(
@ -68,16 +55,10 @@ func (r *Peeker) performPeek(
// FIXME: there's way too much duplication with performJoin // FIXME: there's way too much duplication with performJoin
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID) _, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {
return "", &api.PerformError{ return "", api.ErrInvalidID{Err: fmt.Errorf("supplied user ID %q in incorrect format", req.UserID)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID),
}
} }
if !r.Cfg.Matrix.IsLocalServerName(domain) { if !r.Cfg.Matrix.IsLocalServerName(domain) {
return "", &api.PerformError{ return "", api.ErrInvalidID{Err: fmt.Errorf("user %q does not belong to this homeserver", req.UserID)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID),
}
} }
if strings.HasPrefix(req.RoomIDOrAlias, "!") { if strings.HasPrefix(req.RoomIDOrAlias, "!") {
return r.performPeekRoomByID(ctx, req) return r.performPeekRoomByID(ctx, req)
@ -85,10 +66,7 @@ func (r *Peeker) performPeek(
if strings.HasPrefix(req.RoomIDOrAlias, "#") { if strings.HasPrefix(req.RoomIDOrAlias, "#") {
return r.performPeekRoomByAlias(ctx, req) return r.performPeekRoomByAlias(ctx, req)
} }
return "", &api.PerformError{ return "", api.ErrInvalidID{Err: fmt.Errorf("room ID or alias %q is invalid", req.RoomIDOrAlias)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID or alias %q is invalid", req.RoomIDOrAlias),
}
} }
func (r *Peeker) performPeekRoomByAlias( func (r *Peeker) performPeekRoomByAlias(
@ -98,7 +76,7 @@ func (r *Peeker) performPeekRoomByAlias(
// Get the domain part of the room alias. // Get the domain part of the room alias.
_, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias) _, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias)
if err != nil { if err != nil {
return "", fmt.Errorf("alias %q is not in the correct format", req.RoomIDOrAlias) return "", api.ErrInvalidID{Err: fmt.Errorf("alias %q is not in the correct format", req.RoomIDOrAlias)}
} }
req.ServerNames = append(req.ServerNames, domain) req.ServerNames = append(req.ServerNames, domain)
@ -147,10 +125,7 @@ func (r *Peeker) performPeekRoomByID(
// Get the domain part of the room ID. // Get the domain part of the room ID.
_, domain, err := gomatrixserverlib.SplitID('!', roomID) _, domain, err := gomatrixserverlib.SplitID('!', roomID)
if err != nil { if err != nil {
return "", &api.PerformError{ return "", api.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", roomID, err)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID %q is invalid: %s", roomID, err),
}
} }
// handle federated peeks // handle federated peeks
@ -169,11 +144,7 @@ func (r *Peeker) performPeekRoomByID(
fedRes := fsAPI.PerformOutboundPeekResponse{} fedRes := fsAPI.PerformOutboundPeekResponse{}
_ = r.FSAPI.PerformOutboundPeek(ctx, &fedReq, &fedRes) _ = r.FSAPI.PerformOutboundPeek(ctx, &fedReq, &fedRes)
if fedRes.LastError != nil { if fedRes.LastError != nil {
return "", &api.PerformError{ return "", fedRes.LastError
Code: api.PerformErrRemote,
Msg: fedRes.LastError.Message,
RemoteCode: fedRes.LastError.Code,
}
} }
} }
@ -194,17 +165,11 @@ func (r *Peeker) performPeekRoomByID(
} }
if !worldReadable { if !worldReadable {
return "", &api.PerformError{ return "", api.ErrNotAllowed{Err: fmt.Errorf("room is not world-readable")}
Code: api.PerformErrorNotAllowed,
Msg: "Room is not world-readable",
}
} }
if ev, _ := r.DB.GetStateEvent(ctx, roomID, "m.room.encryption", ""); ev != nil { if ev, _ := r.DB.GetStateEvent(ctx, roomID, "m.room.encryption", ""); ev != nil {
return "", &api.PerformError{ return "", api.ErrNotAllowed{Err: fmt.Errorf("Cannot peek into an encrypted room")}
Code: api.PerformErrorNotAllowed,
Msg: "Cannot peek into an encrypted room",
}
} }
// TODO: handle federated peeks // TODO: handle federated peeks

View file

@ -34,84 +34,48 @@ type Unpeeker struct {
Inputer *input.Inputer Inputer *input.Inputer
} }
// PerformPeek handles peeking into matrix rooms, including over federation by talking to the federationapi. // PerformUnpeek handles un-peeking matrix rooms, including over federation by talking to the federationapi.
func (r *Unpeeker) PerformUnpeek( func (r *Unpeeker) PerformUnpeek(
ctx context.Context, ctx context.Context,
req *api.PerformUnpeekRequest, roomID, userID, deviceID string,
res *api.PerformUnpeekResponse,
) error {
if err := r.performUnpeek(ctx, req); err != nil {
perr, ok := err.(*api.PerformError)
if ok {
res.Error = perr
} else {
res.Error = &api.PerformError{
Msg: err.Error(),
}
}
}
return nil
}
func (r *Unpeeker) performUnpeek(
ctx context.Context,
req *api.PerformUnpeekRequest,
) error { ) error {
// FIXME: there's way too much duplication with performJoin // FIXME: there's way too much duplication with performJoin
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID) _, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return &api.PerformError{ return api.ErrInvalidID{Err: fmt.Errorf("supplied user ID %q in incorrect format", userID)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID),
}
} }
if !r.Cfg.Matrix.IsLocalServerName(domain) { if !r.Cfg.Matrix.IsLocalServerName(domain) {
return &api.PerformError{ return api.ErrInvalidID{Err: fmt.Errorf("user %q does not belong to this homeserver", userID)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID),
} }
if strings.HasPrefix(roomID, "!") {
return r.performUnpeekRoomByID(ctx, roomID, userID, deviceID)
} }
if strings.HasPrefix(req.RoomID, "!") { return api.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid", roomID)}
return r.performUnpeekRoomByID(ctx, req)
}
return &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID %q is invalid", req.RoomID),
}
} }
func (r *Unpeeker) performUnpeekRoomByID( func (r *Unpeeker) performUnpeekRoomByID(
_ context.Context, _ context.Context,
req *api.PerformUnpeekRequest, roomID, userID, deviceID string,
) (err error) { ) (err error) {
// Get the domain part of the room ID. // Get the domain part of the room ID.
_, _, err = gomatrixserverlib.SplitID('!', req.RoomID) _, _, err = gomatrixserverlib.SplitID('!', roomID)
if err != nil { if err != nil {
return &api.PerformError{ return api.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", roomID, err)}
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID %q is invalid: %s", req.RoomID, err),
}
} }
// TODO: handle federated peeks // TODO: handle federated peeks
err = r.Inputer.OutputProducer.ProduceRoomEvents(req.RoomID, []api.OutputEvent{
{
Type: api.OutputTypeRetirePeek,
RetirePeek: &api.OutputRetirePeek{
RoomID: req.RoomID,
UserID: req.UserID,
DeviceID: req.DeviceID,
},
},
})
if err != nil {
return
}
// By this point, if req.RoomIDOrAlias contained an alias, then // By this point, if req.RoomIDOrAlias contained an alias, then
// it will have been overwritten with a room ID by performPeekRoomByAlias. // it will have been overwritten with a room ID by performPeekRoomByAlias.
// We should now include this in the response so that the CS API can // We should now include this in the response so that the CS API can
// return the right room ID. // return the right room ID.
return nil return r.Inputer.OutputProducer.ProduceRoomEvents(roomID, []api.OutputEvent{
{
Type: api.OutputTypeRetirePeek,
RetirePeek: &api.OutputRetirePeek{
RoomID: roomID,
UserID: userID,
DeviceID: deviceID,
},
},
})
} }

View file

@ -172,8 +172,8 @@ func addUserToRoom(
UserID: userID, UserID: userID,
Content: addGroupContent, Content: addGroupContent,
} }
joinRes := rsapi.PerformJoinResponse{} _, _, err := rsAPI.PerformJoin(ctx, &joinReq)
return rsAPI.PerformJoin(ctx, &joinReq, &joinRes) return err
} }
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {