OpenID module tweaks (#599)

- specify expiry is ms rather than vague ts
- add OpenID token lifetime to configuration
- use Go naming conventions for the path params
- store plaintext token rather than hash
- remove openid table sqllite mutex
This commit is contained in:
Bruce MacDonald 2021-04-03 14:22:07 -07:00
parent 5f1272c74f
commit 8b915d3e2a
14 changed files with 107 additions and 106 deletions

View file

@ -64,7 +64,7 @@ func CreateOpenIDToken(
AccessToken: response.Token.Token,
TokenType: "Bearer",
MatrixServerName: string(cfg.Matrix.ServerName),
ExpiresIn: response.Token.ExpiresTS / 1000, // convert ms to s
ExpiresIn: response.Token.ExpiresAtMS / 1000, // convert ms to s
},
}
}

View file

@ -685,7 +685,7 @@ func Setup(
}),
).Methods(http.MethodGet)
r0mux.Handle("/user/{userId}/openid/request_token",
r0mux.Handle("/user/{userID}/openid/request_token",
httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.rateLimit(req); r != nil {
return *r
@ -694,7 +694,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return CreateOpenIDToken(req, userAPI, device, vars["userId"], cfg)
return CreateOpenIDToken(req, userAPI, device, vars["userID"], cfg)
}),
).Methods(http.MethodPost, http.MethodOptions)

View file

@ -360,6 +360,11 @@ user_api:
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
# The length of time that a token issed for a relying party from the
# /_matrix/client/r0/user/{userId}/openid/request_token endpoint
# is considered to be valid in milliseconds.
# The default lifetime is 3600000ms (60 minutes).
# openid_token_lifetime_ms: 3600000
# Configuration for Opentracing.
# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on

View file

@ -53,7 +53,7 @@ func GetOpenIDUserInfo(
var res interface{} = openIDUserInfoResponse{Sub: openIDTokenAttrResponse.Sub}
code := http.StatusOK
nowMS := time.Now().UnixNano() / int64(time.Millisecond)
if openIDTokenAttrResponse.Sub == "" || nowMS > openIDTokenAttrResponse.ExpiresTS {
if openIDTokenAttrResponse.Sub == "" || nowMS > openIDTokenAttrResponse.ExpiresAtMS {
code = http.StatusUnauthorized
res = jsonerror.UnknownToken("Access Token unknown or expired")
}

View file

@ -263,7 +263,7 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI {
// CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component.
func (b *BaseDendrite) CreateAccountsDB() accounts.Database {
db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost)
db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to accounts db")
}

View file

@ -10,6 +10,9 @@ type UserAPI struct {
// The cost when hashing passwords.
BCryptCost int `yaml:"bcrypt_cost"`
// The length of time an OpenID token is condidered valid in milliseconds
OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"`
// The Account database stores the login details and account information
// for local users. It is accessed by the UserAPI.
AccountDatabase DatabaseOptions `yaml:"account_database"`
@ -26,6 +29,7 @@ func (c *UserAPI) Defaults() {
c.AccountDatabase.ConnectionString = "file:userapi_accounts.db"
c.DeviceDatabase.ConnectionString = "file:userapi_devices.db"
c.BCryptCost = bcrypt.DefaultCost
c.OpenIDTokenLifetimeMS = 3600000 // 60 minutes
}
func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
@ -33,4 +37,5 @@ func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect))
checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString))
checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString))
checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS)
}

View file

@ -246,7 +246,7 @@ type QueryOpenIDTokenRequest struct {
// QueryOpenIDTokenResponse is the response for QueryOpenIDToken
type QueryOpenIDTokenResponse struct {
Sub string // The Matrix User ID that generated the token
ExpiresTS int64
ExpiresAtMS int64
}
// Device represents a client's device (mobile, web, etc)
@ -283,13 +283,13 @@ type Account struct {
type OpenIDToken struct {
Token string
UserID string
ExpiresTS int64
ExpiresAtMS int64
}
// OpenIDTokenInfo represents the attributes associated with an issued OpenID token
type OpenIDTokenAttributes struct {
UserID string
ExpiresTS int64
ExpiresAtMS int64
}
// UserInfo is for returning information about the user an OpenID token was issued for

View file

@ -20,9 +20,6 @@ import (
"encoding/json"
"errors"
"fmt"
"hash"
"hash/fnv"
"time"
"github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/userutil"
@ -421,16 +418,13 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
// PerformOpenIDTokenCreation creates a new token that a relying party uses to authenticate a user
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
token := util.RandomString(24)
tokenHash := getTokenHash(token)
nowMS := time.Now().UnixNano() / int64(time.Millisecond)
expiresMS := nowMS + (3600 * 1000) // 60 minutes
err := a.AccountDB.CreateOpenIDToken(ctx, fmt.Sprint(tokenHash.Sum32()), req.UserID, expiresMS)
exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID)
res.Token = api.OpenIDToken{
Token: token,
UserID: req.UserID,
ExpiresTS: expiresMS,
ExpiresAtMS: exp,
}
return err
@ -438,21 +432,13 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
tokenHash := getTokenHash(req.Token)
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, fmt.Sprint(tokenHash.Sum32()))
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token)
if err != nil {
return err
}
res.Sub = openIDTokenAttrs.UserID
res.ExpiresTS = openIDTokenAttrs.ExpiresTS
res.ExpiresAtMS = openIDTokenAttrs.ExpiresAtMS
return nil
}
// used for getting the format in which the token is stored to prevent plaintext storage of sensitive info
func getTokenHash(token string) hash.Hash32 {
tokenHash := fnv.New32a()
_, _ = tokenHash.Write([]byte(token))
return tokenHash
}

View file

