OpenID module tweaks (#599)

- specify expiry is ms rather than vague ts
- add OpenID token lifetime to configuration
- use Go naming conventions for the path params
- store plaintext token rather than hash
- remove openid table sqllite mutex
This commit is contained in:
Bruce MacDonald 2021-04-03 14:22:07 -07:00
parent 5f1272c74f
commit 8b915d3e2a
14 changed files with 107 additions and 106 deletions

View file

@ -64,7 +64,7 @@ func CreateOpenIDToken(
AccessToken: response.Token.Token,
TokenType: "Bearer",
MatrixServerName: string(cfg.Matrix.ServerName),
ExpiresIn: response.Token.ExpiresTS / 1000, // convert ms to s
ExpiresIn: response.Token.ExpiresAtMS / 1000, // convert ms to s
},
}
}

View file

@ -685,7 +685,7 @@ func Setup(
}),
).Methods(http.MethodGet)
r0mux.Handle("/user/{userId}/openid/request_token",
r0mux.Handle("/user/{userID}/openid/request_token",
httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.rateLimit(req); r != nil {
return *r
@ -694,7 +694,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return CreateOpenIDToken(req, userAPI, device, vars["userId"], cfg)
return CreateOpenIDToken(req, userAPI, device, vars["userID"], cfg)
}),
).Methods(http.MethodPost, http.MethodOptions)

View file

@ -360,6 +360,11 @@ user_api:
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
# The length of time that a token issed for a relying party from the
# /_matrix/client/r0/user/{userId}/openid/request_token endpoint
# is considered to be valid in milliseconds.
# The default lifetime is 3600000ms (60 minutes).
# openid_token_lifetime_ms: 3600000
# Configuration for Opentracing.
# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on

View file

@ -53,7 +53,7 @@ func GetOpenIDUserInfo(
var res interface{} = openIDUserInfoResponse{Sub: openIDTokenAttrResponse.Sub}
code := http.StatusOK
nowMS := time.Now().UnixNano() / int64(time.Millisecond)
if openIDTokenAttrResponse.Sub == "" || nowMS > openIDTokenAttrResponse.ExpiresTS {
if openIDTokenAttrResponse.Sub == "" || nowMS > openIDTokenAttrResponse.ExpiresAtMS {
code = http.StatusUnauthorized
res = jsonerror.UnknownToken("Access Token unknown or expired")
}

View file

@ -263,7 +263,7 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI {
// CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component.
func (b *BaseDendrite) CreateAccountsDB() accounts.Database {
db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost)
db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to accounts db")
}

View file

@ -10,6 +10,9 @@ type UserAPI struct {
// The cost when hashing passwords.
BCryptCost int `yaml:"bcrypt_cost"`
// The length of time an OpenID token is condidered valid in milliseconds
OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"`
// The Account database stores the login details and account information
// for local users. It is accessed by the UserAPI.
AccountDatabase DatabaseOptions `yaml:"account_database"`
@ -26,6 +29,7 @@ func (c *UserAPI) Defaults() {
c.AccountDatabase.ConnectionString = "file:userapi_accounts.db"
c.DeviceDatabase.ConnectionString = "file:userapi_devices.db"
c.BCryptCost = bcrypt.DefaultCost
c.OpenIDTokenLifetimeMS = 3600000 // 60 minutes
}
func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
@ -33,4 +37,5 @@ func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect))
checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString))
checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString))
checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS)
}

View file

