OpenID module tweaks (#599)

- specify expiry is ms rather than vague ts
- add OpenID token lifetime to configuration
- use Go naming conventions for the path params
- store plaintext token rather than hash
- remove openid table sqllite mutex
This commit is contained in:
Bruce MacDonald 2021-04-03 14:22:07 -07:00
parent 5f1272c74f
commit 8b915d3e2a
14 changed files with 107 additions and 106 deletions

View file

@ -64,7 +64,7 @@ func CreateOpenIDToken(
AccessToken: response.Token.Token, AccessToken: response.Token.Token,
TokenType: "Bearer", TokenType: "Bearer",
MatrixServerName: string(cfg.Matrix.ServerName), MatrixServerName: string(cfg.Matrix.ServerName),
ExpiresIn: response.Token.ExpiresTS / 1000, // convert ms to s ExpiresIn: response.Token.ExpiresAtMS / 1000, // convert ms to s
}, },
} }
} }

View file

@ -685,7 +685,7 @@ func Setup(
}), }),
).Methods(http.MethodGet) ).Methods(http.MethodGet)
r0mux.Handle("/user/{userId}/openid/request_token", r0mux.Handle("/user/{userID}/openid/request_token",
httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.rateLimit(req); r != nil { if r := rateLimits.rateLimit(req); r != nil {
return *r return *r
@ -694,7 +694,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return CreateOpenIDToken(req, userAPI, device, vars["userId"], cfg) return CreateOpenIDToken(req, userAPI, device, vars["userID"], cfg)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)

View file

@ -360,6 +360,11 @@ user_api:
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
# The length of time that a token issed for a relying party from the
# /_matrix/client/r0/user/{userId}/openid/request_token endpoint
# is considered to be valid in milliseconds.
# The default lifetime is 3600000ms (60 minutes).
# openid_token_lifetime_ms: 3600000
# Configuration for Opentracing. # Configuration for Opentracing.
# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on # See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on

View file

@ -53,7 +53,7 @@ func GetOpenIDUserInfo(
var res interface{} = openIDUserInfoResponse{Sub: openIDTokenAttrResponse.Sub} var res interface{} = openIDUserInfoResponse{Sub: openIDTokenAttrResponse.Sub}
code := http.StatusOK code := http.StatusOK
nowMS := time.Now().UnixNano() / int64(time.Millisecond) nowMS := time.Now().UnixNano() / int64(time.Millisecond)
if openIDTokenAttrResponse.Sub == "" || nowMS > openIDTokenAttrResponse.ExpiresTS { if openIDTokenAttrResponse.Sub == "" || nowMS > openIDTokenAttrResponse.ExpiresAtMS {
code = http.StatusUnauthorized code = http.StatusUnauthorized
res = jsonerror.UnknownToken("Access Token unknown or expired") res = jsonerror.UnknownToken("Access Token unknown or expired")
} }

View file

@ -263,7 +263,7 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI {
// CreateAccountsDB creates a new instance of the accounts database. Should only // CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component. // be called once per component.
func (b *BaseDendrite) CreateAccountsDB() accounts.Database { func (b *BaseDendrite) CreateAccountsDB() accounts.Database {
db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost) db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to accounts db") logrus.WithError(err).Panicf("failed to connect to accounts db")
} }

View file

@ -10,6 +10,9 @@ type UserAPI struct {
// The cost when hashing passwords. // The cost when hashing passwords.
BCryptCost int `yaml:"bcrypt_cost"` BCryptCost int `yaml:"bcrypt_cost"`
// The length of time an OpenID token is condidered valid in milliseconds
OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"`
// The Account database stores the login details and account information // The Account database stores the login details and account information
// for local users. It is accessed by the UserAPI. // for local users. It is accessed by the UserAPI.
AccountDatabase DatabaseOptions `yaml:"account_database"` AccountDatabase DatabaseOptions `yaml:"account_database"`
@ -26,6 +29,7 @@ func (c *UserAPI) Defaults() {
c.AccountDatabase.ConnectionString = "file:userapi_accounts.db" c.AccountDatabase.ConnectionString = "file:userapi_accounts.db"
c.DeviceDatabase.ConnectionString = "file:userapi_devices.db" c.DeviceDatabase.ConnectionString = "file:userapi_devices.db"
c.BCryptCost = bcrypt.DefaultCost c.BCryptCost = bcrypt.DefaultCost
c.OpenIDTokenLifetimeMS = 3600000 // 60 minutes
} }
func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
@ -33,4 +37,5 @@ func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect)) checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect))
checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString)) checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString))
checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString)) checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString))
checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS)
} }

