From 6c9a8b96dda4d0c307868687a6d51476c1859407 Mon Sep 17 00:00:00 2001 From: Jake Hemmerle Date: Wed, 4 Nov 2020 23:51:06 -0500 Subject: [PATCH] feat pt 1: db working, controller not yet --- clientapi/openid/openid.go | 7 ++ clientapi/routing/openid.go | 54 ++++++++++ clientapi/routing/routing.go | 13 +++ federationapi/routing/openid.go | 14 +++ userapi/api/api.go | 32 ++++++ userapi/internal/api.go | 41 ++++++++ userapi/inthttp/client.go | 18 ++++ userapi/inthttp/server.go | 26 +++++ userapi/storage/accounts/interface.go | 2 + .../storage/accounts/postgres/openid_table.go | 95 ++++++++++++++++++ userapi/storage/accounts/postgres/storage.go | 24 +++++ .../storage/accounts/sqlite3/openid_table.go | 99 +++++++++++++++++++ userapi/storage/accounts/sqlite3/storage.go | 27 +++++ 13 files changed, 452 insertions(+) create mode 100644 clientapi/openid/openid.go create mode 100644 clientapi/routing/openid.go create mode 100644 federationapi/routing/openid.go create mode 100644 userapi/storage/accounts/postgres/openid_table.go create mode 100644 userapi/storage/accounts/sqlite3/openid_table.go diff --git a/clientapi/openid/openid.go b/clientapi/openid/openid.go new file mode 100644 index 000000000..7d89334e9 --- /dev/null +++ b/clientapi/openid/openid.go @@ -0,0 +1,7 @@ +package openid + +// TokenRequest represents the request defined at https://matrix.org/docs/spec/client_server/r0.6.1#id603 +type TokenRequest struct { + UserID string `json:"userId"` + RelyingParty string `json:"relyingParty"` +} diff --git a/clientapi/routing/openid.go b/clientapi/routing/openid.go new file mode 100644 index 000000000..3cf63306c --- /dev/null +++ b/clientapi/routing/openid.go @@ -0,0 +1,54 @@ +package routing + +import ( + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +type openIDTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + MatrixServerName string `json:"matrix_server_name"` + ExpiresIn int64 `json:"expires_in"` +} + +// CreateOpenIDToken creates a new OpenID Connect token that a Matrix user +// can supply to an OpenID Relying Party +func CreateOpenIDToken( + req *http.Request, + userAPI api.UserInternalAPI, + device *api.Device, + userID, relyingParty string, + cfg *config.ClientAPI, +) util.JSONResponse { + if userID != device.UserID { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("userID does not match the current user"), + } + } + + request := api.PerformOpenIDTokenCreationRequest{ + UserID: userID, + RelyingParty: relyingParty} + response := api.PerformOpenIDTokenCreationResponse{} + + err := userAPI.PerformOpenIDTokenCreation(req.Context(), &request, &response) + if err != nil { + return jsonerror.InternalServerError() + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: openIDTokenResponse{ + AccessToken: response.Token.Token, + TokenType: "Bearer", + MatrixServerName: string(cfg.Matrix.ServerName), + ExpiresIn: response.Token.ExpiresTS / 1000, // convert ms to s + }, + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 65b622b3a..7cfdbc75b 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -661,6 +661,19 @@ func Setup( }), ).Methods(http.MethodGet) + r0mux.Handle("/user/{userId}/openid/request_token", + httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return CreateOpenIDToken(req, userAPI, device, vars["userId"], vars["relyingParty"], cfg) + }), + ).Methods(http.MethodPost, http.MethodOptions) + r0mux.Handle("/user_directory/search", httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.rateLimit(req); r != nil { diff --git a/federationapi/routing/openid.go b/federationapi/routing/openid.go new file mode 100644 index 000000000..68cfa82fc --- /dev/null +++ b/federationapi/routing/openid.go @@ -0,0 +1,14 @@ +package routing + +/* + +import ( + "net/http" +) + +func GetInvite( + httpReq *http.Request, +) { + return +} +*/ diff --git a/userapi/api/api.go b/userapi/api/api.go index 6c3f3c69c..7f76d544b 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -31,12 +31,14 @@ type UserInternalAPI interface { PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error + PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error + QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error } // InputAccountDataRequest is the request for InputAccountData @@ -214,6 +216,27 @@ type PerformAccountDeactivationResponse struct { AccountDeactivated bool } +// PerformOpenIDTokenCreationRequest is the request for PerformOpenIDTokenCreation +type PerformOpenIDTokenCreationRequest struct { + UserID string + RelyingParty string +} + +// PerformOpenIDTokenCreationResponse is the response for PerformOpenIDTokenCreation +type PerformOpenIDTokenCreationResponse struct { + Token OpenIDToken +} + +// QueryOpenIDTokenRequest is the request for QueryOpenIDToken +type QueryOpenIDTokenRequest struct { + Token string +} + +// QueryOpenIDTokenResponse is the response for QueryOpenIDToken +type QueryOpenIDTokenResponse struct { + Token OpenIDToken +} + // Device represents a client's device (mobile, web, etc) type Device struct { ID string @@ -241,6 +264,15 @@ type Account struct { // TODO: Associations (e.g. with application services) } +// OpenIDToken represents an OpenID token +type OpenIDToken struct { + Token string + UserID string + CreatedTS int64 + ExpiresTS int64 + RelyingParty string +} + // ErrorForbidden is an error indicating that the supplied access token is forbidden type ErrorForbidden struct { Message string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 81d002414..658549439 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -20,6 +20,7 @@ import ( "encoding/json" "errors" "fmt" + "time" "github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/clientapi/userutil" @@ -395,3 +396,43 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a res.AccountDeactivated = err == nil return err } + +// PerformOpenIDTokenCreation creates a new token from an optional relying party. +func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error { + localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + + if domain != a.ServerName { + return fmt.Errorf("cannot create OpenID token for accounts not on this sercer: got %s want %s", domain, a.ServerName) + } + + tokenString := util.RandomString(24) + createdMS := time.Now().UnixNano() / int64(time.Millisecond) + expiresMS := createdMS + (3600 * 1000) // 60 minutes + + err = a.AccountDB.CreateOpenIDToken(ctx, tokenString, localpart, createdMS, expiresMS, req.RelyingParty) + + res.Token = api.OpenIDToken{ + Token: tokenString, + UserID: userutil.MakeUserID(localpart, a.ServerName), + CreatedTS: createdMS, + ExpiresTS: expiresMS, + RelyingParty: req.RelyingParty, + } + + return err +} + +// QueryOpenIDToken returns the user information from the provided token string +func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { + token, err := a.AccountDB.GetOpenIDToken(ctx, req.Token) + if err != nil { + return err + } + + res.Token = *token + + return nil +} diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 4d9dcc416..7a122cec0 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -34,6 +34,7 @@ const ( PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" + PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" QueryProfilePath = "/userapi/queryProfile" QueryAccessTokenPath = "/userapi/queryAccessToken" @@ -41,6 +42,7 @@ const ( QueryAccountDataPath = "/userapi/queryAccountData" QueryDeviceInfosPath = "/userapi/queryDeviceInfos" QuerySearchProfilesPath = "/userapi/querySearchProfiles" + QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -135,6 +137,14 @@ func (h *httpUserInternalAPI) PerformAccountDeactivation(ctx context.Context, re return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } +func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, request *api.PerformOpenIDTokenCreationRequest, response *api.PerformOpenIDTokenCreationResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOpenIDTokenCreation") + defer span.Finish() + + apiURL := h.apiURL + PerformOpenIDTokenCreationPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + func (h *httpUserInternalAPI) QueryProfile( ctx context.Context, request *api.QueryProfileRequest, @@ -194,3 +204,11 @@ func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api. apiURL := h.apiURL + QuerySearchProfilesPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOpenIDToken") + defer span.Finish() + + apiURL := h.apiURL + QueryOpenIDTokenPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 81e936e58..f7cfba232 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -104,6 +104,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(PerformOpenIDTokenCreationPath, + httputil.MakeInternalAPI("performOpenIDTokenCreation", func(req *http.Request) util.JSONResponse { + request := api.PerformOpenIDTokenCreationRequest{} + response := api.PerformOpenIDTokenCreationResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformOpenIDTokenCreation(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(QueryProfilePath, httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse { request := api.QueryProfileRequest{} @@ -182,6 +195,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryOpenIDTokenPath, + httputil.MakeInternalAPI("queryOpenIDToken", func(req *http.Request) util.JSONResponse { + request := api.QueryOpenIDTokenRequest{} + response := api.QueryOpenIDTokenResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryOpenIDToken(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(InputAccountDataPath, httputil.MakeInternalAPI("inputAccountDataPath", func(req *http.Request) util.JSONResponse { request := api.InputAccountDataRequest{} diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index c86b2c391..e96c328b7 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -52,6 +52,8 @@ type Database interface { GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) DeactivateAccount(ctx context.Context, localpart string) (err error) + CreateOpenIDToken(ctx context.Context, token, localpart string, creationTS, expirationTS int64, rp string) (err error) + GetOpenIDToken(ctx context.Context, token string) (*api.OpenIDToken, error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/accounts/postgres/openid_table.go b/userapi/storage/accounts/postgres/openid_table.go new file mode 100644 index 000000000..577d58a47 --- /dev/null +++ b/userapi/storage/accounts/postgres/openid_table.go @@ -0,0 +1,95 @@ +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" +) + +const openIDTokenSchema = ` +-- Stores data about accounts. +CREATE TABLE IF NOT EXISTS account_openid ( + -- This is the token value, empty by default + token TEXT NOT NULL PRIMARY KEY, + -- The Matrix user ID localpart for this account + localpart TEXT NOT NULL, + -- When this token was first created, as a unix timestamp (ms resolution). + token_created_ts BIGINT NOT NULL, + -- When the token expires, as a unix timestamp (ms resolution). + token_expires_ts BIGINT NOT NULL, + -- (optional) Relying Party the token was created for + token_rp TEXT, +); +-- Create sequence for autogenerated numeric usernames +-- CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; +` + +const insertTokenSQL = "" + + "INSERT INTO account_openid(token, localpart, token_created_ts, token_expires_ts, token_rp) VALUES ($1, $2, $3, $4, $5)" + +const selectTokenSQL = "" + + "SELECT token, localpart, token_created_ts, token_expires_ts, token_rp FROM account_openid WHERE token = $1" + +type tokenStatements struct { + insertTokenStmt *sql.Stmt + selectTokenStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + _, err = db.Exec(openIDTokenSchema) + if err != nil { + return + } + if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil { + return + } + if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil { + return + } + s.serverName = server + return +} + +func (s *tokenStatements) insertToken( + ctx context.Context, + txn *sql.Tx, + token, localpart string, + createdTimeMS, expiresTimeMS int64, + tokenRP string, +) (err error) { + stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) + + if tokenRP == "" { + _, err = stmt.ExecContext(ctx, token, localpart, createdTimeMS, expiresTimeMS, nil) + } else { + _, err = stmt.ExecContext(ctx, token, localpart, createdTimeMS, expiresTimeMS, tokenRP) + } + return +} + +func (s *tokenStatements) selectToken( + ctx context.Context, + token string, +) (*api.OpenIDToken, error) { + var openIDToken api.OpenIDToken + var localpart string + + err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( + &openIDToken.Token, + localpart, + &openIDToken.CreatedTS, + &openIDToken.ExpiresTS, + &openIDToken.RelyingParty, + ) + if err != nil { + return nil, err + } + + openIDToken.UserID = userutil.MakeUserID(localpart, s.serverName) + return &openIDToken, nil +} diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index 40c4b8ff5..4e4dfaa01 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -43,6 +43,7 @@ type Database struct { profiles profilesStatements accountDatas accountDataStatements threepids threepidStatements + openIDTokens tokenStatements serverName gomatrixserverlib.ServerName } @@ -84,6 +85,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.threepids.prepare(db); err != nil { return nil, err } + if err = d.openIDTokens.prepare(db, serverName); err != nil { + return nil, err + } return d, nil } @@ -337,3 +341,23 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { return d.accounts.deactivateAccount(ctx, localpart) } + +// CreateOpenIDToken creates a new token for OpenID Connect +func (d *Database) CreateOpenIDToken( + ctx context.Context, + token, localpart string, + createdTS, expirationTS int64, + rp string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.openIDTokens.insertToken(ctx, txn, token, localpart, createdTS, expirationTS, rp) + }) +} + +// GetOpenIDToken gets a whole token +func (d *Database) GetOpenIDToken( + ctx context.Context, + token string, +) (*api.OpenIDToken, error) { + return d.openIDTokens.selectToken(ctx, token) +} diff --git a/userapi/storage/accounts/sqlite3/openid_table.go b/userapi/storage/accounts/sqlite3/openid_table.go new file mode 100644 index 000000000..4fdd9a3c8 --- /dev/null +++ b/userapi/storage/accounts/sqlite3/openid_table.go @@ -0,0 +1,99 @@ +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" +) + +const openIDTokenSchema = ` +-- Stores data about accounts. +CREATE TABLE IF NOT EXISTS account_openid ( + -- This is the token value, empty by default + token TEXT NOT NULL PRIMARY KEY, + -- The Matrix user ID localpart for this account + localpart TEXT NOT NULL, + -- When this token was first created, as a unix timestamp (ms resolution). + token_created_ts BIGINT NOT NULL, + -- When the token expires, as a unix timestamp (ms resolution). + token_expires_ts BIGINT NOT NULL, + -- (optional) Relying Party the token was created for + token_rp TEXT, +); +-- Create sequence for autogenerated numeric usernames +-- CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; +` + +const insertTokenSQL = "" + + "INSERT INTO account_openid(token, localpart, token_created_ts, token_expires_ts, token_rp) VALUES ($1, $2, $3, $4, $5)" + +const selectTokenSQL = "" + + "SELECT token, localpart, token_created_ts, token_expires_ts, token_rp FROM account_openid WHERE token = $1" + +type tokenStatements struct { + db *sql.DB + insertTokenStmt *sql.Stmt + selectTokenStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + s.db = db + _, err = db.Exec(openIDTokenSchema) + if err != nil { + return err + } + if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil { + return + } + if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil { + return + } + s.serverName = server + return +} + +// insertToken inserts a new OpenID Connect token to the DB. +// tokenRP is the OpenID Relying Party; if not specified, it's left nil +// Returns new token, otherwise returns error if token already exists. +func (s *tokenStatements) insertToken( + ctx context.Context, + txn *sql.Tx, + token, localpart string, + createdTimeMS, expiresTimeMS int64, + tokenRP string, +) (err error) { + stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) + + if tokenRP == "" { + _, err = stmt.ExecContext(ctx, token, localpart, createdTimeMS, expiresTimeMS, nil) + } else { + _, err = stmt.ExecContext(ctx, token, localpart, createdTimeMS, expiresTimeMS, tokenRP) + } + return +} + +func (s *tokenStatements) selectToken( + ctx context.Context, + token string, +) (openIDToken *api.OpenIDToken, err error) { + var localpart string + + err = s.selectTokenStmt.QueryRowContext(ctx, token).Scan( + &openIDToken.Token, + localpart, + &openIDToken.CreatedTS, + &openIDToken.ExpiresTS, + &openIDToken.RelyingParty, + ) + if err != nil { + return nil, err + } + + openIDToken.UserID = userutil.MakeUserID(localpart, s.serverName) + return openIDToken, nil +} diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 0be7bcbe7..89a60975c 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -42,12 +42,14 @@ type Database struct { profiles profilesStatements accountDatas accountDataStatements threepids threepidStatements + openIDTokens tokenStatements serverName gomatrixserverlib.ServerName accountsMu sync.Mutex profilesMu sync.Mutex accountDatasMu sync.Mutex threepidsMu sync.Mutex + openIDsMu sync.Mutex } // NewDatabase creates a new accounts and profiles database @@ -89,6 +91,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.threepids.prepare(db); err != nil { return nil, err } + if err = d.openIDTokens.prepare(db, serverName); err != nil { + return nil, err + } return d, nil } @@ -378,3 +383,25 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { return d.accounts.deactivateAccount(ctx, localpart) } + +// CreateOpenIDToken creates a new token for the +func (d *Database) CreateOpenIDToken( + ctx context.Context, + token, localpart string, + createdTS, expirationTS int64, + rp string, +) error { + d.openIDsMu.Lock() + defer d.openIDsMu.Unlock() + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.openIDTokens.insertToken(ctx, txn, token, localpart, createdTS, expirationTS, rp) + }) +} + +// GetOpenIDToken returns an OpenIDToken struct from the token string +func (d *Database) GetOpenIDToken( + ctx context.Context, + token string, +) (*api.OpenIDToken, error) { + return d.openIDTokens.selectToken(ctx, token) +}