@ -245,8 +245,8 @@ type QueryOpenIDTokenRequest struct {
// QueryOpenIDTokenResponse is the response for QueryOpenIDToken
type QueryOpenIDTokenResponse struct {
Sub string // The Matrix User ID that generated the token
ExpiresTS int64
Sub string // The Matrix User ID that generated the token
ExpiresAtMS int64
}
// Device represents a client's device (mobile, web, etc)
@ -281,15 +281,15 @@ type Account struct {
// OpenIDToken represents an OpenID token
type OpenIDToken struct {
Token string
UserID string
ExpiresTS int64
Token string
UserID string
ExpiresAtMS int64
}
// OpenIDTokenInfo represents the attributes associated with an issued OpenID token
type OpenIDTokenAttributes struct {
UserID string
ExpiresTS int64
UserID string
ExpiresAtMS int64
}
// UserInfo is for returning information about the user an OpenID token was issued for

View file

@ -20,9 +20,6 @@ import (
"encoding/json"
"errors"
"fmt"
"hash"
"hash/fnv"
"time"
"github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/userutil"
@ -421,16 +418,13 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
// PerformOpenIDTokenCreation creates a new token that a relying party uses to authenticate a user
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
token := util.RandomString(24)
tokenHash := getTokenHash(token)
nowMS := time.Now().UnixNano() / int64(time.Millisecond)
expiresMS := nowMS + (3600 * 1000) // 60 minutes
err := a.AccountDB.CreateOpenIDToken(ctx, fmt.Sprint(tokenHash.Sum32()), req.UserID, expiresMS)
exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID)
res.Token = api.OpenIDToken{
Token: token,
UserID: req.UserID,
ExpiresTS: expiresMS,
Token: token,
UserID: req.UserID,
ExpiresAtMS: exp,
}
return err
@ -438,21 +432,13 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
tokenHash := getTokenHash(req.Token)
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, fmt.Sprint(tokenHash.Sum32()))
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token)
if err != nil {
return err
}
res.Sub = openIDTokenAttrs.UserID
res.ExpiresTS = openIDTokenAttrs.ExpiresTS
res.ExpiresAtMS = openIDTokenAttrs.ExpiresAtMS
return nil
}
// used for getting the format in which the token is stored to prevent plaintext storage of sensitive info
func getTokenHash(token string) hash.Hash32 {
tokenHash := fnv.New32a()
_, _ = tokenHash.Write([]byte(token))
return tokenHash
}

View file

@ -52,8 +52,8 @@ type Database interface {
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
DeactivateAccount(ctx context.Context, localpart string) (err error)
CreateOpenIDToken(ctx context.Context, tokenHash, localpart string, expirationTS int64) (err error)
GetOpenIDTokenAttributes(ctx context.Context, tokenHash string) (*api.OpenIDTokenAttributes, error)
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
}
// Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -13,20 +13,20 @@ import (
const openIDTokenSchema = `
-- Stores data about openid tokens issued for accounts.
CREATE TABLE IF NOT EXISTS open_id_tokens (
-- This is a hash the token value
token_hash TEXT NOT NULL PRIMARY KEY,
-- The value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account
localpart TEXT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution).
token_expires_ts BIGINT NOT NULL
token_expires_at_ms BIGINT NOT NULL
);
`
const insertTokenSQL = "" +
"INSERT INTO open_id_tokens(token_hash, localpart, token_expires_ts) VALUES ($1, $2, $3)"
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectTokenSQL = "" +
"SELECT localpart, token_expires_ts FROM open_id_tokens WHERE token_hash = $1"
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
type tokenStatements struct {
insertTokenStmt *sql.Stmt
@ -50,28 +50,28 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam
}
// insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if token hash already exists.
// Returns new token, otherwise returns error if the token already exists.
func (s *tokenStatements) insertToken(
ctx context.Context,
txn *sql.Tx,
tokenHash, localpart string,
expiresTimeMS int64,
token, localpart string,
expiresAtMS int64,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, tokenHash, localpart, expiresTimeMS)
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
return
}
// selectOpenIDTokenAtrributesByTokenHash gets the attributes associated with an OpenID token from the DB
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributesByTokenHash(
func (s *tokenStatements) selectOpenIDTokenAtrributes(
ctx context.Context,
tokenHash string,
token string,
) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes
err := s.selectTokenStmt.QueryRowContext(ctx, tokenHash).Scan(
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDTokenAttrs.UserID,
&openIDTokenAttrs.ExpiresTS,
&openIDTokenAttrs.ExpiresAtMS,
)
if err != nil {
if err != sql.ErrNoRows {

View file

@ -20,6 +20,7 @@ import (
"encoding/json"
"errors"
"strconv"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil"
@ -39,26 +40,28 @@ type Database struct {
db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
accountDatas accountDataStatements
threepids threepidStatements
openIDTokens tokenStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
accounts accountsStatements
profiles profilesStatements
accountDatas accountDataStatements
threepids threepidStatements
openIDTokens tokenStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (*Database, error) {
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
d := &Database{
serverName: serverName,
db: db,
writer: sqlutil.NewDummyWriter(),
bcryptCost: bcryptCost,
serverName: serverName,
db: db,
writer: sqlutil.NewDummyWriter(),
bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
// Create tables before executing migrations so we don't fail if the table is missing,
@ -349,18 +352,19 @@ func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err
// CreateOpenIDToken persists a new token that was issued through OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
tokenHash, localpart string,
expirationTS int64,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, tokenHash, localpart, expirationTS)
token, localpart string,
) (int64, error) {
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
})
return expiresAtMS, err
}
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context,
tokenHash string,
token string,
) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributesByTokenHash(ctx, tokenHash)
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
}

View file

@ -13,20 +13,20 @@ import (
const openIDTokenSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS open_id_tokens (
-- This is the hash of the token value
token_hash TEXT NOT NULL PRIMARY KEY,
-- The value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account
localpart TEXT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution).
token_expires_ts BIGINT NOT NULL
token_expires_at_ms BIGINT NOT NULL
);
`
const insertTokenSQL = "" +
"INSERT INTO open_id_tokens(token_hash, localpart, token_expires_ts) VALUES ($1, $2, $3)"
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectTokenSQL = "" +
"SELECT localpart, token_expires_ts FROM open_id_tokens WHERE token_hash = $1"
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
type tokenStatements struct {
db *sql.DB
@ -52,28 +52,28 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam
}
// insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if token hash already exists.
// Returns new token, otherwise returns error if the token already exists.
func (s *tokenStatements) insertToken(
ctx context.Context,
txn *sql.Tx,
token_hash, localpart string,
expiresTimeMS int64,
token, localpart string,
expiresAtMS int64,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, token_hash, localpart, expiresTimeMS)
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
return
}
// selectOpenIDTokenAtrributesByTokenHash gets the attributes associated with an OpenID token from the DB
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributesByTokenHash(
func (s *tokenStatements) selectOpenIDTokenAtrributes(
ctx context.Context,
tokenHash string,
token string,
) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes
err := s.selectTokenStmt.QueryRowContext(ctx, tokenHash).Scan(
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDTokenAttrs.UserID,
&openIDTokenAttrs.ExpiresTS,
&openIDTokenAttrs.ExpiresAtMS,
)
if err != nil {
if err != sql.ErrNoRows {

View file

@ -21,6 +21,7 @@ import (
"errors"
"strconv"
"sync"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil"
@ -37,32 +38,33 @@ type Database struct {
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
accountDatas accountDataStatements
threepids threepidStatements
openIDTokens tokenStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
accounts accountsStatements
profiles profilesStatements
accountDatas accountDataStatements
threepids threepidStatements
openIDTokens tokenStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
accountsMu sync.Mutex
profilesMu sync.Mutex
accountDatasMu sync.Mutex
threepidsMu sync.Mutex
openIDsMu sync.Mutex
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (*Database, error) {
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
d := &Database{
serverName: serverName,
db: db,
writer: sqlutil.NewExclusiveWriter(),
bcryptCost: bcryptCost,
serverName: serverName,
db: db,
writer: sqlutil.NewExclusiveWriter(),
bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
// Create tables before executing migrations so we don't fail if the table is missing,
@ -388,20 +390,19 @@ func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
tokenHash, localpart string,
expirationTS int64,
) error {
d.openIDsMu.Lock()
defer d.openIDsMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, tokenHash, localpart, expirationTS)
token, localpart string,
) (int64, error) {
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
})
return expiresAtMS, err
}
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context,
tokenHash string,
token string,
) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributesByTokenHash(ctx, tokenHash)
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
}

View file

@ -27,12 +27,12 @@ import (
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (Database, error) {
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost)
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(dbProperties, serverName, bcryptCost)
return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
default:
return nil, fmt.Errorf("unexpected database type")
}