View file

@ -245,8 +245,8 @@ type QueryOpenIDTokenRequest struct {
// QueryOpenIDTokenResponse is the response for QueryOpenIDToken // QueryOpenIDTokenResponse is the response for QueryOpenIDToken
type QueryOpenIDTokenResponse struct { type QueryOpenIDTokenResponse struct {
Sub string // The Matrix User ID that generated the token Sub string // The Matrix User ID that generated the token
ExpiresTS int64 ExpiresAtMS int64
} }
// Device represents a client's device (mobile, web, etc) // Device represents a client's device (mobile, web, etc)
@ -281,15 +281,15 @@ type Account struct {
// OpenIDToken represents an OpenID token // OpenIDToken represents an OpenID token
type OpenIDToken struct { type OpenIDToken struct {
Token string Token string
UserID string UserID string
ExpiresTS int64 ExpiresAtMS int64
} }
// OpenIDTokenInfo represents the attributes associated with an issued OpenID token // OpenIDTokenInfo represents the attributes associated with an issued OpenID token
type OpenIDTokenAttributes struct { type OpenIDTokenAttributes struct {
UserID string UserID string
ExpiresTS int64 ExpiresAtMS int64
} }
// UserInfo is for returning information about the user an OpenID token was issued for // UserInfo is for returning information about the user an OpenID token was issued for

View file

@ -20,9 +20,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"hash"
"hash/fnv"
"time"
"github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
@ -421,16 +418,13 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
// PerformOpenIDTokenCreation creates a new token that a relying party uses to authenticate a user // PerformOpenIDTokenCreation creates a new token that a relying party uses to authenticate a user
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error { func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
token := util.RandomString(24) token := util.RandomString(24)
tokenHash := getTokenHash(token)
nowMS := time.Now().UnixNano() / int64(time.Millisecond)
expiresMS := nowMS + (3600 * 1000) // 60 minutes
err := a.AccountDB.CreateOpenIDToken(ctx, fmt.Sprint(tokenHash.Sum32()), req.UserID, expiresMS) exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID)
res.Token = api.OpenIDToken{ res.Token = api.OpenIDToken{
Token: token, Token: token,
UserID: req.UserID, UserID: req.UserID,
ExpiresTS: expiresMS, ExpiresAtMS: exp,
} }
return err return err
@ -438,21 +432,13 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation // QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
tokenHash := getTokenHash(req.Token) openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token)
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, fmt.Sprint(tokenHash.Sum32()))
if err != nil { if err != nil {
return err return err
} }
res.Sub = openIDTokenAttrs.UserID res.Sub = openIDTokenAttrs.UserID
res.ExpiresTS = openIDTokenAttrs.ExpiresTS res.ExpiresAtMS = openIDTokenAttrs.ExpiresAtMS
return nil return nil
} }
// used for getting the format in which the token is stored to prevent plaintext storage of sensitive info
func getTokenHash(token string) hash.Hash32 {
tokenHash := fnv.New32a()
_, _ = tokenHash.Write([]byte(token))
return tokenHash
}

View file

