Small tweaks to send_join to use spec roomid/userid types

This commit is contained in:
Devon Hudson 2023-05-17 09:25:21 -06:00
parent 67d6876857
commit 3ca9859bb6
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
2 changed files with 29 additions and 14 deletions

View file

@ -229,9 +229,10 @@ func SendJoin(
cfg *config.FederationAPI, cfg *config.FederationAPI,
rsAPI api.FederationRoomserverAPI, rsAPI api.FederationRoomserverAPI,
keys gomatrixserverlib.JSONVerifier, keys gomatrixserverlib.JSONVerifier,
roomID, eventID string, roomID spec.RoomID,
eventID string,
) util.JSONResponse { ) util.JSONResponse {
roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID) roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID.String())
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryRoomVersionForRoom failed") util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryRoomVersionForRoom failed")
return util.JSONResponse{ return util.JSONResponse{
@ -274,13 +275,13 @@ func SendJoin(
// Check that the sender belongs to the server that is sending us // Check that the sender belongs to the server that is sending us
// the request. By this point we've already asserted that the sender // the request. By this point we've already asserted that the sender
// and the state key are equal so we don't need to check both. // and the state key are equal so we don't need to check both.
var serverName spec.ServerName sender, err := spec.NewUserID(event.Sender(), true)
if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden("The sender of the join is invalid"), JSON: spec.Forbidden("The sender of the join is invalid"),
} }
} else if serverName != request.Origin() { } else if sender.Domain() != request.Origin() {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
JSON: spec.Forbidden("The sender does not match the server that originated the request"), JSON: spec.Forbidden("The sender does not match the server that originated the request"),
@ -288,7 +289,7 @@ func SendJoin(
} }
// Check that the room ID is correct. // Check that the room ID is correct.
if event.RoomID() != roomID { if event.RoomID() != roomID.String() {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: spec.BadJSON( JSON: spec.BadJSON(
@ -338,7 +339,7 @@ func SendJoin(
} }
} }
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
ServerName: serverName, ServerName: sender.Domain(),
Message: redacted, Message: redacted,
AtTS: event.OriginServerTS(), AtTS: event.OriginServerTS(),
StrictValidityChecking: true, StrictValidityChecking: true,
@ -364,7 +365,7 @@ func SendJoin(
err = rsAPI.QueryStateAndAuthChain(httpReq.Context(), &api.QueryStateAndAuthChainRequest{ err = rsAPI.QueryStateAndAuthChain(httpReq.Context(), &api.QueryStateAndAuthChainRequest{
PrevEventIDs: event.PrevEventIDs(), PrevEventIDs: event.PrevEventIDs(),
AuthEventIDs: event.AuthEventIDs(), AuthEventIDs: event.AuthEventIDs(),
RoomID: roomID, RoomID: roomID.String(),
ResolveState: true, ResolveState: true,
}, &stateAndAuthChainResponse) }, &stateAndAuthChainResponse)
if err != nil { if err != nil {

View file

@ -331,14 +331,14 @@ func Setup(
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: spec.BadJSON("Invalid UserID"), JSON: spec.InvalidParam("Invalid UserID"),
} }
} }
roomID, err := spec.NewRoomID(vars["roomID"]) roomID, err := spec.NewRoomID(vars["roomID"])
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: spec.BadJSON("Invalid RoomID"), JSON: spec.InvalidParam("Invalid RoomID"),
} }
} }
@ -358,10 +358,17 @@ func Setup(
JSON: spec.Forbidden("Forbidden by server ACLs"), JSON: spec.Forbidden("Forbidden by server ACLs"),
} }
} }
roomID := vars["roomID"]
eventID := vars["eventID"] eventID := vars["eventID"]
roomID, err := spec.NewRoomID(vars["roomID"])
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("Invalid RoomID"),
}
}
res := SendJoin( res := SendJoin(
httpReq, request, cfg, rsAPI, keys, roomID, eventID, httpReq, request, cfg, rsAPI, keys, *roomID, eventID,
) )
// not all responses get wrapped in [code, body] // not all responses get wrapped in [code, body]
var body interface{} var body interface{}
@ -390,10 +397,17 @@ func Setup(
JSON: spec.Forbidden("Forbidden by server ACLs"), JSON: spec.Forbidden("Forbidden by server ACLs"),
} }
} }
roomID := vars["roomID"]
eventID := vars["eventID"] eventID := vars["eventID"]
roomID, err := spec.NewRoomID(vars["roomID"])
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("Invalid RoomID"),
}
}
return SendJoin( return SendJoin(
httpReq, request, cfg, rsAPI, keys, roomID, eventID, httpReq, request, cfg, rsAPI, keys, *roomID, eventID,
) )
}, },
)).Methods(http.MethodPut) )).Methods(http.MethodPut)