@ -52,8 +52,8 @@ type Database interface {
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
DeactivateAccount(ctx context.Context, localpart string) (err error)
CreateOpenIDToken(ctx context.Context, tokenHash, localpart string, expirationTS int64) (err error)
GetOpenIDTokenAttributes(ctx context.Context, tokenHash string) (*api.OpenIDTokenAttributes, error)
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
}
// Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -13,20 +13,20 @@ import (
const openIDTokenSchema = `
-- Stores data about openid tokens issued for accounts.
CREATE TABLE IF NOT EXISTS open_id_tokens (
-- This is a hash the token value
token_hash TEXT NOT NULL PRIMARY KEY,
-- The value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account
localpart TEXT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution).
token_expires_ts BIGINT NOT NULL
token_expires_at_ms BIGINT NOT NULL
);
`
const insertTokenSQL = "" +
"INSERT INTO open_id_tokens(token_hash, localpart, token_expires_ts) VALUES ($1, $2, $3)"
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectTokenSQL = "" +
"SELECT localpart, token_expires_ts FROM open_id_tokens WHERE token_hash = $1"
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
type tokenStatements struct {
insertTokenStmt *sql.Stmt
@ -50,28 +50,28 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam
}
// insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if token hash already exists.
// Returns new token, otherwise returns error if the token already exists.
func (s *tokenStatements) insertToken(
ctx context.Context,
txn *sql.Tx,
tokenHash, localpart string,
expiresTimeMS int64,
token, localpart string,
expiresAtMS int64,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, tokenHash, localpart, expiresTimeMS)
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
return
}
// selectOpenIDTokenAtrributesByTokenHash gets the attributes associated with an OpenID token from the DB
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributesByTokenHash(
func (s *tokenStatements) selectOpenIDTokenAtrributes(
ctx context.Context,
tokenHash string,
token string,
) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes
err := s.selectTokenStmt.QueryRowContext(ctx, tokenHash).Scan(
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDTokenAttrs.UserID,
&openIDTokenAttrs.ExpiresTS,
&openIDTokenAttrs.ExpiresAtMS,
)
if err != nil {
if err != sql.ErrNoRows {

View file

@ -20,6 +20,7 @@ import (
"encoding/json"
"errors"
"strconv"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil"
@ -46,10 +47,11 @@ type Database struct {
openIDTokens tokenStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (*Database, error) {
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
@ -59,6 +61,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
db: db,
writer: sqlutil.NewDummyWriter(),
bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
// Create tables before executing migrations so we don't fail if the table is missing,
@ -349,18 +352,19 @@ func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err
// CreateOpenIDToken persists a new token that was issued through OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
tokenHash, localpart string,
expirationTS int64,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, tokenHash, localpart, expirationTS)
token, localpart string,
) (int64, error) {
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
})
return expiresAtMS, err
}
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context,
tokenHash string,
token string,
) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributesByTokenHash(ctx, tokenHash)
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
}

View file

@ -13,20 +13,20 @@ import (
const openIDTokenSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS open_id_tokens (
-- This is the hash of the token value
token_hash TEXT NOT NULL PRIMARY KEY,
-- The value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account
localpart TEXT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution).
token_expires_ts BIGINT NOT NULL
token_expires_at_ms BIGINT NOT NULL
);
`
const insertTokenSQL = "" +
"INSERT INTO open_id_tokens(token_hash, localpart, token_expires_ts) VALUES ($1, $2, $3)"
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectTokenSQL = "" +
"SELECT localpart, token_expires_ts FROM open_id_tokens WHERE token_hash = $1"
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
type tokenStatements struct {
db *sql.DB
@ -52,28 +52,28 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam
}
// insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if token hash already exists.
// Returns new token, otherwise returns error if the token already exists.
func (s *tokenStatements) insertToken(
ctx context.Context,
txn *sql.Tx,
token_hash, localpart string,
expiresTimeMS int64,
token, localpart string,
expiresAtMS int64,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, token_hash, localpart, expiresTimeMS)
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
return
}
// selectOpenIDTokenAtrributesByTokenHash gets the attributes associated with an OpenID token from the DB
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributesByTokenHash(
func (s *tokenStatements) selectOpenIDTokenAtrributes(
ctx context.Context,
tokenHash string,
token string,
) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes
err := s.selectTokenStmt.QueryRowContext(ctx, tokenHash).Scan(
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDTokenAttrs.UserID,
&openIDTokenAttrs.ExpiresTS,
&openIDTokenAttrs.ExpiresAtMS,
)
if err != nil {
if err != sql.ErrNoRows {

View file

@ -21,6 +21,7 @@ import (
"errors"
"strconv"
"sync"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil"
@ -44,16 +45,16 @@ type Database struct {
openIDTokens tokenStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
accountsMu sync.Mutex
profilesMu sync.Mutex
accountDatasMu sync.Mutex
threepidsMu sync.Mutex
openIDsMu sync.Mutex
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (*Database, error) {
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
@ -63,6 +64,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
db: db,
writer: sqlutil.NewExclusiveWriter(),
bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
// Create tables before executing migrations so we don't fail if the table is missing,
@ -388,20 +390,19 @@ func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
tokenHash, localpart string,
expirationTS int64,
) error {
d.openIDsMu.Lock()
defer d.openIDsMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, tokenHash, localpart, expirationTS)
token, localpart string,
) (int64, error) {
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
})
return expiresAtMS, err
}
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context,
tokenHash string,
token string,
) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributesByTokenHash(ctx, tokenHash)
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
}

View file

@ -27,12 +27,12 @@ import (
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (Database, error) {
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost)
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(dbProperties, serverName, bcryptCost)
return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
default:
return nil, fmt.Errorf("unexpected database type")
}