@ -52,8 +52,8 @@ type Database interface {
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
DeactivateAccount(ctx context.Context, localpart string) (err error) DeactivateAccount(ctx context.Context, localpart string) (err error)
CreateOpenIDToken(ctx context.Context, tokenHash, localpart string, expirationTS int64) (err error) CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
GetOpenIDTokenAttributes(ctx context.Context, tokenHash string) (*api.OpenIDTokenAttributes, error) GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
} }
// Err3PIDInUse is the error returned when trying to save an association involving // Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -13,20 +13,20 @@ import (
const openIDTokenSchema = ` const openIDTokenSchema = `
-- Stores data about openid tokens issued for accounts. -- Stores data about openid tokens issued for accounts.
CREATE TABLE IF NOT EXISTS open_id_tokens ( CREATE TABLE IF NOT EXISTS open_id_tokens (
-- This is a hash the token value -- The value of the token issued to a user
token_hash TEXT NOT NULL PRIMARY KEY, token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account -- The Matrix user ID for this account
localpart TEXT NOT NULL, localpart TEXT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution). -- When the token expires, as a unix timestamp (ms resolution).
token_expires_ts BIGINT NOT NULL token_expires_at_ms BIGINT NOT NULL
); );
` `
const insertTokenSQL = "" + const insertTokenSQL = "" +
"INSERT INTO open_id_tokens(token_hash, localpart, token_expires_ts) VALUES ($1, $2, $3)" "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectTokenSQL = "" + const selectTokenSQL = "" +
"SELECT localpart, token_expires_ts FROM open_id_tokens WHERE token_hash = $1" "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
type tokenStatements struct { type tokenStatements struct {
insertTokenStmt *sql.Stmt insertTokenStmt *sql.Stmt
@ -50,28 +50,28 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam
} }
// insertToken inserts a new OpenID Connect token to the DB. // insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if token hash already exists. // Returns new token, otherwise returns error if the token already exists.
func (s *tokenStatements) insertToken( func (s *tokenStatements) insertToken(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
tokenHash, localpart string, token, localpart string,
expiresTimeMS int64, expiresAtMS int64,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, tokenHash, localpart, expiresTimeMS) _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
return return
} }
// selectOpenIDTokenAtrributesByTokenHash gets the attributes associated with an OpenID token from the DB // selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found // Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributesByTokenHash( func (s *tokenStatements) selectOpenIDTokenAtrributes(
ctx context.Context, ctx context.Context,
tokenHash string, token string,
) (*api.OpenIDTokenAttributes, error) { ) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes var openIDTokenAttrs api.OpenIDTokenAttributes
err := s.selectTokenStmt.QueryRowContext(ctx, tokenHash).Scan( err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDTokenAttrs.UserID, &openIDTokenAttrs.UserID,
&openIDTokenAttrs.ExpiresTS, &openIDTokenAttrs.ExpiresAtMS,
) )
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {

View file

@ -20,6 +20,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"strconv" "strconv"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -39,26 +40,28 @@ type Database struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer writer sqlutil.Writer
sqlutil.PartitionOffsetStatements sqlutil.PartitionOffsetStatements
accounts accountsStatements accounts accountsStatements
profiles profilesStatements profiles profilesStatements
accountDatas accountDataStatements accountDatas accountDataStatements
threepids threepidStatements threepids threepidStatements
openIDTokens tokenStatements openIDTokens tokenStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
bcryptCost int bcryptCost int
openIDTokenLifetimeMS int64
} }
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
db, err := sqlutil.Open(dbProperties) db, err := sqlutil.Open(dbProperties)
if err != nil { if err != nil {
return nil, err return nil, err
} }
d := &Database{ d := &Database{
serverName: serverName, serverName: serverName,
db: db, db: db,
writer: sqlutil.NewDummyWriter(), writer: sqlutil.NewDummyWriter(),
bcryptCost: bcryptCost, bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
} }
// Create tables before executing migrations so we don't fail if the table is missing, // Create tables before executing migrations so we don't fail if the table is missing,
@ -349,18 +352,19 @@ func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err
// CreateOpenIDToken persists a new token that was issued through OpenID Connect // CreateOpenIDToken persists a new token that was issued through OpenID Connect
func (d *Database) CreateOpenIDToken( func (d *Database) CreateOpenIDToken(
ctx context.Context, ctx context.Context,
tokenHash, localpart string, token, localpart string,
expirationTS int64, ) (int64, error) {
) error { expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, tokenHash, localpart, expirationTS) return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
}) })
return expiresAtMS, err
} }
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token // GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
func (d *Database) GetOpenIDTokenAttributes( func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context, ctx context.Context,
tokenHash string, token string,
) (*api.OpenIDTokenAttributes, error) { ) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributesByTokenHash(ctx, tokenHash) return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
} }

View file

