Add JoinedVia to PerformJoin responses

This commit is contained in:
Kegan Dougal 2020-11-18 18:09:05 +00:00
parent 46cbbc3fc9
commit 99b3d56a73
5 changed files with 43 additions and 30 deletions

View file

@ -105,6 +105,7 @@ type PerformJoinRequest struct {
} }
type PerformJoinResponse struct { type PerformJoinResponse struct {
JoinedVia gomatrixserverlib.ServerName
LastError *gomatrix.HTTPError LastError *gomatrix.HTTPError
} }

View file

@ -105,6 +105,7 @@ func (r *FederationSenderInternalAPI) PerformJoin(
} }
// We're all good. // We're all good.
response.JoinedVia = serverName
return return
} }

View file

@ -101,6 +101,7 @@ type EventRelationshipResponse struct {
} }
// Enable this MSC // Enable this MSC
// nolint:gocyclo
func Enable( func Enable(
base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI, base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI,
userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier, userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier,
@ -170,7 +171,7 @@ func Enable(
}) })
base.PublicClientAPIMux.Handle("/unstable/event_relationships", base.PublicClientAPIMux.Handle("/unstable/event_relationships",
httputil.MakeAuthAPI("eventRelationships", userAPI, eventRelationshipHandler(db, rsAPI, fsAPI)), httputil.MakeAuthAPI("eventRelationships", userAPI, eventRelationshipHandler(db, rsAPI)),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
base.PublicFederationAPIMux.Handle("/unstable/event_relationships", httputil.MakeExternalAPI( base.PublicFederationAPIMux.Handle("/unstable/event_relationships", httputil.MakeExternalAPI(
@ -191,13 +192,12 @@ type reqCtx struct {
ctx context.Context ctx context.Context
rsAPI roomserver.RoomserverInternalAPI rsAPI roomserver.RoomserverInternalAPI
db Database db Database
fsAPI fs.FederationSenderInternalAPI
req *EventRelationshipRequest req *EventRelationshipRequest
userID string userID string
isFederatedRequest bool isFederatedRequest bool
} }
func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse { func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse {
return func(req *http.Request, device *userapi.Device) util.JSONResponse { return func(req *http.Request, device *userapi.Device) util.JSONResponse {
relation, err := NewEventRelationshipRequest(req.Body) relation, err := NewEventRelationshipRequest(req.Body)
if err != nil { if err != nil {
@ -332,7 +332,8 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen
} }
var childEvents []*gomatrixserverlib.HeaderedEvent var childEvents []*gomatrixserverlib.HeaderedEvent
for _, child := range children { for _, child := range children {
childEvent := rc.getEventIfVisible(child.EventID, child.RoomID, child.Servers) // in order for us to even know about the children the server must be joined to those rooms, hence pass no claimed room ID or servers.
childEvent := rc.getEventIfVisible(child.EventID, "", nil)
if childEvent != nil { if childEvent != nil {
childEvents = append(childEvents, childEvent) childEvents = append(childEvents, childEvent)
} }
@ -370,6 +371,7 @@ func walkThread(
} }
// Process the event. // Process the event.
// TODO: Include edge information: room ID and servers
event := rc.getEventIfVisible(wi.EventID, "", nil) event := rc.getEventIfVisible(wi.EventID, "", nil)
if event != nil { if event != nil {
result = append(result, event) result = append(result, event)
@ -394,6 +396,10 @@ func (rc *reqCtx) getEventIfVisible(eventID string, claimedRoomID string, claime
if !rc.req.AutoJoin { if !rc.req.AutoJoin {
return nil return nil
} }
// if we're doing this on behalf of a random server don't auto-join rooms regardless of what the request says
if rc.isFederatedRequest {
return nil
}
roomID := claimedRoomID roomID := claimedRoomID
var servers []gomatrixserverlib.ServerName var servers []gomatrixserverlib.ServerName
if event != nil { if event != nil {
@ -416,7 +422,7 @@ func (rc *reqCtx) getEventIfVisible(eventID string, claimedRoomID string, claime
if event != nil { if event != nil {
return event return event
} }
// TODO: fetch the event in question // TODO: hit /event_relationships on the server we joined via
util.GetLogger(rc.ctx).Infof("joined room but need to fetch event TODO") util.GetLogger(rc.ctx).Infof("joined room but need to fetch event TODO")
return nil return nil
} }

View file

@ -84,6 +84,7 @@ type PerformJoinRequest struct {
type PerformJoinResponse struct { type PerformJoinResponse struct {
// The room ID, populated on success. // The room ID, populated on success.
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
JoinedVia gomatrixserverlib.ServerName
// If non-nil, the join request failed. Contains more information why it failed. // If non-nil, the join request failed. Contains more information why it failed.
Error *PerformError Error *PerformError
} }

View file

@ -47,7 +47,7 @@ func (r *Joiner) PerformJoin(
req *api.PerformJoinRequest, req *api.PerformJoinRequest,
res *api.PerformJoinResponse, res *api.PerformJoinResponse,
) { ) {
roomID, err := r.performJoin(ctx, req) roomID, joinedVia, err := r.performJoin(ctx, req)
if err != nil { if err != nil {
perr, ok := err.(*api.PerformError) perr, ok := err.(*api.PerformError)
if ok { if ok {
@ -59,21 +59,22 @@ func (r *Joiner) PerformJoin(
} }
} }
res.RoomID = roomID res.RoomID = roomID
res.JoinedVia = joinedVia
} }
func (r *Joiner) performJoin( func (r *Joiner) performJoin(
ctx context.Context, ctx context.Context,
req *api.PerformJoinRequest, req *api.PerformJoinRequest,
) (string, error) { ) (string, gomatrixserverlib.ServerName, error) {
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID) _, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {
return "", &api.PerformError{ return "", "", &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID),
} }
} }
if domain != r.Cfg.Matrix.ServerName { if domain != r.Cfg.Matrix.ServerName {
return "", &api.PerformError{ return "", "", &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID),
} }
@ -84,7 +85,7 @@ func (r *Joiner) performJoin(
if strings.HasPrefix(req.RoomIDOrAlias, "#") { if strings.HasPrefix(req.RoomIDOrAlias, "#") {
return r.performJoinRoomByAlias(ctx, req) return r.performJoinRoomByAlias(ctx, req)
} }
return "", &api.PerformError{ return "", "", &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID or alias %q is invalid", req.RoomIDOrAlias), Msg: fmt.Sprintf("Room ID or alias %q is invalid", req.RoomIDOrAlias),
} }
@ -93,11 +94,11 @@ func (r *Joiner) performJoin(
func (r *Joiner) performJoinRoomByAlias( func (r *Joiner) performJoinRoomByAlias(
ctx context.Context, ctx context.Context,
req *api.PerformJoinRequest, req *api.PerformJoinRequest,
) (string, error) { ) (string, gomatrixserverlib.ServerName, error) {
// Get the domain part of the room alias. // Get the domain part of the room alias.
_, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias) _, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias)
if err != nil { if err != nil {
return "", fmt.Errorf("Alias %q is not in the correct format", req.RoomIDOrAlias) return "", "", fmt.Errorf("Alias %q is not in the correct format", req.RoomIDOrAlias)
} }
req.ServerNames = append(req.ServerNames, domain) req.ServerNames = append(req.ServerNames, domain)
@ -115,7 +116,7 @@ func (r *Joiner) performJoinRoomByAlias(
err = r.FSAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes) err = r.FSAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("error looking up alias %q", req.RoomIDOrAlias) logrus.WithError(err).Errorf("error looking up alias %q", req.RoomIDOrAlias)
return "", fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err) return "", "", fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err)
} }
roomID = dirRes.RoomID roomID = dirRes.RoomID
req.ServerNames = append(req.ServerNames, dirRes.ServerNames...) req.ServerNames = append(req.ServerNames, dirRes.ServerNames...)
@ -123,13 +124,13 @@ func (r *Joiner) performJoinRoomByAlias(
// Otherwise, look up if we know this room alias locally. // Otherwise, look up if we know this room alias locally.
roomID, err = r.DB.GetRoomIDForAlias(ctx, req.RoomIDOrAlias) roomID, err = r.DB.GetRoomIDForAlias(ctx, req.RoomIDOrAlias)
if err != nil { if err != nil {
return "", fmt.Errorf("Lookup room alias %q failed: %w", req.RoomIDOrAlias, err) return "", "", fmt.Errorf("Lookup room alias %q failed: %w", req.RoomIDOrAlias, err)
} }
} }
// If the room ID is empty then we failed to look up the alias. // If the room ID is empty then we failed to look up the alias.
if roomID == "" { if roomID == "" {
return "", fmt.Errorf("Alias %q not found", req.RoomIDOrAlias) return "", "", fmt.Errorf("Alias %q not found", req.RoomIDOrAlias)
} }
// If we do, then pluck out the room ID and continue the join. // If we do, then pluck out the room ID and continue the join.
@ -142,11 +143,11 @@ func (r *Joiner) performJoinRoomByAlias(
func (r *Joiner) performJoinRoomByID( func (r *Joiner) performJoinRoomByID(
ctx context.Context, ctx context.Context,
req *api.PerformJoinRequest, req *api.PerformJoinRequest,
) (string, error) { ) (string, gomatrixserverlib.ServerName, error) {
// Get the domain part of the room ID. // Get the domain part of the room ID.
_, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias) _, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias)
if err != nil { if err != nil {
return "", &api.PerformError{ return "", "", &api.PerformError{
Code: api.PerformErrorBadRequest, Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID %q is invalid: %s", req.RoomIDOrAlias, err), Msg: fmt.Sprintf("Room ID %q is invalid: %s", req.RoomIDOrAlias, err),
} }
@ -169,7 +170,7 @@ func (r *Joiner) performJoinRoomByID(
Redacts: "", Redacts: "",
} }
if err = eb.SetUnsigned(struct{}{}); err != nil { if err = eb.SetUnsigned(struct{}{}); err != nil {
return "", fmt.Errorf("eb.SetUnsigned: %w", err) return "", "", fmt.Errorf("eb.SetUnsigned: %w", err)
} }
// It is possible for the request to include some "content" for the // It is possible for the request to include some "content" for the
@ -180,7 +181,7 @@ func (r *Joiner) performJoinRoomByID(
} }
req.Content["membership"] = gomatrixserverlib.Join req.Content["membership"] = gomatrixserverlib.Join
if err = eb.SetContent(req.Content); err != nil { if err = eb.SetContent(req.Content); err != nil {
return "", fmt.Errorf("eb.SetContent: %w", err) return "", "", fmt.Errorf("eb.SetContent: %w", err)
} }
// Force a federated join if we aren't in the room and we've been // Force a federated join if we aren't in the room and we've been
@ -194,7 +195,7 @@ func (r *Joiner) performJoinRoomByID(
if err == nil && isInvitePending { if err == nil && isInvitePending {
_, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender)
if ierr != nil { if ierr != nil {
return "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err) return "", "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
} }
// If we were invited by someone from another server then we can // If we were invited by someone from another server then we can
@ -206,8 +207,10 @@ func (r *Joiner) performJoinRoomByID(
} }
// If we should do a forced federated join then do that. // If we should do a forced federated join then do that.
var joinedVia gomatrixserverlib.ServerName
if forceFederatedJoin { if forceFederatedJoin {
return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req) joinedVia, err = r.performFederatedJoinRoomByID(ctx, req)
return req.RoomIDOrAlias, joinedVia, err
} }
// Try to construct an actual join event from the template. // Try to construct an actual join event from the template.
@ -249,7 +252,7 @@ func (r *Joiner) performJoinRoomByID(
inputRes := api.InputRoomEventsResponse{} inputRes := api.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes)
if err = inputRes.Err(); err != nil { if err = inputRes.Err(); err != nil {
return "", &api.PerformError{ return "", "", &api.PerformError{
Code: api.PerformErrorNotAllowed, Code: api.PerformErrorNotAllowed,
Msg: fmt.Sprintf("InputRoomEvents auth failed: %s", err), Msg: fmt.Sprintf("InputRoomEvents auth failed: %s", err),
} }
@ -265,7 +268,7 @@ func (r *Joiner) performJoinRoomByID(
// Otherwise we'll try a federated join as normal, since it's quite // Otherwise we'll try a federated join as normal, since it's quite
// possible that the room still exists on other servers. // possible that the room still exists on other servers.
if len(req.ServerNames) == 0 { if len(req.ServerNames) == 0 {
return "", &api.PerformError{ return "", "", &api.PerformError{
Code: api.PerformErrorNoRoom, Code: api.PerformErrorNoRoom,
Msg: fmt.Sprintf("Room ID %q does not exist", req.RoomIDOrAlias), Msg: fmt.Sprintf("Room ID %q does not exist", req.RoomIDOrAlias),
} }
@ -273,24 +276,25 @@ func (r *Joiner) performJoinRoomByID(
} }
// Perform a federated room join. // Perform a federated room join.
return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req) joinedVia, err = r.performFederatedJoinRoomByID(ctx, req)
return req.RoomIDOrAlias, joinedVia, err
default: default:
// Something else went wrong. // Something else went wrong.
return "", fmt.Errorf("Error joining local room: %q", err) return "", "", fmt.Errorf("Error joining local room: %q", err)
} }
// By this point, if req.RoomIDOrAlias contained an alias, then // By this point, if req.RoomIDOrAlias contained an alias, then
// it will have been overwritten with a room ID by performJoinRoomByAlias. // it will have been overwritten with a room ID by performJoinRoomByAlias.
// We should now include this in the response so that the CS API can // We should now include this in the response so that the CS API can
// return the right room ID. // return the right room ID.
return req.RoomIDOrAlias, nil return req.RoomIDOrAlias, r.Cfg.Matrix.ServerName, nil
} }
func (r *Joiner) performFederatedJoinRoomByID( func (r *Joiner) performFederatedJoinRoomByID(
ctx context.Context, ctx context.Context,
req *api.PerformJoinRequest, req *api.PerformJoinRequest,
) error { ) (gomatrixserverlib.ServerName, error) {
// Try joining by all of the supplied server names. // Try joining by all of the supplied server names.
fedReq := fsAPI.PerformJoinRequest{ fedReq := fsAPI.PerformJoinRequest{
RoomID: req.RoomIDOrAlias, // the room ID to try and join RoomID: req.RoomIDOrAlias, // the room ID to try and join
@ -301,13 +305,13 @@ func (r *Joiner) performFederatedJoinRoomByID(
fedRes := fsAPI.PerformJoinResponse{} fedRes := fsAPI.PerformJoinResponse{}
r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes) r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes)
if fedRes.LastError != nil { if fedRes.LastError != nil {
return &api.PerformError{ return "", &api.PerformError{
Code: api.PerformErrRemote, Code: api.PerformErrRemote,
Msg: fedRes.LastError.Message, Msg: fedRes.LastError.Message,
RemoteCode: fedRes.LastError.Code, RemoteCode: fedRes.LastError.Code,
} }
} }
return nil return fedRes.JoinedVia, nil
} }
func buildEvent( func buildEvent(