Small tweaks to send_join to use spec roomid/userid types

This commit is contained in:
Devon Hudson 2023-05-17 09:25:21 -06:00
parent 67d6876857
commit 3ca9859bb6
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
2 changed files with 29 additions and 14 deletions

View file

@ -229,9 +229,10 @@ func SendJoin(
cfg *config.FederationAPI,
rsAPI api.FederationRoomserverAPI,
keys gomatrixserverlib.JSONVerifier,
roomID, eventID string,
roomID spec.RoomID,
eventID string,
) util.JSONResponse {
roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID)
roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID.String())
if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryRoomVersionForRoom failed")
return util.JSONResponse{
@ -274,13 +275,13 @@ func SendJoin(
// Check that the sender belongs to the server that is sending us
// the request. By this point we've already asserted that the sender
// and the state key are equal so we don't need to check both.
var serverName spec.ServerName
if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil {
sender, err := spec.NewUserID(event.Sender(), true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("The sender of the join is invalid"),
}
} else if serverName != request.Origin() {
} else if sender.Domain() != request.Origin() {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("The sender does not match the server that originated the request"),
@ -288,7 +289,7 @@ func SendJoin(
}
// Check that the room ID is correct.
if event.RoomID() != roomID {
if event.RoomID() != roomID.String() {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON(
@ -338,7 +339,7 @@ func SendJoin(
}
}
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
ServerName: serverName,
ServerName: sender.Domain(),
Message: redacted,
AtTS: event.OriginServerTS(),
StrictValidityChecking: true,
@ -364,7 +365,7 @@ func SendJoin(
err = rsAPI.QueryStateAndAuthChain(httpReq.Context(), &api.QueryStateAndAuthChainRequest{
PrevEventIDs: event.PrevEventIDs(),
AuthEventIDs: event.AuthEventIDs(),
RoomID: roomID,
RoomID: roomID.String(),
ResolveState: true,
}, &stateAndAuthChainResponse)
if err != nil {

View file

@ -331,14 +331,14 @@ func Setup(
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("Invalid UserID"),
JSON: spec.InvalidParam("Invalid UserID"),
}
}
roomID, err := spec.NewRoomID(vars["roomID"])
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("Invalid RoomID"),
JSON: spec.InvalidParam("Invalid RoomID"),
}
}
@ -358,10 +358,17 @@ func Setup(
JSON: spec.Forbidden("Forbidden by server ACLs"),
}
}
roomID := vars["roomID"]
eventID := vars["eventID"]
roomID, err := spec.NewRoomID(vars["roomID"])
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("Invalid RoomID"),
}
}
res := SendJoin(
httpReq, request, cfg, rsAPI, keys, roomID, eventID,
httpReq, request, cfg, rsAPI, keys, *roomID, eventID,
)
// not all responses get wrapped in [code, body]
var body interface{}
@ -390,10 +397,17 @@ func Setup(
JSON: spec.Forbidden("Forbidden by server ACLs"),
}
}
roomID := vars["roomID"]
eventID := vars["eventID"]
roomID, err := spec.NewRoomID(vars["roomID"])
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("Invalid RoomID"),
}
}
return SendJoin(
httpReq, request, cfg, rsAPI, keys, roomID, eventID,
httpReq, request, cfg, rsAPI, keys, *roomID, eventID,
)
},
)).Methods(http.MethodPut)