diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index 726aaa03f..812d3c6da 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -158,7 +158,7 @@ type getUserDevices struct { func (h *httpFederationInternalAPI) GetUserDevices( ctx context.Context, s gomatrixserverlib.ServerName, userID string, ) (gomatrixserverlib.RespUserDevices, error) { - return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, api.FederationClientError]( + return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError]( "GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient, ctx, &getUserDevices{ S: s, @@ -175,7 +175,7 @@ type claimKeys struct { func (h *httpFederationInternalAPI) ClaimKeys( ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ) (gomatrixserverlib.RespClaimKeys, error) { - return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, api.FederationClientError]( + return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError]( "ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient, ctx, &claimKeys{ S: s, @@ -192,7 +192,7 @@ type queryKeys struct { func (h *httpFederationInternalAPI) QueryKeys( ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, ) (gomatrixserverlib.RespQueryKeys, error) { - return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, api.FederationClientError]( + return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError]( "QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient, ctx, &queryKeys{ S: s, @@ -211,7 +211,7 @@ type backfill struct { func (h *httpFederationInternalAPI) Backfill( ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ) (gomatrixserverlib.Transaction, error) { - return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, api.FederationClientError]( + return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError]( "Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient, ctx, &backfill{ S: s, @@ -232,7 +232,7 @@ type lookupState struct { func (h *httpFederationInternalAPI) LookupState( ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (gomatrixserverlib.RespState, error) { - return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, api.FederationClientError]( + return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError]( "LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient, ctx, &lookupState{ S: s, @@ -252,7 +252,7 @@ type lookupStateIDs struct { func (h *httpFederationInternalAPI) LookupStateIDs( ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, ) (gomatrixserverlib.RespStateIDs, error) { - return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, api.FederationClientError]( + return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError]( "LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient, ctx, &lookupStateIDs{ S: s, @@ -273,7 +273,7 @@ func (h *httpFederationInternalAPI) LookupMissingEvents( ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.RespMissingEvents, err error) { - return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, api.FederationClientError]( + return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError]( "LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient, ctx, &lookupMissingEvents{ S: s, @@ -292,7 +292,7 @@ type getEvent struct { func (h *httpFederationInternalAPI) GetEvent( ctx context.Context, s gomatrixserverlib.ServerName, eventID string, ) (gomatrixserverlib.Transaction, error) { - return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, api.FederationClientError]( + return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError]( "GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient, ctx, &getEvent{ S: s, @@ -312,7 +312,7 @@ func (h *httpFederationInternalAPI) GetEventAuth( ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, ) (gomatrixserverlib.RespEventAuth, error) { - return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, api.FederationClientError]( + return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError]( "GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient, ctx, &getEventAuth{ S: s, @@ -340,7 +340,7 @@ type lookupServerKeys struct { func (h *httpFederationInternalAPI) LookupServerKeys( ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, ) ([]gomatrixserverlib.ServerKeys, error) { - return httputil.CallInternalProxyAPI[lookupServerKeys, []gomatrixserverlib.ServerKeys, api.FederationClientError]( + return httputil.CallInternalProxyAPI[lookupServerKeys, []gomatrixserverlib.ServerKeys, *api.FederationClientError]( "LookupServerKeys", h.federationAPIURL+FederationAPILookupServerKeysPath, h.httpClient, ctx, &lookupServerKeys{ S: s, @@ -359,7 +359,7 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships( ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { - return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, api.FederationClientError]( + return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError]( "MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient, ctx, &eventRelationships{ S: s, @@ -378,7 +378,7 @@ type spacesReq struct { func (h *httpFederationInternalAPI) MSC2946Spaces( ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { - return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, api.FederationClientError]( + return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError]( "MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient, ctx, &spacesReq{ S: dst, diff --git a/internal/httputil/http.go b/internal/httputil/http.go index 0788ac69d..a700c9684 100644 --- a/internal/httputil/http.go +++ b/internal/httputil/http.go @@ -19,6 +19,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "net/url" "strings" @@ -30,9 +31,9 @@ import ( // PostJSON performs a POST request with JSON on an internal HTTP API. // The error will match the errtype if returned from the remote API, or // will be a different type if there was a problem reaching the API. -func PostJSON[reqtype, restype, errtype any]( +func PostJSON[reqtype, restype any, errtype error]( ctx context.Context, span opentracing.Span, httpClient *http.Client, - apiURL string, request *reqtype, response *restype, reserr *errtype, + apiURL string, request *reqtype, response *restype, ) error { jsonBytes, err := json.Marshal(request) if err != nil { @@ -70,14 +71,23 @@ func PostJSON[reqtype, restype, errtype any]( if err != nil { return err } + var body []byte + body, err = io.ReadAll(res.Body) + if err != nil { + return err + } if res.StatusCode != http.StatusOK { - if err = json.NewDecoder(res.Body).Decode(reserr); err != nil { + if len(body) == 0 { return fmt.Errorf("HTTP %d from %s", res.StatusCode, apiURL) } - return nil + var reserr errtype + if err = json.Unmarshal(body, reserr); err != nil { + return fmt.Errorf("HTTP %d from %s", res.StatusCode, apiURL) + } + return reserr } - if err = json.NewDecoder(res.Body).Decode(response); err != nil { - return fmt.Errorf("json.NewDecoder.Decode: %w", err) + if err = json.Unmarshal(body, response); err != nil { + return fmt.Errorf("json.Unmarshal: %w", err) } return nil } diff --git a/internal/httputil/internalapi.go b/internal/httputil/internalapi.go index a366631f8..385092d9c 100644 --- a/internal/httputil/internalapi.go +++ b/internal/httputil/internalapi.go @@ -81,25 +81,13 @@ func CallInternalRPCAPI[reqtype, restype any](name, url string, client *http.Cli span, ctx := opentracing.StartSpanFromContext(ctx, name) defer span.Finish() - var reserr *InternalAPIError - if err := PostJSON(ctx, span, client, url, request, response, reserr); err != nil { - return err - } else if reserr != nil { - return reserr - } - return nil // must be untyped nil + return PostJSON[reqtype, restype, InternalAPIError](ctx, span, client, url, request, response) } -func CallInternalProxyAPI[req, res any, errtype error](name, url string, client *http.Client, ctx context.Context, request *req) (res, error) { +func CallInternalProxyAPI[reqtype, restype any, errtype error](name, url string, client *http.Client, ctx context.Context, request *reqtype) (restype, error) { span, ctx := opentracing.StartSpanFromContext(ctx, name) defer span.Finish() - var response res - var reserr *errtype - if err := PostJSON(ctx, span, client, url, request, &response, &reserr); err != nil { - return response, err - } else if reserr != nil { - return response, *reserr - } - return response, nil // must be untyped nil + var response restype + return response, PostJSON[reqtype, restype, errtype](ctx, span, client, url, request, &response) }