Implement OpenID module (#599)

- Unrelated: change Riot references to Element in client API routing

Signed-off-by: Bruce MacDonald <contact@bruce-macdonald.com>
This commit is contained in:
Bruce MacDonald 2021-03-21 18:40:38 -07:00
parent 01267a34b9
commit 5f1272c74f
16 changed files with 520 additions and 3 deletions

View file

@ -0,0 +1,70 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
type openIDTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
MatrixServerName string `json:"matrix_server_name"`
ExpiresIn int64 `json:"expires_in"`
}
// CreateOpenIDToken creates a new OpenID Connect (OIDC) token that a Matrix user
// can supply to an OpenID Relying Party to verify their identity
func CreateOpenIDToken(
req *http.Request,
userAPI api.UserInternalAPI,
device *api.Device,
userID string,
cfg *config.ClientAPI,
) util.JSONResponse {
// does the incoming user ID match the user that the token was issued for?
if userID != device.UserID {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Cannot request tokens for other users"),
}
}
request := api.PerformOpenIDTokenCreationRequest{
UserID: userID, // this is the user ID from the incoming path
}
response := api.PerformOpenIDTokenCreationResponse{}
err := userAPI.PerformOpenIDTokenCreation(req.Context(), &request, &response)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.CreateOpenIDToken failed")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: openIDTokenResponse{
AccessToken: response.Token.Token,
TokenType: "Bearer",
MatrixServerName: string(cfg.Matrix.ServerName),
ExpiresIn: response.Token.ExpiresTS / 1000, // convert ms to s
},
}
}

View file

@ -469,7 +469,7 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
// Stub endpoints required by Riot // Stub endpoints required by Element
r0mux.Handle("/login", r0mux.Handle("/login",
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
@ -506,7 +506,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
// Riot user settings // Element user settings
r0mux.Handle("/profile/{userID}", r0mux.Handle("/profile/{userID}",
httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse {
@ -592,7 +592,7 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
// Riot logs get flooded unless this is handled // Element logs get flooded unless this is handled
r0mux.Handle("/presence/{userID}/status", r0mux.Handle("/presence/{userID}/status",
httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse {
if r := rateLimits.rateLimit(req); r != nil { if r := rateLimits.rateLimit(req); r != nil {
@ -685,6 +685,19 @@ func Setup(
}), }),
).Methods(http.MethodGet) ).Methods(http.MethodGet)
r0mux.Handle("/user/{userId}/openid/request_token",
httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.rateLimit(req); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return CreateOpenIDToken(req, userAPI, device, vars["userId"], cfg)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/user_directory/search", r0mux.Handle("/user_directory/search",
httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.rateLimit(req); r != nil { if r := rateLimits.rateLimit(req); r != nil {

View file

@ -0,0 +1,65 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"net/http"
"time"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
type openIDUserInfoResponse struct {
Sub string `json:"sub"`
}
// GetOpenIDUserInfo implements GET /_matrix/federation/v1/openid/userinfo
func GetOpenIDUserInfo(
httpReq *http.Request,
userAPI userapi.UserInternalAPI,
) util.JSONResponse {
token := httpReq.URL.Query().Get("access_token")
if len(token) == 0 {
return util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.MissingArgument("access_token is missing"),
}
}
req := userapi.QueryOpenIDTokenRequest{
Token: token,
}
var openIDTokenAttrResponse userapi.QueryOpenIDTokenResponse
err := userAPI.QueryOpenIDToken(httpReq.Context(), &req, &openIDTokenAttrResponse)
if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("userAPI.QueryOpenIDToken failed")
}
var res interface{} = openIDUserInfoResponse{Sub: openIDTokenAttrResponse.Sub}
code := http.StatusOK
nowMS := time.Now().UnixNano() / int64(time.Millisecond)
if openIDTokenAttrResponse.Sub == "" || nowMS > openIDTokenAttrResponse.ExpiresTS {
code = http.StatusUnauthorized
res = jsonerror.UnknownToken("Access Token unknown or expired")
}
return util.JSONResponse{
Code: code,
JSON: res,
}
}

View file

@ -460,4 +460,10 @@ func Setup(
return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName)
}, },
)).Methods(http.MethodPost) )).Methods(http.MethodPost)
v1fedmux.Handle("/openid/userinfo",
httputil.MakeExternalAPI("federation_openid_userinfo", func(req *http.Request) util.JSONResponse {
return GetOpenIDUserInfo(req, userAPI)
}),
).Methods(http.MethodGet)
} }

View file

