From 3ca9859bb6e588641956ea55ab32457ea505cb51 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 17 May 2023 09:25:21 -0600 Subject: [PATCH] Small tweaks to send_join to use spec roomid/userid types --- federationapi/routing/join.go | 17 +++++++++-------- federationapi/routing/routing.go | 26 ++++++++++++++++++++------ 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index cc22690a9..d44491a76 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -229,9 +229,10 @@ func SendJoin( cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, keys gomatrixserverlib.JSONVerifier, - roomID, eventID string, + roomID spec.RoomID, + eventID string, ) util.JSONResponse { - roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID) + roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID.String()) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryRoomVersionForRoom failed") return util.JSONResponse{ @@ -274,13 +275,13 @@ func SendJoin( // Check that the sender belongs to the server that is sending us // the request. By this point we've already asserted that the sender // and the state key are equal so we don't need to check both. - var serverName spec.ServerName - if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { + sender, err := spec.NewUserID(event.Sender(), true) + if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("The sender of the join is invalid"), } - } else if serverName != request.Origin() { + } else if sender.Domain() != request.Origin() { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("The sender does not match the server that originated the request"), @@ -288,7 +289,7 @@ func SendJoin( } // Check that the room ID is correct. - if event.RoomID() != roomID { + if event.RoomID() != roomID.String() { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON( @@ -338,7 +339,7 @@ func SendJoin( } } verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ - ServerName: serverName, + ServerName: sender.Domain(), Message: redacted, AtTS: event.OriginServerTS(), StrictValidityChecking: true, @@ -364,7 +365,7 @@ func SendJoin( err = rsAPI.QueryStateAndAuthChain(httpReq.Context(), &api.QueryStateAndAuthChainRequest{ PrevEventIDs: event.PrevEventIDs(), AuthEventIDs: event.AuthEventIDs(), - RoomID: roomID, + RoomID: roomID.String(), ResolveState: true, }, &stateAndAuthChainResponse) if err != nil { diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 44faad918..7be0857a6 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -331,14 +331,14 @@ func Setup( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: spec.BadJSON("Invalid UserID"), + JSON: spec.InvalidParam("Invalid UserID"), } } roomID, err := spec.NewRoomID(vars["roomID"]) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: spec.BadJSON("Invalid RoomID"), + JSON: spec.InvalidParam("Invalid RoomID"), } } @@ -358,10 +358,17 @@ func Setup( JSON: spec.Forbidden("Forbidden by server ACLs"), } } - roomID := vars["roomID"] eventID := vars["eventID"] + roomID, err := spec.NewRoomID(vars["roomID"]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid RoomID"), + } + } + res := SendJoin( - httpReq, request, cfg, rsAPI, keys, roomID, eventID, + httpReq, request, cfg, rsAPI, keys, *roomID, eventID, ) // not all responses get wrapped in [code, body] var body interface{} @@ -390,10 +397,17 @@ func Setup( JSON: spec.Forbidden("Forbidden by server ACLs"), } } - roomID := vars["roomID"] eventID := vars["eventID"] + roomID, err := spec.NewRoomID(vars["roomID"]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid RoomID"), + } + } + return SendJoin( - httpReq, request, cfg, rsAPI, keys, roomID, eventID, + httpReq, request, cfg, rsAPI, keys, *roomID, eventID, ) }, )).Methods(http.MethodPut)