mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 08:13:09 -06:00
OpenID module tweaks (#599)
- specify expiry is ms rather than vague ts - add OpenID token lifetime to configuration - use Go naming conventions for the path params - store plaintext token rather than hash - remove openid table sqllite mutex
This commit is contained in:
parent
5f1272c74f
commit
8b915d3e2a
|
|
@ -64,7 +64,7 @@ func CreateOpenIDToken(
|
|||
AccessToken: response.Token.Token,
|
||||
TokenType: "Bearer",
|
||||
MatrixServerName: string(cfg.Matrix.ServerName),
|
||||
ExpiresIn: response.Token.ExpiresTS / 1000, // convert ms to s
|
||||
ExpiresIn: response.Token.ExpiresAtMS / 1000, // convert ms to s
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -685,7 +685,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet)
|
||||
|
||||
r0mux.Handle("/user/{userId}/openid/request_token",
|
||||
r0mux.Handle("/user/{userID}/openid/request_token",
|
||||
httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.rateLimit(req); r != nil {
|
||||
return *r
|
||||
|
|
@ -694,7 +694,7 @@ func Setup(
|
|||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return CreateOpenIDToken(req, userAPI, device, vars["userId"], cfg)
|
||||
return CreateOpenIDToken(req, userAPI, device, vars["userID"], cfg)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
|
|
|
|||
|
|
@ -360,6 +360,11 @@ user_api:
|
|||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
# The length of time that a token issed for a relying party from the
|
||||
# /_matrix/client/r0/user/{userId}/openid/request_token endpoint
|
||||
# is considered to be valid in milliseconds.
|
||||
# The default lifetime is 3600000ms (60 minutes).
|
||||
# openid_token_lifetime_ms: 3600000
|
||||
|
||||
# Configuration for Opentracing.
|
||||
# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ func GetOpenIDUserInfo(
|
|||
var res interface{} = openIDUserInfoResponse{Sub: openIDTokenAttrResponse.Sub}
|
||||
code := http.StatusOK
|
||||
nowMS := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
if openIDTokenAttrResponse.Sub == "" || nowMS > openIDTokenAttrResponse.ExpiresTS {
|
||||
if openIDTokenAttrResponse.Sub == "" || nowMS > openIDTokenAttrResponse.ExpiresAtMS {
|
||||
code = http.StatusUnauthorized
|
||||
res = jsonerror.UnknownToken("Access Token unknown or expired")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -263,7 +263,7 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI {
|
|||
// CreateAccountsDB creates a new instance of the accounts database. Should only
|
||||
// be called once per component.
|
||||
func (b *BaseDendrite) CreateAccountsDB() accounts.Database {
|
||||
db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost)
|
||||
db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to connect to accounts db")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ type UserAPI struct {
|
|||
// The cost when hashing passwords.
|
||||
BCryptCost int `yaml:"bcrypt_cost"`
|
||||
|
||||
// The length of time an OpenID token is condidered valid in milliseconds
|
||||
OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"`
|
||||
|
||||
// The Account database stores the login details and account information
|
||||
// for local users. It is accessed by the UserAPI.
|
||||
AccountDatabase DatabaseOptions `yaml:"account_database"`
|
||||
|
|
@ -26,6 +29,7 @@ func (c *UserAPI) Defaults() {
|
|||
c.AccountDatabase.ConnectionString = "file:userapi_accounts.db"
|
||||
c.DeviceDatabase.ConnectionString = "file:userapi_devices.db"
|
||||
c.BCryptCost = bcrypt.DefaultCost
|
||||
c.OpenIDTokenLifetimeMS = 3600000 // 60 minutes
|
||||
}
|
||||
|
||||
func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||
|
|
@ -33,4 +37,5 @@ func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
|||
checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect))
|
||||
checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString))
|
||||
checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString))
|
||||
checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -245,8 +245,8 @@ type QueryOpenIDTokenRequest struct {
|
|||
|
||||
// QueryOpenIDTokenResponse is the response for QueryOpenIDToken
|
||||
type QueryOpenIDTokenResponse struct {
|
||||
Sub string // The Matrix User ID that generated the token
|
||||
ExpiresTS int64
|
||||
Sub string // The Matrix User ID that generated the token
|
||||
ExpiresAtMS int64
|
||||
}
|
||||
|
||||
// Device represents a client's device (mobile, web, etc)
|
||||
|
|
@ -281,15 +281,15 @@ type Account struct {
|
|||
|
||||
// OpenIDToken represents an OpenID token
|
||||
type OpenIDToken struct {
|
||||
Token string
|
||||
UserID string
|
||||
ExpiresTS int64
|
||||
Token string
|
||||
UserID string
|
||||
ExpiresAtMS int64
|
||||
}
|
||||
|
||||
// OpenIDTokenInfo represents the attributes associated with an issued OpenID token
|
||||
type OpenIDTokenAttributes struct {
|
||||
UserID string
|
||||
ExpiresTS int64
|
||||
UserID string
|
||||
ExpiresAtMS int64
|
||||
}
|
||||
|
||||
// UserInfo is for returning information about the user an OpenID token was issued for
|
||||
|
|
|
|||
|
|
@ -20,9 +20,6 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"hash/fnv"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/appservice/types"
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
|
|
@ -421,16 +418,13 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
|
|||
// PerformOpenIDTokenCreation creates a new token that a relying party uses to authenticate a user
|
||||
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
|
||||
token := util.RandomString(24)
|
||||
tokenHash := getTokenHash(token)
|
||||
nowMS := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
expiresMS := nowMS + (3600 * 1000) // 60 minutes
|
||||
|
||||
err := a.AccountDB.CreateOpenIDToken(ctx, fmt.Sprint(tokenHash.Sum32()), req.UserID, expiresMS)
|
||||
exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID)
|
||||
|
||||
res.Token = api.OpenIDToken{
|
||||
Token: token,
|
||||
UserID: req.UserID,
|
||||
ExpiresTS: expiresMS,
|
||||
Token: token,
|
||||
UserID: req.UserID,
|
||||
ExpiresAtMS: exp,
|
||||
}
|
||||
|
||||
return err
|
||||
|
|
@ -438,21 +432,13 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a
|
|||
|
||||
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
|
||||
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
|
||||
tokenHash := getTokenHash(req.Token)
|
||||
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, fmt.Sprint(tokenHash.Sum32()))
|
||||
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res.Sub = openIDTokenAttrs.UserID
|
||||
res.ExpiresTS = openIDTokenAttrs.ExpiresTS
|
||||
res.ExpiresAtMS = openIDTokenAttrs.ExpiresAtMS
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// used for getting the format in which the token is stored to prevent plaintext storage of sensitive info
|
||||
func getTokenHash(token string) hash.Hash32 {
|
||||
tokenHash := fnv.New32a()
|
||||
_, _ = tokenHash.Write([]byte(token))
|
||||
return tokenHash
|
||||
}
|
||||
|
|
|
|||
|
|
@ -52,8 +52,8 @@ type Database interface {
|
|||
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||
CreateOpenIDToken(ctx context.Context, tokenHash, localpart string, expirationTS int64) (err error)
|
||||
GetOpenIDTokenAttributes(ctx context.Context, tokenHash string) (*api.OpenIDTokenAttributes, error)
|
||||
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
|
||||
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||
}
|
||||
|
||||
// Err3PIDInUse is the error returned when trying to save an association involving
|
||||
|
|
|
|||
|
|
@ -13,20 +13,20 @@ import (
|
|||
const openIDTokenSchema = `
|
||||
-- Stores data about openid tokens issued for accounts.
|
||||
CREATE TABLE IF NOT EXISTS open_id_tokens (
|
||||
-- This is a hash the token value
|
||||
token_hash TEXT NOT NULL PRIMARY KEY,
|
||||
-- The value of the token issued to a user
|
||||
token TEXT NOT NULL PRIMARY KEY,
|
||||
-- The Matrix user ID for this account
|
||||
localpart TEXT NOT NULL,
|
||||
-- When the token expires, as a unix timestamp (ms resolution).
|
||||
token_expires_ts BIGINT NOT NULL
|
||||
token_expires_at_ms BIGINT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertTokenSQL = "" +
|
||||
"INSERT INTO open_id_tokens(token_hash, localpart, token_expires_ts) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
|
||||
const selectTokenSQL = "" +
|
||||
"SELECT localpart, token_expires_ts FROM open_id_tokens WHERE token_hash = $1"
|
||||
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
||||
|
||||
type tokenStatements struct {
|
||||
insertTokenStmt *sql.Stmt
|
||||
|
|
@ -50,28 +50,28 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam
|
|||
}
|
||||
|
||||
// insertToken inserts a new OpenID Connect token to the DB.
|
||||
// Returns new token, otherwise returns error if token hash already exists.
|
||||
// Returns new token, otherwise returns error if the token already exists.
|
||||
func (s *tokenStatements) insertToken(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
tokenHash, localpart string,
|
||||
expiresTimeMS int64,
|
||||
token, localpart string,
|
||||
expiresAtMS int64,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||
_, err = stmt.ExecContext(ctx, tokenHash, localpart, expiresTimeMS)
|
||||
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
||||
return
|
||||
}
|
||||
|
||||
// selectOpenIDTokenAtrributesByTokenHash gets the attributes associated with an OpenID token from the DB
|
||||
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
|
||||
// Returns the existing token's attributes, or err if no token is found
|
||||
func (s *tokenStatements) selectOpenIDTokenAtrributesByTokenHash(
|
||||
func (s *tokenStatements) selectOpenIDTokenAtrributes(
|
||||
ctx context.Context,
|
||||
tokenHash string,
|
||||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||
err := s.selectTokenStmt.QueryRowContext(ctx, tokenHash).Scan(
|
||||
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||
&openIDTokenAttrs.UserID,
|
||||
&openIDTokenAttrs.ExpiresTS,
|
||||
&openIDTokenAttrs.ExpiresAtMS,
|
||||
)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
|
|
@ -39,26 +40,28 @@ type Database struct {
|
|||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
sqlutil.PartitionOffsetStatements
|
||||
accounts accountsStatements
|
||||
profiles profilesStatements
|
||||
accountDatas accountDataStatements
|
||||
threepids threepidStatements
|
||||
openIDTokens tokenStatements
|
||||
serverName gomatrixserverlib.ServerName
|
||||
bcryptCost int
|
||||
accounts accountsStatements
|
||||
profiles profilesStatements
|
||||
accountDatas accountDataStatements
|
||||
threepids threepidStatements
|
||||
openIDTokens tokenStatements
|
||||
serverName gomatrixserverlib.ServerName
|
||||
bcryptCost int
|
||||
openIDTokenLifetimeMS int64
|
||||
}
|
||||
|
||||
// NewDatabase creates a new accounts and profiles database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (*Database, error) {
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := &Database{
|
||||
serverName: serverName,
|
||||
db: db,
|
||||
writer: sqlutil.NewDummyWriter(),
|
||||
bcryptCost: bcryptCost,
|
||||
serverName: serverName,
|
||||
db: db,
|
||||
writer: sqlutil.NewDummyWriter(),
|
||||
bcryptCost: bcryptCost,
|
||||
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}
|
||||
|
||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||
|
|
@ -349,18 +352,19 @@ func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err
|
|||
// CreateOpenIDToken persists a new token that was issued through OpenID Connect
|
||||
func (d *Database) CreateOpenIDToken(
|
||||
ctx context.Context,
|
||||
tokenHash, localpart string,
|
||||
expirationTS int64,
|
||||
) error {
|
||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
return d.openIDTokens.insertToken(ctx, txn, tokenHash, localpart, expirationTS)
|
||||
token, localpart string,
|
||||
) (int64, error) {
|
||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
|
||||
err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
|
||||
})
|
||||
return expiresAtMS, err
|
||||
}
|
||||
|
||||
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
|
||||
func (d *Database) GetOpenIDTokenAttributes(
|
||||
ctx context.Context,
|
||||
tokenHash string,
|
||||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
return d.openIDTokens.selectOpenIDTokenAtrributesByTokenHash(ctx, tokenHash)
|
||||
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,20 +13,20 @@ import (
|
|||
const openIDTokenSchema = `
|
||||
-- Stores data about accounts.
|
||||
CREATE TABLE IF NOT EXISTS open_id_tokens (
|
||||
-- This is the hash of the token value
|
||||
token_hash TEXT NOT NULL PRIMARY KEY,
|
||||
-- The value of the token issued to a user
|
||||
token TEXT NOT NULL PRIMARY KEY,
|
||||
-- The Matrix user ID for this account
|
||||
localpart TEXT NOT NULL,
|
||||
-- When the token expires, as a unix timestamp (ms resolution).
|
||||
token_expires_ts BIGINT NOT NULL
|
||||
token_expires_at_ms BIGINT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertTokenSQL = "" +
|
||||
"INSERT INTO open_id_tokens(token_hash, localpart, token_expires_ts) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
|
||||
const selectTokenSQL = "" +
|
||||
"SELECT localpart, token_expires_ts FROM open_id_tokens WHERE token_hash = $1"
|
||||
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
||||
|
||||
type tokenStatements struct {
|
||||
db *sql.DB
|
||||
|
|
@ -52,28 +52,28 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam
|
|||
}
|
||||
|
||||
// insertToken inserts a new OpenID Connect token to the DB.
|
||||
// Returns new token, otherwise returns error if token hash already exists.
|
||||
// Returns new token, otherwise returns error if the token already exists.
|
||||
func (s *tokenStatements) insertToken(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
token_hash, localpart string,
|
||||
expiresTimeMS int64,
|
||||
token, localpart string,
|
||||
expiresAtMS int64,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||
_, err = stmt.ExecContext(ctx, token_hash, localpart, expiresTimeMS)
|
||||
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
||||
return
|
||||
}
|
||||
|
||||
// selectOpenIDTokenAtrributesByTokenHash gets the attributes associated with an OpenID token from the DB
|
||||
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
|
||||
// Returns the existing token's attributes, or err if no token is found
|
||||
func (s *tokenStatements) selectOpenIDTokenAtrributesByTokenHash(
|
||||
func (s *tokenStatements) selectOpenIDTokenAtrributes(
|
||||
ctx context.Context,
|
||||
tokenHash string,
|
||||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||
err := s.selectTokenStmt.QueryRowContext(ctx, tokenHash).Scan(
|
||||
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||
&openIDTokenAttrs.UserID,
|
||||
&openIDTokenAttrs.ExpiresTS,
|
||||
&openIDTokenAttrs.ExpiresAtMS,
|
||||
)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import (
|
|||
"errors"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
|
|
@ -37,32 +38,33 @@ type Database struct {
|
|||
writer sqlutil.Writer
|
||||
|
||||
sqlutil.PartitionOffsetStatements
|
||||
accounts accountsStatements
|
||||
profiles profilesStatements
|
||||
accountDatas accountDataStatements
|
||||
threepids threepidStatements
|
||||
openIDTokens tokenStatements
|
||||
serverName gomatrixserverlib.ServerName
|
||||
bcryptCost int
|
||||
accounts accountsStatements
|
||||
profiles profilesStatements
|
||||
accountDatas accountDataStatements
|
||||
threepids threepidStatements
|
||||
openIDTokens tokenStatements
|
||||
serverName gomatrixserverlib.ServerName
|
||||
bcryptCost int
|
||||
openIDTokenLifetimeMS int64
|
||||
|
||||
accountsMu sync.Mutex
|
||||
profilesMu sync.Mutex
|
||||
accountDatasMu sync.Mutex
|
||||
threepidsMu sync.Mutex
|
||||
openIDsMu sync.Mutex
|
||||
}
|
||||
|
||||
// NewDatabase creates a new accounts and profiles database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (*Database, error) {
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := &Database{
|
||||
serverName: serverName,
|
||||
db: db,
|
||||
writer: sqlutil.NewExclusiveWriter(),
|
||||
bcryptCost: bcryptCost,
|
||||
serverName: serverName,
|
||||
db: db,
|
||||
writer: sqlutil.NewExclusiveWriter(),
|
||||
bcryptCost: bcryptCost,
|
||||
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}
|
||||
|
||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||
|
|
@ -388,20 +390,19 @@ func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err
|
|||
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
|
||||
func (d *Database) CreateOpenIDToken(
|
||||
ctx context.Context,
|
||||
tokenHash, localpart string,
|
||||
expirationTS int64,
|
||||
) error {
|
||||
d.openIDsMu.Lock()
|
||||
defer d.openIDsMu.Unlock()
|
||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.openIDTokens.insertToken(ctx, txn, tokenHash, localpart, expirationTS)
|
||||
token, localpart string,
|
||||
) (int64, error) {
|
||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
|
||||
err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
|
||||
})
|
||||
return expiresAtMS, err
|
||||
}
|
||||
|
||||
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
|
||||
func (d *Database) GetOpenIDTokenAttributes(
|
||||
ctx context.Context,
|
||||
tokenHash string,
|
||||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
return d.openIDTokens.selectOpenIDTokenAtrributesByTokenHash(ctx, tokenHash)
|
||||
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,12 +27,12 @@ import (
|
|||
|
||||
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
||||
// and sets postgres connection parameters
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (Database, error) {
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost)
|
||||
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return postgres.NewDatabase(dbProperties, serverName, bcryptCost)
|
||||
return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue