From d27607af78a53bda636f14f603b02b2952d6e1d8 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Wed, 7 Apr 2021 05:26:20 -0700 Subject: [PATCH] Implement OpenID module (#599) (#1812) * Implement OpenID module (#599) - Unrelated: change Riot references to Element in client API routing Signed-off-by: Bruce MacDonald * OpenID module tweaks (#599) - specify expiry is ms rather than vague ts - add OpenID token lifetime to configuration - use Go naming conventions for the path params - store plaintext token rather than hash - remove openid table sqllite mutex * Add default OpenID token lifetime (#599) * Update dendrite-config.yaml Co-authored-by: Kegsay Co-authored-by: Kegsay --- clientapi/routing/openid.go | 70 +++++++++++++++ clientapi/routing/routing.go | 19 +++- cmd/create-account/main.go | 2 +- dendrite-config.yaml | 5 ++ federationapi/routing/openid.go | 65 ++++++++++++++ federationapi/routing/routing.go | 6 ++ setup/base.go | 2 +- setup/config/config_userapi.go | 7 ++ setup/mscs/msc2836/msc2836_test.go | 6 ++ setup/mscs/msc2946/msc2946_test.go | 6 ++ sytest-whitelist | 3 + userapi/api/api.go | 41 +++++++++ userapi/internal/api.go | 28 ++++++ userapi/inthttp/client.go | 18 ++++ userapi/inthttp/server.go | 26 ++++++ userapi/storage/accounts/interface.go | 2 + .../storage/accounts/postgres/openid_table.go | 84 ++++++++++++++++++ userapi/storage/accounts/postgres/storage.go | 49 ++++++++--- .../storage/accounts/sqlite3/openid_table.go | 86 +++++++++++++++++++ userapi/storage/accounts/sqlite3/storage.go | 49 ++++++++--- userapi/storage/accounts/storage.go | 6 +- userapi/storage/accounts/storage_wasm.go | 3 +- userapi/userapi_test.go | 2 +- 23 files changed, 553 insertions(+), 32 deletions(-) create mode 100644 clientapi/routing/openid.go create mode 100644 federationapi/routing/openid.go create mode 100644 userapi/storage/accounts/postgres/openid_table.go create mode 100644 userapi/storage/accounts/sqlite3/openid_table.go diff --git a/clientapi/routing/openid.go b/clientapi/routing/openid.go new file mode 100644 index 000000000..13656e288 --- /dev/null +++ b/clientapi/routing/openid.go @@ -0,0 +1,70 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +type openIDTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + MatrixServerName string `json:"matrix_server_name"` + ExpiresIn int64 `json:"expires_in"` +} + +// CreateOpenIDToken creates a new OpenID Connect (OIDC) token that a Matrix user +// can supply to an OpenID Relying Party to verify their identity +func CreateOpenIDToken( + req *http.Request, + userAPI api.UserInternalAPI, + device *api.Device, + userID string, + cfg *config.ClientAPI, +) util.JSONResponse { + // does the incoming user ID match the user that the token was issued for? + if userID != device.UserID { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("Cannot request tokens for other users"), + } + } + + request := api.PerformOpenIDTokenCreationRequest{ + UserID: userID, // this is the user ID from the incoming path + } + response := api.PerformOpenIDTokenCreationResponse{} + + err := userAPI.PerformOpenIDTokenCreation(req.Context(), &request, &response) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.CreateOpenIDToken failed") + return jsonerror.InternalServerError() + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: openIDTokenResponse{ + AccessToken: response.Token.Token, + TokenType: "Bearer", + MatrixServerName: string(cfg.Matrix.ServerName), + ExpiresIn: response.Token.ExpiresAtMS / 1000, // convert ms to s + }, + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index a56359b4c..5d4f90a45 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -469,7 +469,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - // Stub endpoints required by Riot + // Stub endpoints required by Element r0mux.Handle("/login", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { @@ -506,7 +506,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - // Riot user settings + // Element user settings r0mux.Handle("/profile/{userID}", httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse { @@ -592,7 +592,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - // Riot logs get flooded unless this is handled + // Element logs get flooded unless this is handled r0mux.Handle("/presence/{userID}/status", httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse { if r := rateLimits.rateLimit(req); r != nil { @@ -685,6 +685,19 @@ func Setup( }), ).Methods(http.MethodGet) + r0mux.Handle("/user/{userID}/openid/request_token", + httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return CreateOpenIDToken(req, userAPI, device, vars["userID"], cfg) + }), + ).Methods(http.MethodPost, http.MethodOptions) + r0mux.Handle("/user_directory/search", httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.rateLimit(req); r != nil { diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 22732c518..060b82f97 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -58,7 +58,7 @@ func main() { accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString, - }, cfg.Global.ServerName, bcrypt.DefaultCost) + }, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS) if err != nil { logrus.Fatalln("Failed to connect to the database:", err.Error()) } diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 402987f98..1edb026f7 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -360,6 +360,11 @@ user_api: max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 + # The length of time that a token issued for a relying party from + # /_matrix/client/r0/user/{userId}/openid/request_token endpoint + # is considered to be valid in milliseconds. + # The default lifetime is 3600000ms (60 minutes). + # openid_token_lifetime_ms: 3600000 # Configuration for Opentracing. # See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on diff --git a/federationapi/routing/openid.go b/federationapi/routing/openid.go new file mode 100644 index 000000000..829dbccad --- /dev/null +++ b/federationapi/routing/openid.go @@ -0,0 +1,65 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + "time" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +type openIDUserInfoResponse struct { + Sub string `json:"sub"` +} + +// GetOpenIDUserInfo implements GET /_matrix/federation/v1/openid/userinfo +func GetOpenIDUserInfo( + httpReq *http.Request, + userAPI userapi.UserInternalAPI, +) util.JSONResponse { + token := httpReq.URL.Query().Get("access_token") + if len(token) == 0 { + return util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: jsonerror.MissingArgument("access_token is missing"), + } + } + + req := userapi.QueryOpenIDTokenRequest{ + Token: token, + } + + var openIDTokenAttrResponse userapi.QueryOpenIDTokenResponse + err := userAPI.QueryOpenIDToken(httpReq.Context(), &req, &openIDTokenAttrResponse) + if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("userAPI.QueryOpenIDToken failed") + } + + var res interface{} = openIDUserInfoResponse{Sub: openIDTokenAttrResponse.Sub} + code := http.StatusOK + nowMS := time.Now().UnixNano() / int64(time.Millisecond) + if openIDTokenAttrResponse.Sub == "" || nowMS > openIDTokenAttrResponse.ExpiresAtMS { + code = http.StatusUnauthorized + res = jsonerror.UnknownToken("Access Token unknown or expired") + } + + return util.JSONResponse{ + Code: code, + JSON: res, + } +} diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index ce018904f..07a28c3fc 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -462,4 +462,10 @@ func Setup( return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) + + v1fedmux.Handle("/openid/userinfo", + httputil.MakeExternalAPI("federation_openid_userinfo", func(req *http.Request) util.JSONResponse { + return GetOpenIDUserInfo(req, userAPI) + }), + ).Methods(http.MethodGet) } diff --git a/setup/base.go b/setup/base.go index b081ffaf8..6bdeb80f7 100644 --- a/setup/base.go +++ b/setup/base.go @@ -280,7 +280,7 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI { // CreateAccountsDB creates a new instance of the accounts database. Should only // be called once per component. func (b *BaseDendrite) CreateAccountsDB() accounts.Database { - db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost) + db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS) if err != nil { logrus.WithError(err).Panicf("failed to connect to accounts db") } diff --git a/setup/config/config_userapi.go b/setup/config/config_userapi.go index e69123842..2bf1be3dd 100644 --- a/setup/config/config_userapi.go +++ b/setup/config/config_userapi.go @@ -10,6 +10,9 @@ type UserAPI struct { // The cost when hashing passwords. BCryptCost int `yaml:"bcrypt_cost"` + // The length of time an OpenID token is condidered valid in milliseconds + OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"` + // The Account database stores the login details and account information // for local users. It is accessed by the UserAPI. AccountDatabase DatabaseOptions `yaml:"account_database"` @@ -18,6 +21,8 @@ type UserAPI struct { DeviceDatabase DatabaseOptions `yaml:"device_database"` } +const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes + func (c *UserAPI) Defaults() { c.InternalAPI.Listen = "http://localhost:7781" c.InternalAPI.Connect = "http://localhost:7781" @@ -26,6 +31,7 @@ func (c *UserAPI) Defaults() { c.AccountDatabase.ConnectionString = "file:userapi_accounts.db" c.DeviceDatabase.ConnectionString = "file:userapi_devices.db" c.BCryptCost = bcrypt.DefaultCost + c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS } func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { @@ -33,4 +39,5 @@ func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect)) checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString)) checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString)) + checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS) } diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 4eb5708c1..79aaebc0b 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -524,6 +524,9 @@ func (u *testUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.Pe func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error { return nil } +func (u *testUserAPI) PerformOpenIDTokenCreation(ctx context.Context, req *userapi.PerformOpenIDTokenCreationRequest, res *userapi.PerformOpenIDTokenCreationResponse) error { + return nil +} func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error { return nil } @@ -548,6 +551,9 @@ func (u *testUserAPI) QueryDeviceInfos(ctx context.Context, req *userapi.QueryDe func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error { return nil } +func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOpenIDTokenRequest, res *userapi.QueryOpenIDTokenResponse) error { + return nil +} type testRoomserverAPI struct { // use a trace API as it implements method stubs so we don't need to have them here. diff --git a/setup/mscs/msc2946/msc2946_test.go b/setup/mscs/msc2946/msc2946_test.go index 99085c0f4..96160c10d 100644 --- a/setup/mscs/msc2946/msc2946_test.go +++ b/setup/mscs/msc2946/msc2946_test.go @@ -367,6 +367,9 @@ func (u *testUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.Pe func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error { return nil } +func (u *testUserAPI) PerformOpenIDTokenCreation(ctx context.Context, req *userapi.PerformOpenIDTokenCreationRequest, res *userapi.PerformOpenIDTokenCreationResponse) error { + return nil +} func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error { return nil } @@ -391,6 +394,9 @@ func (u *testUserAPI) QueryDeviceInfos(ctx context.Context, req *userapi.QueryDe func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error { return nil } +func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOpenIDTokenRequest, res *userapi.QueryOpenIDTokenResponse) error { + return nil +} type testRoomserverAPI struct { // use a trace API as it implements method stubs so we don't need to have them here. diff --git a/sytest-whitelist b/sytest-whitelist index ed02fdecb..8c4585716 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -517,3 +517,6 @@ AS can set avatar for ghosted users AS can set displayname for ghosted users Ghost user must register before joining room Inviting an AS-hosted user asks the AS server +Can generate a openid access_token that can be exchanged for information about a user +Invalid openid access tokens are rejected +Requests to userinfo without access tokens are rejected diff --git a/userapi/api/api.go b/userapi/api/api.go index 45e4e834e..407350123 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -32,12 +32,14 @@ type UserInternalAPI interface { PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error + PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error + QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error } // InputAccountDataRequest is the request for InputAccountData @@ -226,6 +228,27 @@ type PerformAccountDeactivationResponse struct { AccountDeactivated bool } +// PerformOpenIDTokenCreationRequest is the request for PerformOpenIDTokenCreation +type PerformOpenIDTokenCreationRequest struct { + UserID string +} + +// PerformOpenIDTokenCreationResponse is the response for PerformOpenIDTokenCreation +type PerformOpenIDTokenCreationResponse struct { + Token OpenIDToken +} + +// QueryOpenIDTokenRequest is the request for QueryOpenIDToken +type QueryOpenIDTokenRequest struct { + Token string +} + +// QueryOpenIDTokenResponse is the response for QueryOpenIDToken +type QueryOpenIDTokenResponse struct { + Sub string // The Matrix User ID that generated the token + ExpiresAtMS int64 +} + // Device represents a client's device (mobile, web, etc) type Device struct { ID string @@ -256,6 +279,24 @@ type Account struct { // TODO: Associations (e.g. with application services) } +// OpenIDToken represents an OpenID token +type OpenIDToken struct { + Token string + UserID string + ExpiresAtMS int64 +} + +// OpenIDTokenInfo represents the attributes associated with an issued OpenID token +type OpenIDTokenAttributes struct { + UserID string + ExpiresAtMS int64 +} + +// UserInfo is for returning information about the user an OpenID token was issued for +type UserInfo struct { + Sub string // The Matrix user's ID who generated the token +} + // ErrorForbidden is an error indicating that the supplied access token is forbidden type ErrorForbidden struct { Message string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 0d01afa19..21933c1c4 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -414,3 +414,31 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a res.AccountDeactivated = err == nil return err } + +// PerformOpenIDTokenCreation creates a new token that a relying party uses to authenticate a user +func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error { + token := util.RandomString(24) + + exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID) + + res.Token = api.OpenIDToken{ + Token: token, + UserID: req.UserID, + ExpiresAtMS: exp, + } + + return err +} + +// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation +func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { + openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token) + if err != nil { + return err + } + + res.Sub = openIDTokenAttrs.UserID + res.ExpiresAtMS = openIDTokenAttrs.ExpiresAtMS + + return nil +} diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 680e4cb52..1cb5ef0a8 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -35,6 +35,7 @@ const ( PerformLastSeenUpdatePath = "/userapi/performLastSeenUpdate" PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" + PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" QueryProfilePath = "/userapi/queryProfile" QueryAccessTokenPath = "/userapi/queryAccessToken" @@ -42,6 +43,7 @@ const ( QueryAccountDataPath = "/userapi/queryAccountData" QueryDeviceInfosPath = "/userapi/queryDeviceInfos" QuerySearchProfilesPath = "/userapi/querySearchProfiles" + QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -148,6 +150,14 @@ func (h *httpUserInternalAPI) PerformAccountDeactivation(ctx context.Context, re return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } +func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, request *api.PerformOpenIDTokenCreationRequest, response *api.PerformOpenIDTokenCreationResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOpenIDTokenCreation") + defer span.Finish() + + apiURL := h.apiURL + PerformOpenIDTokenCreationPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + func (h *httpUserInternalAPI) QueryProfile( ctx context.Context, request *api.QueryProfileRequest, @@ -207,3 +217,11 @@ func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api. apiURL := h.apiURL + QuerySearchProfilesPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOpenIDToken") + defer span.Finish() + + apiURL := h.apiURL + QueryOpenIDTokenPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index e495e3536..1c1cfdcd1 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -117,6 +117,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(PerformOpenIDTokenCreationPath, + httputil.MakeInternalAPI("performOpenIDTokenCreation", func(req *http.Request) util.JSONResponse { + request := api.PerformOpenIDTokenCreationRequest{} + response := api.PerformOpenIDTokenCreationResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformOpenIDTokenCreation(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(QueryProfilePath, httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse { request := api.QueryProfileRequest{} @@ -195,6 +208,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryOpenIDTokenPath, + httputil.MakeInternalAPI("queryOpenIDToken", func(req *http.Request) util.JSONResponse { + request := api.QueryOpenIDTokenRequest{} + response := api.QueryOpenIDTokenResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryOpenIDToken(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(InputAccountDataPath, httputil.MakeInternalAPI("inputAccountDataPath", func(req *http.Request) util.JSONResponse { request := api.InputAccountDataRequest{} diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index c86b2c391..5aa61b909 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -52,6 +52,8 @@ type Database interface { GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) DeactivateAccount(ctx context.Context, localpart string) (err error) + CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error) + GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/accounts/postgres/openid_table.go b/userapi/storage/accounts/postgres/openid_table.go new file mode 100644 index 000000000..86c197059 --- /dev/null +++ b/userapi/storage/accounts/postgres/openid_table.go @@ -0,0 +1,84 @@ +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +const openIDTokenSchema = ` +-- Stores data about openid tokens issued for accounts. +CREATE TABLE IF NOT EXISTS open_id_tokens ( + -- The value of the token issued to a user + token TEXT NOT NULL PRIMARY KEY, + -- The Matrix user ID for this account + localpart TEXT NOT NULL, + -- When the token expires, as a unix timestamp (ms resolution). + token_expires_at_ms BIGINT NOT NULL +); +` + +const insertTokenSQL = "" + + "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + +const selectTokenSQL = "" + + "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" + +type tokenStatements struct { + insertTokenStmt *sql.Stmt + selectTokenStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + _, err = db.Exec(openIDTokenSchema) + if err != nil { + return + } + if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil { + return + } + if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil { + return + } + s.serverName = server + return +} + +// insertToken inserts a new OpenID Connect token to the DB. +// Returns new token, otherwise returns error if the token already exists. +func (s *tokenStatements) insertToken( + ctx context.Context, + txn *sql.Tx, + token, localpart string, + expiresAtMS int64, +) (err error) { + stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) + _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) + return +} + +// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB +// Returns the existing token's attributes, or err if no token is found +func (s *tokenStatements) selectOpenIDTokenAtrributes( + ctx context.Context, + token string, +) (*api.OpenIDTokenAttributes, error) { + var openIDTokenAttrs api.OpenIDTokenAttributes + err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( + &openIDTokenAttrs.UserID, + &openIDTokenAttrs.ExpiresAtMS, + ) + if err != nil { + if err != sql.ErrNoRows { + log.WithError(err).Error("Unable to retrieve token from the db") + } + return nil, err + } + + return &openIDTokenAttrs, nil +} diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index 3933fe5bd..c5e74ed15 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -20,6 +20,7 @@ import ( "encoding/json" "errors" "strconv" + "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -39,25 +40,28 @@ type Database struct { db *sql.DB writer sqlutil.Writer sqlutil.PartitionOffsetStatements - accounts accountsStatements - profiles profilesStatements - accountDatas accountDataStatements - threepids threepidStatements - serverName gomatrixserverlib.ServerName - bcryptCost int + accounts accountsStatements + profiles profilesStatements + accountDatas accountDataStatements + threepids threepidStatements + openIDTokens tokenStatements + serverName gomatrixserverlib.ServerName + bcryptCost int + openIDTokenLifetimeMS int64 } // NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err } d := &Database{ - serverName: serverName, - db: db, - writer: sqlutil.NewDummyWriter(), - bcryptCost: bcryptCost, + serverName: serverName, + db: db, + writer: sqlutil.NewDummyWriter(), + bcryptCost: bcryptCost, + openIDTokenLifetimeMS: openIDTokenLifetimeMS, } // Create tables before executing migrations so we don't fail if the table is missing, @@ -86,6 +90,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.threepids.prepare(db); err != nil { return nil, err } + if err = d.openIDTokens.prepare(db, serverName); err != nil { + return nil, err + } return d, nil } @@ -341,3 +348,23 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { return d.accounts.deactivateAccount(ctx, localpart) } + +// CreateOpenIDToken persists a new token that was issued through OpenID Connect +func (d *Database) CreateOpenIDToken( + ctx context.Context, + token, localpart string, +) (int64, error) { + expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS + err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) + }) + return expiresAtMS, err +} + +// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token +func (d *Database) GetOpenIDTokenAttributes( + ctx context.Context, + token string, +) (*api.OpenIDTokenAttributes, error) { + return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) +} diff --git a/userapi/storage/accounts/sqlite3/openid_table.go b/userapi/storage/accounts/sqlite3/openid_table.go new file mode 100644 index 000000000..80b9dd4cb --- /dev/null +++ b/userapi/storage/accounts/sqlite3/openid_table.go @@ -0,0 +1,86 @@ +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +const openIDTokenSchema = ` +-- Stores data about accounts. +CREATE TABLE IF NOT EXISTS open_id_tokens ( + -- The value of the token issued to a user + token TEXT NOT NULL PRIMARY KEY, + -- The Matrix user ID for this account + localpart TEXT NOT NULL, + -- When the token expires, as a unix timestamp (ms resolution). + token_expires_at_ms BIGINT NOT NULL +); +` + +const insertTokenSQL = "" + + "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + +const selectTokenSQL = "" + + "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" + +type tokenStatements struct { + db *sql.DB + insertTokenStmt *sql.Stmt + selectTokenStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + s.db = db + _, err = db.Exec(openIDTokenSchema) + if err != nil { + return err + } + if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil { + return + } + if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil { + return + } + s.serverName = server + return +} + +// insertToken inserts a new OpenID Connect token to the DB. +// Returns new token, otherwise returns error if the token already exists. +func (s *tokenStatements) insertToken( + ctx context.Context, + txn *sql.Tx, + token, localpart string, + expiresAtMS int64, +) (err error) { + stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) + _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) + return +} + +// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB +// Returns the existing token's attributes, or err if no token is found +func (s *tokenStatements) selectOpenIDTokenAtrributes( + ctx context.Context, + token string, +) (*api.OpenIDTokenAttributes, error) { + var openIDTokenAttrs api.OpenIDTokenAttributes + err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( + &openIDTokenAttrs.UserID, + &openIDTokenAttrs.ExpiresAtMS, + ) + if err != nil { + if err != sql.ErrNoRows { + log.WithError(err).Error("Unable to retrieve token from the db") + } + return nil, err + } + + return &openIDTokenAttrs, nil +} diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 07cc68b35..c0f7118cb 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -21,6 +21,7 @@ import ( "errors" "strconv" "sync" + "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -37,12 +38,14 @@ type Database struct { writer sqlutil.Writer sqlutil.PartitionOffsetStatements - accounts accountsStatements - profiles profilesStatements - accountDatas accountDataStatements - threepids threepidStatements - serverName gomatrixserverlib.ServerName - bcryptCost int + accounts accountsStatements + profiles profilesStatements + accountDatas accountDataStatements + threepids threepidStatements + openIDTokens tokenStatements + serverName gomatrixserverlib.ServerName + bcryptCost int + openIDTokenLifetimeMS int64 accountsMu sync.Mutex profilesMu sync.Mutex @@ -51,16 +54,17 @@ type Database struct { } // NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err } d := &Database{ - serverName: serverName, - db: db, - writer: sqlutil.NewExclusiveWriter(), - bcryptCost: bcryptCost, + serverName: serverName, + db: db, + writer: sqlutil.NewExclusiveWriter(), + bcryptCost: bcryptCost, + openIDTokenLifetimeMS: openIDTokenLifetimeMS, } // Create tables before executing migrations so we don't fail if the table is missing, @@ -90,6 +94,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.threepids.prepare(db); err != nil { return nil, err } + if err = d.openIDTokens.prepare(db, serverName); err != nil { + return nil, err + } return d, nil } @@ -379,3 +386,23 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { return d.accounts.deactivateAccount(ctx, localpart) } + +// CreateOpenIDToken persists a new token that was issued for OpenID Connect +func (d *Database) CreateOpenIDToken( + ctx context.Context, + token, localpart string, +) (int64, error) { + expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS + err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) + }) + return expiresAtMS, err +} + +// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token +func (d *Database) GetOpenIDTokenAttributes( + ctx context.Context, + token string, +) (*api.OpenIDTokenAttributes, error) { + return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) +} diff --git a/userapi/storage/accounts/storage.go b/userapi/storage/accounts/storage.go index 28c437daa..3489c9d07 100644 --- a/userapi/storage/accounts/storage.go +++ b/userapi/storage/accounts/storage.go @@ -27,12 +27,12 @@ import ( // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost) + return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties, serverName, bcryptCost) + return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/userapi/storage/accounts/storage_wasm.go b/userapi/storage/accounts/storage_wasm.go index 8038326fe..11a88a20a 100644 --- a/userapi/storage/accounts/storage_wasm.go +++ b/userapi/storage/accounts/storage_wasm.go @@ -26,10 +26,11 @@ func NewDatabase( dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, + openIDTokenLifetimeMS int64, ) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost) + return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 9a45a2dc8..0141258e6 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -26,7 +26,7 @@ const ( func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database) { accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ ConnectionString: "file::memory:", - }, serverName, bcrypt.MinCost) + }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS) if err != nil { t.Fatalf("failed to create account DB: %s", err) }