@ -524,6 +524,9 @@ func (u *testUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.Pe
func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error { func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error {
return nil return nil
} }
func (u *testUserAPI) PerformOpenIDTokenCreation(ctx context.Context, req *userapi.PerformOpenIDTokenCreationRequest, res *userapi.PerformOpenIDTokenCreationResponse) error {
return nil
}
func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error { func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error {
return nil return nil
} }
@ -548,6 +551,9 @@ func (u *testUserAPI) QueryDeviceInfos(ctx context.Context, req *userapi.QueryDe
func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error { func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error {
return nil return nil
} }
func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOpenIDTokenRequest, res *userapi.QueryOpenIDTokenResponse) error {
return nil
}
type testRoomserverAPI struct { type testRoomserverAPI struct {
// use a trace API as it implements method stubs so we don't need to have them here. // use a trace API as it implements method stubs so we don't need to have them here.

View file

@ -367,6 +367,9 @@ func (u *testUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.Pe
func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error { func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error {
return nil return nil
} }
func (u *testUserAPI) PerformOpenIDTokenCreation(ctx context.Context, req *userapi.PerformOpenIDTokenCreationRequest, res *userapi.PerformOpenIDTokenCreationResponse) error {
return nil
}
func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error { func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error {
return nil return nil
} }
@ -391,6 +394,9 @@ func (u *testUserAPI) QueryDeviceInfos(ctx context.Context, req *userapi.QueryDe
func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error { func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error {
return nil return nil
} }
func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOpenIDTokenRequest, res *userapi.QueryOpenIDTokenResponse) error {
return nil
}
type testRoomserverAPI struct { type testRoomserverAPI struct {
// use a trace API as it implements method stubs so we don't need to have them here. // use a trace API as it implements method stubs so we don't need to have them here.

View file

@ -517,3 +517,6 @@ AS can set avatar for ghosted users
AS can set displayname for ghosted users AS can set displayname for ghosted users
Ghost user must register before joining room Ghost user must register before joining room
Inviting an AS-hosted user asks the AS server Inviting an AS-hosted user asks the AS server
Can generate a openid access_token that can be exchanged for information about a user
Invalid openid access tokens are rejected
Requests to userinfo without access tokens are rejected

View file

@ -32,12 +32,14 @@ type UserInternalAPI interface {
PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
} }
// InputAccountDataRequest is the request for InputAccountData // InputAccountDataRequest is the request for InputAccountData
@ -226,6 +228,27 @@ type PerformAccountDeactivationResponse struct {
AccountDeactivated bool AccountDeactivated bool
} }
// PerformOpenIDTokenCreationRequest is the request for PerformOpenIDTokenCreation
type PerformOpenIDTokenCreationRequest struct {
UserID string
}
// PerformOpenIDTokenCreationResponse is the response for PerformOpenIDTokenCreation
type PerformOpenIDTokenCreationResponse struct {
Token OpenIDToken
}
// QueryOpenIDTokenRequest is the request for QueryOpenIDToken
type QueryOpenIDTokenRequest struct {
Token string
}
// QueryOpenIDTokenResponse is the response for QueryOpenIDToken
type QueryOpenIDTokenResponse struct {
Sub string // The Matrix User ID that generated the token
ExpiresTS int64
}
// Device represents a client's device (mobile, web, etc) // Device represents a client's device (mobile, web, etc)
type Device struct { type Device struct {
ID string ID string
@ -256,6 +279,24 @@ type Account struct {
// TODO: Associations (e.g. with application services) // TODO: Associations (e.g. with application services)
} }
// OpenIDToken represents an OpenID token
type OpenIDToken struct {
Token string
UserID string
ExpiresTS int64
}
// OpenIDTokenInfo represents the attributes associated with an issued OpenID token
type OpenIDTokenAttributes struct {
UserID string
ExpiresTS int64
}
// UserInfo is for returning information about the user an OpenID token was issued for
type UserInfo struct {
Sub string // The Matrix user's ID who generated the token
}
// ErrorForbidden is an error indicating that the supplied access token is forbidden // ErrorForbidden is an error indicating that the supplied access token is forbidden
type ErrorForbidden struct { type ErrorForbidden struct {
Message string Message string

View file

@ -20,6 +20,9 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"hash"
"hash/fnv"
"time"
"github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
@ -414,3 +417,42 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
res.AccountDeactivated = err == nil res.AccountDeactivated = err == nil
return err return err
} }
// PerformOpenIDTokenCreation creates a new token that a relying party uses to authenticate a user
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
token := util.RandomString(24)
tokenHash := getTokenHash(token)
nowMS := time.Now().UnixNano() / int64(time.Millisecond)
expiresMS := nowMS + (3600 * 1000) // 60 minutes
err := a.AccountDB.CreateOpenIDToken(ctx, fmt.Sprint(tokenHash.Sum32()), req.UserID, expiresMS)
res.Token = api.OpenIDToken{
Token: token,
UserID: req.UserID,
ExpiresTS: expiresMS,
}
return err
}
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
tokenHash := getTokenHash(req.Token)
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, fmt.Sprint(tokenHash.Sum32()))
if err != nil {
return err
}
res.Sub = openIDTokenAttrs.UserID
res.ExpiresTS = openIDTokenAttrs.ExpiresTS
return nil
}
// used for getting the format in which the token is stored to prevent plaintext storage of sensitive info
func getTokenHash(token string) hash.Hash32 {
tokenHash := fnv.New32a()
_, _ = tokenHash.Write([]byte(token))
return tokenHash
}

View file

@ -35,6 +35,7 @@ const (
PerformLastSeenUpdatePath = "/userapi/performLastSeenUpdate" PerformLastSeenUpdatePath = "/userapi/performLastSeenUpdate"
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation"
QueryProfilePath = "/userapi/queryProfile" QueryProfilePath = "/userapi/queryProfile"
QueryAccessTokenPath = "/userapi/queryAccessToken" QueryAccessTokenPath = "/userapi/queryAccessToken"
@ -42,6 +43,7 @@ const (
QueryAccountDataPath = "/userapi/queryAccountData" QueryAccountDataPath = "/userapi/queryAccountData"
QueryDeviceInfosPath = "/userapi/queryDeviceInfos" QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
QuerySearchProfilesPath = "/userapi/querySearchProfiles" QuerySearchProfilesPath = "/userapi/querySearchProfiles"
QueryOpenIDTokenPath = "/userapi/queryOpenIDToken"
) )
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -148,6 +150,14 @@ func (h *httpUserInternalAPI) PerformAccountDeactivation(ctx context.Context, re
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, request *api.PerformOpenIDTokenCreationRequest, response *api.PerformOpenIDTokenCreationResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOpenIDTokenCreation")
defer span.Finish()
apiURL := h.apiURL + PerformOpenIDTokenCreationPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) QueryProfile( func (h *httpUserInternalAPI) QueryProfile(
ctx context.Context, ctx context.Context,
request *api.QueryProfileRequest, request *api.QueryProfileRequest,
@ -207,3 +217,11 @@ func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.
apiURL := h.apiURL + QuerySearchProfilesPath apiURL := h.apiURL + QuerySearchProfilesPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOpenIDToken")
defer span.Finish()
apiURL := h.apiURL + QueryOpenIDTokenPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -117,6 +117,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(PerformOpenIDTokenCreationPath,
httputil.MakeInternalAPI("performOpenIDTokenCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformOpenIDTokenCreationRequest{}
response := api.PerformOpenIDTokenCreationResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformOpenIDTokenCreation(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryProfilePath, internalAPIMux.Handle(QueryProfilePath,
httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse {
request := api.QueryProfileRequest{} request := api.QueryProfileRequest{}
@ -195,6 +208,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(QueryOpenIDTokenPath,
httputil.MakeInternalAPI("queryOpenIDToken", func(req *http.Request) util.JSONResponse {
request := api.QueryOpenIDTokenRequest{}
response := api.QueryOpenIDTokenResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryOpenIDToken(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(InputAccountDataPath, internalAPIMux.Handle(InputAccountDataPath,
httputil.MakeInternalAPI("inputAccountDataPath", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("inputAccountDataPath", func(req *http.Request) util.JSONResponse {
request := api.InputAccountDataRequest{} request := api.InputAccountDataRequest{}

View file

@ -52,6 +52,8 @@ type Database interface {
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
DeactivateAccount(ctx context.Context, localpart string) (err error) DeactivateAccount(ctx context.Context, localpart string) (err error)
CreateOpenIDToken(ctx context.Context, tokenHash, localpart string, expirationTS int64) (err error)
GetOpenIDTokenAttributes(ctx context.Context, tokenHash string) (*api.OpenIDTokenAttributes, error)
} }
// Err3PIDInUse is the error returned when trying to save an association involving // Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -0,0 +1,84 @@
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
const openIDTokenSchema = `
-- Stores data about openid tokens issued for accounts.
CREATE TABLE IF NOT EXISTS open_id_tokens (
-- This is a hash the token value
token_hash TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account
localpart TEXT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution).
token_expires_ts BIGINT NOT NULL
);
`
const insertTokenSQL = "" +
"INSERT INTO open_id_tokens(token_hash, localpart, token_expires_ts) VALUES ($1, $2, $3)"
const selectTokenSQL = "" +
"SELECT localpart, token_expires_ts FROM open_id_tokens WHERE token_hash = $1"
type tokenStatements struct {
insertTokenStmt *sql.Stmt
selectTokenStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
_, err = db.Exec(openIDTokenSchema)
if err != nil {
return
}
if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil {
return
}
if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil {
return
}
s.serverName = server
return
}
// insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if token hash already exists.
func (s *tokenStatements) insertToken(
ctx context.Context,
txn *sql.Tx,
tokenHash, localpart string,
expiresTimeMS int64,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, tokenHash, localpart, expiresTimeMS)
return
}
// selectOpenIDTokenAtrributesByTokenHash gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributesByTokenHash(
ctx context.Context,
tokenHash string,
) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes
err := s.selectTokenStmt.QueryRowContext(ctx, tokenHash).Scan(
&openIDTokenAttrs.UserID,
&openIDTokenAttrs.ExpiresTS,
)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve token from the db")
}
return nil, err
}
return &openIDTokenAttrs, nil
}

View file

@ -43,6 +43,7 @@ type Database struct {
profiles profilesStatements profiles profilesStatements
accountDatas accountDataStatements accountDatas accountDataStatements
threepids threepidStatements threepids threepidStatements
openIDTokens tokenStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
bcryptCost int bcryptCost int
} }
@ -86,6 +87,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.threepids.prepare(db); err != nil { if err = d.threepids.prepare(db); err != nil {
return nil, err return nil, err
} }
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
return d, nil return d, nil
} }
@ -341,3 +345,22 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
return d.accounts.deactivateAccount(ctx, localpart) return d.accounts.deactivateAccount(ctx, localpart)
} }
// CreateOpenIDToken persists a new token that was issued through OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
tokenHash, localpart string,
expirationTS int64,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, tokenHash, localpart, expirationTS)
})
}
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context,
tokenHash string,
) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributesByTokenHash(ctx, tokenHash)
}

View file

@ -0,0 +1,86 @@
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
const openIDTokenSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS open_id_tokens (
-- This is the hash of the token value
token_hash TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account
localpart TEXT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution).
token_expires_ts BIGINT NOT NULL
);
`
const insertTokenSQL = "" +
"INSERT INTO open_id_tokens(token_hash, localpart, token_expires_ts) VALUES ($1, $2, $3)"
const selectTokenSQL = "" +
"SELECT localpart, token_expires_ts FROM open_id_tokens WHERE token_hash = $1"
type tokenStatements struct {
db *sql.DB
insertTokenStmt *sql.Stmt
selectTokenStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
_, err = db.Exec(openIDTokenSchema)
if err != nil {
return err
}
if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil {
return
}
if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil {
return
}
s.serverName = server
return
}
// insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if token hash already exists.
func (s *tokenStatements) insertToken(
ctx context.Context,
txn *sql.Tx,
token_hash, localpart string,
expiresTimeMS int64,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, token_hash, localpart, expiresTimeMS)
return
}
// selectOpenIDTokenAtrributesByTokenHash gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributesByTokenHash(
ctx context.Context,
tokenHash string,
) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes
err := s.selectTokenStmt.QueryRowContext(ctx, tokenHash).Scan(
&openIDTokenAttrs.UserID,
&openIDTokenAttrs.ExpiresTS,
)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve token from the db")
}
return nil, err
}
return &openIDTokenAttrs, nil
}

View file

@ -41,6 +41,7 @@ type Database struct {
profiles profilesStatements profiles profilesStatements
accountDatas accountDataStatements accountDatas accountDataStatements
threepids threepidStatements threepids threepidStatements
openIDTokens tokenStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
bcryptCost int bcryptCost int
@ -48,6 +49,7 @@ type Database struct {
profilesMu sync.Mutex profilesMu sync.Mutex
accountDatasMu sync.Mutex accountDatasMu sync.Mutex
threepidsMu sync.Mutex threepidsMu sync.Mutex
openIDsMu sync.Mutex
} }
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
@ -90,6 +92,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.threepids.prepare(db); err != nil { if err = d.threepids.prepare(db); err != nil {
return nil, err return nil, err
} }
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
return d, nil return d, nil
} }
@ -379,3 +384,24 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
return d.accounts.deactivateAccount(ctx, localpart) return d.accounts.deactivateAccount(ctx, localpart)
} }
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
tokenHash, localpart string,
expirationTS int64,
) error {
d.openIDsMu.Lock()
defer d.openIDsMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, tokenHash, localpart, expirationTS)
})
}
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context,
tokenHash string,
) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributesByTokenHash(ctx, tokenHash)
}