From 5f1272c74f7e6123bcaadc573b78f1c6d0ef2e54 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Sun, 21 Mar 2021 18:40:38 -0700 Subject: [PATCH] Implement OpenID module (#599) - Unrelated: change Riot references to Element in client API routing Signed-off-by: Bruce MacDonald --- clientapi/routing/openid.go | 70 +++++++++++++++ clientapi/routing/routing.go | 19 +++- federationapi/routing/openid.go | 65 ++++++++++++++ federationapi/routing/routing.go | 6 ++ setup/mscs/msc2836/msc2836_test.go | 6 ++ setup/mscs/msc2946/msc2946_test.go | 6 ++ sytest-whitelist | 3 + userapi/api/api.go | 41 +++++++++ userapi/internal/api.go | 42 +++++++++ userapi/inthttp/client.go | 18 ++++ userapi/inthttp/server.go | 26 ++++++ userapi/storage/accounts/interface.go | 2 + .../storage/accounts/postgres/openid_table.go | 84 ++++++++++++++++++ userapi/storage/accounts/postgres/storage.go | 23 +++++ .../storage/accounts/sqlite3/openid_table.go | 86 +++++++++++++++++++ userapi/storage/accounts/sqlite3/storage.go | 26 ++++++ 16 files changed, 520 insertions(+), 3 deletions(-) create mode 100644 clientapi/routing/openid.go create mode 100644 federationapi/routing/openid.go create mode 100644 userapi/storage/accounts/postgres/openid_table.go create mode 100644 userapi/storage/accounts/sqlite3/openid_table.go diff --git a/clientapi/routing/openid.go b/clientapi/routing/openid.go new file mode 100644 index 000000000..12d8fa25f --- /dev/null +++ b/clientapi/routing/openid.go @@ -0,0 +1,70 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +type openIDTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + MatrixServerName string `json:"matrix_server_name"` + ExpiresIn int64 `json:"expires_in"` +} + +// CreateOpenIDToken creates a new OpenID Connect (OIDC) token that a Matrix user +// can supply to an OpenID Relying Party to verify their identity +func CreateOpenIDToken( + req *http.Request, + userAPI api.UserInternalAPI, + device *api.Device, + userID string, + cfg *config.ClientAPI, +) util.JSONResponse { + // does the incoming user ID match the user that the token was issued for? + if userID != device.UserID { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("Cannot request tokens for other users"), + } + } + + request := api.PerformOpenIDTokenCreationRequest{ + UserID: userID, // this is the user ID from the incoming path + } + response := api.PerformOpenIDTokenCreationResponse{} + + err := userAPI.PerformOpenIDTokenCreation(req.Context(), &request, &response) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.CreateOpenIDToken failed") + return jsonerror.InternalServerError() + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: openIDTokenResponse{ + AccessToken: response.Token.Token, + TokenType: "Bearer", + MatrixServerName: string(cfg.Matrix.ServerName), + ExpiresIn: response.Token.ExpiresTS / 1000, // convert ms to s + }, + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index a56359b4c..5882fffd4 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -469,7 +469,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - // Stub endpoints required by Riot + // Stub endpoints required by Element r0mux.Handle("/login", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { @@ -506,7 +506,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - // Riot user settings + // Element user settings r0mux.Handle("/profile/{userID}", httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse { @@ -592,7 +592,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - // Riot logs get flooded unless this is handled + // Element logs get flooded unless this is handled r0mux.Handle("/presence/{userID}/status", httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse { if r := rateLimits.rateLimit(req); r != nil { @@ -685,6 +685,19 @@ func Setup( }), ).Methods(http.MethodGet) + r0mux.Handle("/user/{userId}/openid/request_token", + httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return CreateOpenIDToken(req, userAPI, device, vars["userId"], cfg) + }), + ).Methods(http.MethodPost, http.MethodOptions) + r0mux.Handle("/user_directory/search", httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.rateLimit(req); r != nil { diff --git a/federationapi/routing/openid.go b/federationapi/routing/openid.go new file mode 100644 index 000000000..55a9af0aa --- /dev/null +++ b/federationapi/routing/openid.go @@ -0,0 +1,65 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + "time" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +type openIDUserInfoResponse struct { + Sub string `json:"sub"` +} + +// GetOpenIDUserInfo implements GET /_matrix/federation/v1/openid/userinfo +func GetOpenIDUserInfo( + httpReq *http.Request, + userAPI userapi.UserInternalAPI, +) util.JSONResponse { + token := httpReq.URL.Query().Get("access_token") + if len(token) == 0 { + return util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: jsonerror.MissingArgument("access_token is missing"), + } + } + + req := userapi.QueryOpenIDTokenRequest{ + Token: token, + } + + var openIDTokenAttrResponse userapi.QueryOpenIDTokenResponse + err := userAPI.QueryOpenIDToken(httpReq.Context(), &req, &openIDTokenAttrResponse) + if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("userAPI.QueryOpenIDToken failed") + } + + var res interface{} = openIDUserInfoResponse{Sub: openIDTokenAttrResponse.Sub} + code := http.StatusOK + nowMS := time.Now().UnixNano() / int64(time.Millisecond) + if openIDTokenAttrResponse.Sub == "" || nowMS > openIDTokenAttrResponse.ExpiresTS { + code = http.StatusUnauthorized + res = jsonerror.UnknownToken("Access Token unknown or expired") + } + + return util.JSONResponse{ + Code: code, + JSON: res, + } +} diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index b579ae1fa..483cdcd00 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -460,4 +460,10 @@ func Setup( return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) + + v1fedmux.Handle("/openid/userinfo", + httputil.MakeExternalAPI("federation_openid_userinfo", func(req *http.Request) util.JSONResponse { + return GetOpenIDUserInfo(req, userAPI) + }), + ).Methods(http.MethodGet) } diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 4eb5708c1..79aaebc0b 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -524,6 +524,9 @@ func (u *testUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.Pe func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error { return nil } +func (u *testUserAPI) PerformOpenIDTokenCreation(ctx context.Context, req *userapi.PerformOpenIDTokenCreationRequest, res *userapi.PerformOpenIDTokenCreationResponse) error { + return nil +} func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error { return nil } @@ -548,6 +551,9 @@ func (u *testUserAPI) QueryDeviceInfos(ctx context.Context, req *userapi.QueryDe func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error { return nil } +func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOpenIDTokenRequest, res *userapi.QueryOpenIDTokenResponse) error { + return nil +} type testRoomserverAPI struct { // use a trace API as it implements method stubs so we don't need to have them here. diff --git a/setup/mscs/msc2946/msc2946_test.go b/setup/mscs/msc2946/msc2946_test.go index 99085c0f4..96160c10d 100644 --- a/setup/mscs/msc2946/msc2946_test.go +++ b/setup/mscs/msc2946/msc2946_test.go @@ -367,6 +367,9 @@ func (u *testUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.Pe func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error { return nil } +func (u *testUserAPI) PerformOpenIDTokenCreation(ctx context.Context, req *userapi.PerformOpenIDTokenCreationRequest, res *userapi.PerformOpenIDTokenCreationResponse) error { + return nil +} func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error { return nil } @@ -391,6 +394,9 @@ func (u *testUserAPI) QueryDeviceInfos(ctx context.Context, req *userapi.QueryDe func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error { return nil } +func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOpenIDTokenRequest, res *userapi.QueryOpenIDTokenResponse) error { + return nil +} type testRoomserverAPI struct { // use a trace API as it implements method stubs so we don't need to have them here. diff --git a/sytest-whitelist b/sytest-whitelist index ed02fdecb..8c4585716 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -517,3 +517,6 @@ AS can set avatar for ghosted users AS can set displayname for ghosted users Ghost user must register before joining room Inviting an AS-hosted user asks the AS server +Can generate a openid access_token that can be exchanged for information about a user +Invalid openid access tokens are rejected +Requests to userinfo without access tokens are rejected diff --git a/userapi/api/api.go b/userapi/api/api.go index 45e4e834e..0bcbc624d 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -32,12 +32,14 @@ type UserInternalAPI interface { PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error + PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error + QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error } // InputAccountDataRequest is the request for InputAccountData @@ -226,6 +228,27 @@ type PerformAccountDeactivationResponse struct { AccountDeactivated bool } +// PerformOpenIDTokenCreationRequest is the request for PerformOpenIDTokenCreation +type PerformOpenIDTokenCreationRequest struct { + UserID string +} + +// PerformOpenIDTokenCreationResponse is the response for PerformOpenIDTokenCreation +type PerformOpenIDTokenCreationResponse struct { + Token OpenIDToken +} + +// QueryOpenIDTokenRequest is the request for QueryOpenIDToken +type QueryOpenIDTokenRequest struct { + Token string +} + +// QueryOpenIDTokenResponse is the response for QueryOpenIDToken +type QueryOpenIDTokenResponse struct { + Sub string // The Matrix User ID that generated the token + ExpiresTS int64 +} + // Device represents a client's device (mobile, web, etc) type Device struct { ID string @@ -256,6 +279,24 @@ type Account struct { // TODO: Associations (e.g. with application services) } +// OpenIDToken represents an OpenID token +type OpenIDToken struct { + Token string + UserID string + ExpiresTS int64 +} + +// OpenIDTokenInfo represents the attributes associated with an issued OpenID token +type OpenIDTokenAttributes struct { + UserID string + ExpiresTS int64 +} + +// UserInfo is for returning information about the user an OpenID token was issued for +type UserInfo struct { + Sub string // The Matrix user's ID who generated the token +} + // ErrorForbidden is an error indicating that the supplied access token is forbidden type ErrorForbidden struct { Message string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 0d01afa19..3dbbf3a2c 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -20,6 +20,9 @@ import ( "encoding/json" "errors" "fmt" + "hash" + "hash/fnv" + "time" "github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/clientapi/userutil" @@ -414,3 +417,42 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a res.AccountDeactivated = err == nil return err } + +// PerformOpenIDTokenCreation creates a new token that a relying party uses to authenticate a user +func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error { + token := util.RandomString(24) + tokenHash := getTokenHash(token) + nowMS := time.Now().UnixNano() / int64(time.Millisecond) + expiresMS := nowMS + (3600 * 1000) // 60 minutes + + err := a.AccountDB.CreateOpenIDToken(ctx, fmt.Sprint(tokenHash.Sum32()), req.UserID, expiresMS) + + res.Token = api.OpenIDToken{ + Token: token, + UserID: req.UserID, + ExpiresTS: expiresMS, + } + + return err +} + +// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation +func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { + tokenHash := getTokenHash(req.Token) + openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, fmt.Sprint(tokenHash.Sum32())) + if err != nil { + return err + } + + res.Sub = openIDTokenAttrs.UserID + res.ExpiresTS = openIDTokenAttrs.ExpiresTS + + return nil +} + +// used for getting the format in which the token is stored to prevent plaintext storage of sensitive info +func getTokenHash(token string) hash.Hash32 { + tokenHash := fnv.New32a() + _, _ = tokenHash.Write([]byte(token)) + return tokenHash +} diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 680e4cb52..1cb5ef0a8 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -35,6 +35,7 @@ const ( PerformLastSeenUpdatePath = "/userapi/performLastSeenUpdate" PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" + PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" QueryProfilePath = "/userapi/queryProfile" QueryAccessTokenPath = "/userapi/queryAccessToken" @@ -42,6 +43,7 @@ const ( QueryAccountDataPath = "/userapi/queryAccountData" QueryDeviceInfosPath = "/userapi/queryDeviceInfos" QuerySearchProfilesPath = "/userapi/querySearchProfiles" + QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -148,6 +150,14 @@ func (h *httpUserInternalAPI) PerformAccountDeactivation(ctx context.Context, re return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } +func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, request *api.PerformOpenIDTokenCreationRequest, response *api.PerformOpenIDTokenCreationResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOpenIDTokenCreation") + defer span.Finish() + + apiURL := h.apiURL + PerformOpenIDTokenCreationPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + func (h *httpUserInternalAPI) QueryProfile( ctx context.Context, request *api.QueryProfileRequest, @@ -207,3 +217,11 @@ func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api. apiURL := h.apiURL + QuerySearchProfilesPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOpenIDToken") + defer span.Finish() + + apiURL := h.apiURL + QueryOpenIDTokenPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index e495e3536..1c1cfdcd1 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -117,6 +117,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(PerformOpenIDTokenCreationPath, + httputil.MakeInternalAPI("performOpenIDTokenCreation", func(req *http.Request) util.JSONResponse { + request := api.PerformOpenIDTokenCreationRequest{} + response := api.PerformOpenIDTokenCreationResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformOpenIDTokenCreation(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(QueryProfilePath, httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse { request := api.QueryProfileRequest{} @@ -195,6 +208,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryOpenIDTokenPath, + httputil.MakeInternalAPI("queryOpenIDToken", func(req *http.Request) util.JSONResponse { + request := api.QueryOpenIDTokenRequest{} + response := api.QueryOpenIDTokenResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryOpenIDToken(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(InputAccountDataPath, httputil.MakeInternalAPI("inputAccountDataPath", func(req *http.Request) util.JSONResponse { request := api.InputAccountDataRequest{} diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index c86b2c391..ed2e6696c 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -52,6 +52,8 @@ type Database interface { GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) DeactivateAccount(ctx context.Context, localpart string) (err error) + CreateOpenIDToken(ctx context.Context, tokenHash, localpart string, expirationTS int64) (err error) + GetOpenIDTokenAttributes(ctx context.Context, tokenHash string) (*api.OpenIDTokenAttributes, error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/accounts/postgres/openid_table.go b/userapi/storage/accounts/postgres/openid_table.go new file mode 100644 index 000000000..aa47fc420 --- /dev/null +++ b/userapi/storage/accounts/postgres/openid_table.go @@ -0,0 +1,84 @@ +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +const openIDTokenSchema = ` +-- Stores data about openid tokens issued for accounts. +CREATE TABLE IF NOT EXISTS open_id_tokens ( + -- This is a hash the token value + token_hash TEXT NOT NULL PRIMARY KEY, + -- The Matrix user ID for this account + localpart TEXT NOT NULL, + -- When the token expires, as a unix timestamp (ms resolution). + token_expires_ts BIGINT NOT NULL +); +` + +const insertTokenSQL = "" + + "INSERT INTO open_id_tokens(token_hash, localpart, token_expires_ts) VALUES ($1, $2, $3)" + +const selectTokenSQL = "" + + "SELECT localpart, token_expires_ts FROM open_id_tokens WHERE token_hash = $1" + +type tokenStatements struct { + insertTokenStmt *sql.Stmt + selectTokenStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + _, err = db.Exec(openIDTokenSchema) + if err != nil { + return + } + if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil { + return + } + if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil { + return + } + s.serverName = server + return +} + +// insertToken inserts a new OpenID Connect token to the DB. +// Returns new token, otherwise returns error if token hash already exists. +func (s *tokenStatements) insertToken( + ctx context.Context, + txn *sql.Tx, + tokenHash, localpart string, + expiresTimeMS int64, +) (err error) { + stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) + _, err = stmt.ExecContext(ctx, tokenHash, localpart, expiresTimeMS) + return +} + +// selectOpenIDTokenAtrributesByTokenHash gets the attributes associated with an OpenID token from the DB +// Returns the existing token's attributes, or err if no token is found +func (s *tokenStatements) selectOpenIDTokenAtrributesByTokenHash( + ctx context.Context, + tokenHash string, +) (*api.OpenIDTokenAttributes, error) { + var openIDTokenAttrs api.OpenIDTokenAttributes + err := s.selectTokenStmt.QueryRowContext(ctx, tokenHash).Scan( + &openIDTokenAttrs.UserID, + &openIDTokenAttrs.ExpiresTS, + ) + if err != nil { + if err != sql.ErrNoRows { + log.WithError(err).Error("Unable to retrieve token from the db") + } + return nil, err + } + + return &openIDTokenAttrs, nil +} diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index 3933fe5bd..c84afbe1f 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -43,6 +43,7 @@ type Database struct { profiles profilesStatements accountDatas accountDataStatements threepids threepidStatements + openIDTokens tokenStatements serverName gomatrixserverlib.ServerName bcryptCost int } @@ -86,6 +87,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.threepids.prepare(db); err != nil { return nil, err } + if err = d.openIDTokens.prepare(db, serverName); err != nil { + return nil, err + } return d, nil } @@ -341,3 +345,22 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { return d.accounts.deactivateAccount(ctx, localpart) } + +// CreateOpenIDToken persists a new token that was issued through OpenID Connect +func (d *Database) CreateOpenIDToken( + ctx context.Context, + tokenHash, localpart string, + expirationTS int64, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.openIDTokens.insertToken(ctx, txn, tokenHash, localpart, expirationTS) + }) +} + +// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token +func (d *Database) GetOpenIDTokenAttributes( + ctx context.Context, + tokenHash string, +) (*api.OpenIDTokenAttributes, error) { + return d.openIDTokens.selectOpenIDTokenAtrributesByTokenHash(ctx, tokenHash) +} diff --git a/userapi/storage/accounts/sqlite3/openid_table.go b/userapi/storage/accounts/sqlite3/openid_table.go new file mode 100644 index 000000000..47602bcbd --- /dev/null +++ b/userapi/storage/accounts/sqlite3/openid_table.go @@ -0,0 +1,86 @@ +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +const openIDTokenSchema = ` +-- Stores data about accounts. +CREATE TABLE IF NOT EXISTS open_id_tokens ( + -- This is the hash of the token value + token_hash TEXT NOT NULL PRIMARY KEY, + -- The Matrix user ID for this account + localpart TEXT NOT NULL, + -- When the token expires, as a unix timestamp (ms resolution). + token_expires_ts BIGINT NOT NULL +); +` + +const insertTokenSQL = "" + + "INSERT INTO open_id_tokens(token_hash, localpart, token_expires_ts) VALUES ($1, $2, $3)" + +const selectTokenSQL = "" + + "SELECT localpart, token_expires_ts FROM open_id_tokens WHERE token_hash = $1" + +type tokenStatements struct { + db *sql.DB + insertTokenStmt *sql.Stmt + selectTokenStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + s.db = db + _, err = db.Exec(openIDTokenSchema) + if err != nil { + return err + } + if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil { + return + } + if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil { + return + } + s.serverName = server + return +} + +// insertToken inserts a new OpenID Connect token to the DB. +// Returns new token, otherwise returns error if token hash already exists. +func (s *tokenStatements) insertToken( + ctx context.Context, + txn *sql.Tx, + token_hash, localpart string, + expiresTimeMS int64, +) (err error) { + stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) + _, err = stmt.ExecContext(ctx, token_hash, localpart, expiresTimeMS) + return +} + +// selectOpenIDTokenAtrributesByTokenHash gets the attributes associated with an OpenID token from the DB +// Returns the existing token's attributes, or err if no token is found +func (s *tokenStatements) selectOpenIDTokenAtrributesByTokenHash( + ctx context.Context, + tokenHash string, +) (*api.OpenIDTokenAttributes, error) { + var openIDTokenAttrs api.OpenIDTokenAttributes + err := s.selectTokenStmt.QueryRowContext(ctx, tokenHash).Scan( + &openIDTokenAttrs.UserID, + &openIDTokenAttrs.ExpiresTS, + ) + if err != nil { + if err != sql.ErrNoRows { + log.WithError(err).Error("Unable to retrieve token from the db") + } + return nil, err + } + + return &openIDTokenAttrs, nil +} diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 07cc68b35..2149850c9 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -41,6 +41,7 @@ type Database struct { profiles profilesStatements accountDatas accountDataStatements threepids threepidStatements + openIDTokens tokenStatements serverName gomatrixserverlib.ServerName bcryptCost int @@ -48,6 +49,7 @@ type Database struct { profilesMu sync.Mutex accountDatasMu sync.Mutex threepidsMu sync.Mutex + openIDsMu sync.Mutex } // NewDatabase creates a new accounts and profiles database @@ -90,6 +92,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.threepids.prepare(db); err != nil { return nil, err } + if err = d.openIDTokens.prepare(db, serverName); err != nil { + return nil, err + } return d, nil } @@ -379,3 +384,24 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { return d.accounts.deactivateAccount(ctx, localpart) } + +// CreateOpenIDToken persists a new token that was issued for OpenID Connect +func (d *Database) CreateOpenIDToken( + ctx context.Context, + tokenHash, localpart string, + expirationTS int64, +) error { + d.openIDsMu.Lock() + defer d.openIDsMu.Unlock() + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.openIDTokens.insertToken(ctx, txn, tokenHash, localpart, expirationTS) + }) +} + +// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token +func (d *Database) GetOpenIDTokenAttributes( + ctx context.Context, + tokenHash string, +) (*api.OpenIDTokenAttributes, error) { + return d.openIDTokens.selectOpenIDTokenAtrributesByTokenHash(ctx, tokenHash) +}