diff --git a/clientapi/auth/sso/oauth2.go b/clientapi/auth/sso/oauth2.go index 4d62ce3a6..cd6186938 100644 --- a/clientapi/auth/sso/oauth2.go +++ b/clientapi/auth/sso/oauth2.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" "github.com/tidwall/gjson" ) @@ -130,29 +131,17 @@ func (p *oauth2IdentityProvider) getAccessToken(ctx context.Context, callbackURL hreq.Header.Set("Content-Type", "application/x-www-form-urlencoded") hreq.Header.Set("Accept", p.responseMimeType) - hresp, err := p.hc.Do(hreq) + hresp, err := httpDo(ctx, p.hc, hreq) if err != nil { - return "", err + return "", fmt.Errorf("access token: %w", err) } defer hresp.Body.Close() // nolint:errcheck - if hresp.StatusCode/100 != 2 { - return "", fmt.Errorf("OAuth2 access token request %q failed: %d %s", p.accessTokenURL, hresp.StatusCode, hresp.Status) - } - var resp oauth2TokenResponse if err := json.NewDecoder(hresp.Body).Decode(&resp); err != nil { return "", err } - if resp.Error != "" { - desc := resp.ErrorDescription - if desc == "" { - desc = resp.Error - } - return "", fmt.Errorf("failed to retrieve OIDC access token: %s", desc) - } - if strings.ToLower(resp.TokenType) != "bearer" { return "", fmt.Errorf("expected bearer token, got type %q", resp.TokenType) } @@ -163,10 +152,6 @@ func (p *oauth2IdentityProvider) getAccessToken(ctx context.Context, callbackURL type oauth2TokenResponse struct { TokenType string `json:"token_type"` AccessToken string `json:"access_token"` - - Error string `json:"error"` - ErrorDescription string `json:"error_description"` - ErrorURI string `json:"error_uri"` } func (p *oauth2IdentityProvider) getUserInfo(ctx context.Context, accessToken string) (subject, displayName, suggestedLocalpart string, _ error) { @@ -177,16 +162,12 @@ func (p *oauth2IdentityProvider) getUserInfo(ctx context.Context, accessToken st hreq.Header.Set("Authorization", "Bearer "+accessToken) hreq.Header.Set("Accept", p.responseMimeType) - hresp, err := p.hc.Do(hreq) + hresp, err := httpDo(ctx, p.hc, hreq) if err != nil { - return "", "", "", err + return "", "", "", fmt.Errorf("user info: %w", err) } defer hresp.Body.Close() // nolint:errcheck - if hresp.StatusCode/100 != 2 { - return "", "", "", fmt.Errorf("OAuth2 user info request %q failed: %d %s", p.userInfoURL, hresp.StatusCode, hresp.Status) - } - body, err := ioutil.ReadAll(hresp.Body) if err != nil { return "", "", "", err @@ -212,6 +193,56 @@ func (p *oauth2IdentityProvider) getUserInfo(ctx context.Context, accessToken st return } +func httpDo(ctx context.Context, hc *http.Client, req *http.Request) (*http.Response, error) { + resp, err := hc.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode/100 != 2 { + defer resp.Body.Close() + + contentType := resp.Header.Get("Content-Type") + switch { + case strings.HasPrefix(contentType, "text/plain"): + bs, err := ioutil.ReadAll(resp.Body) + if err == nil { + if len(bs) > 80 { + bs = bs[:80] + } + util.GetLogger(ctx).WithField("url", req.URL.String()).WithField("status", resp.StatusCode).Warnf("OAuth2 HTTP request failed: %s", string(bs)) + } + case strings.HasPrefix(contentType, "application/json"): + // https://openid.net/specs/openid-connect-core-1_0.html#TokenErrorResponse + var body oauth2Error + if err := json.NewDecoder(resp.Body).Decode(&body); err == nil { + util.GetLogger(ctx).WithField("url", req.URL.String()).WithField("status", resp.StatusCode).Warnf("OAuth2 HTTP request failed: %+v", &body) + } + if body.Error != "" { + return nil, fmt.Errorf("OAuth2 request %q failed: %s (%s)", req.URL.String(), resp.Status, body.Error) + } + } + + if hdr := resp.Header.Get("WWW-Authenticate"); hdr != "" { + // https://openid.net/specs/openid-connect-core-1_0.html#UserInfoError + if len(hdr) > 80 { + hdr = hdr[:80] + } + return nil, fmt.Errorf("OAuth2 request %q failed: %s (%s)", req.URL.String(), resp.Status, hdr) + } + + return nil, fmt.Errorf("OAuth2 HTTP request %q failed: %s", req.URL.String(), resp.Status) + } + + return resp, nil +} + +type oauth2Error struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + ErrorURI string `json:"error_uri"` +} + func resolveURL(urlString string, defaultQuery url.Values) (*url.URL, error) { u, err := url.Parse(urlString) if err != nil {