@ -13,20 +13,20 @@ import (
const openIDTokenSchema = ` const openIDTokenSchema = `
-- Stores data about accounts. -- Stores data about accounts.
CREATE TABLE IF NOT EXISTS open_id_tokens ( CREATE TABLE IF NOT EXISTS open_id_tokens (
-- This is the hash of the token value -- The value of the token issued to a user
token_hash TEXT NOT NULL PRIMARY KEY, token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account -- The Matrix user ID for this account
localpart TEXT NOT NULL, localpart TEXT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution). -- When the token expires, as a unix timestamp (ms resolution).
token_expires_ts BIGINT NOT NULL token_expires_at_ms BIGINT NOT NULL
); );
` `
const insertTokenSQL = "" + const insertTokenSQL = "" +
"INSERT INTO open_id_tokens(token_hash, localpart, token_expires_ts) VALUES ($1, $2, $3)" "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectTokenSQL = "" + const selectTokenSQL = "" +
"SELECT localpart, token_expires_ts FROM open_id_tokens WHERE token_hash = $1" "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
type tokenStatements struct { type tokenStatements struct {
db *sql.DB db *sql.DB
@ -52,28 +52,28 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam
} }
// insertToken inserts a new OpenID Connect token to the DB. // insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if token hash already exists. // Returns new token, otherwise returns error if the token already exists.
func (s *tokenStatements) insertToken( func (s *tokenStatements) insertToken(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
token_hash, localpart string, token, localpart string,
expiresTimeMS int64, expiresAtMS int64,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, token_hash, localpart, expiresTimeMS) _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
return return
} }
// selectOpenIDTokenAtrributesByTokenHash gets the attributes associated with an OpenID token from the DB // selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found // Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributesByTokenHash( func (s *tokenStatements) selectOpenIDTokenAtrributes(
ctx context.Context, ctx context.Context,
tokenHash string, token string,
) (*api.OpenIDTokenAttributes, error) { ) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes var openIDTokenAttrs api.OpenIDTokenAttributes
err := s.selectTokenStmt.QueryRowContext(ctx, tokenHash).Scan( err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDTokenAttrs.UserID, &openIDTokenAttrs.UserID,
&openIDTokenAttrs.ExpiresTS, &openIDTokenAttrs.ExpiresAtMS,
) )
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {

View file

@ -21,6 +21,7 @@ import (
"errors" "errors"
"strconv" "strconv"
"sync" "sync"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -37,32 +38,33 @@ type Database struct {
writer sqlutil.Writer writer sqlutil.Writer
sqlutil.PartitionOffsetStatements sqlutil.PartitionOffsetStatements
accounts accountsStatements accounts accountsStatements
profiles profilesStatements profiles profilesStatements
accountDatas accountDataStatements accountDatas accountDataStatements
threepids threepidStatements threepids threepidStatements
openIDTokens tokenStatements openIDTokens tokenStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
bcryptCost int bcryptCost int
openIDTokenLifetimeMS int64
accountsMu sync.Mutex accountsMu sync.Mutex
profilesMu sync.Mutex profilesMu sync.Mutex
accountDatasMu sync.Mutex accountDatasMu sync.Mutex
threepidsMu sync.Mutex threepidsMu sync.Mutex
openIDsMu sync.Mutex
} }
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
db, err := sqlutil.Open(dbProperties) db, err := sqlutil.Open(dbProperties)
if err != nil { if err != nil {
return nil, err return nil, err
} }
d := &Database{ d := &Database{
serverName: serverName, serverName: serverName,
db: db, db: db,
writer: sqlutil.NewExclusiveWriter(), writer: sqlutil.NewExclusiveWriter(),
bcryptCost: bcryptCost, bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
} }
// Create tables before executing migrations so we don't fail if the table is missing, // Create tables before executing migrations so we don't fail if the table is missing,
@ -388,20 +390,19 @@ func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err
// CreateOpenIDToken persists a new token that was issued for OpenID Connect // CreateOpenIDToken persists a new token that was issued for OpenID Connect
func (d *Database) CreateOpenIDToken( func (d *Database) CreateOpenIDToken(
ctx context.Context, ctx context.Context,
tokenHash, localpart string, token, localpart string,
expirationTS int64, ) (int64, error) {
) error { expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
d.openIDsMu.Lock() err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
defer d.openIDsMu.Unlock() return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, tokenHash, localpart, expirationTS)
}) })
return expiresAtMS, err
} }
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token // GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
func (d *Database) GetOpenIDTokenAttributes( func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context, ctx context.Context,
tokenHash string, token string,
) (*api.OpenIDTokenAttributes, error) { ) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributesByTokenHash(ctx, tokenHash) return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
} }

View file

@ -27,12 +27,12 @@ import (
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters // and sets postgres connection parameters
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost) return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
case dbProperties.ConnectionString.IsPostgres(): case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(dbProperties, serverName, bcryptCost) return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
default: default:
return nil, fmt.Errorf("unexpected database type") return nil, fmt.Errorf("unexpected database type")
} }