diff --git a/userapi/storage/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go index 1beac9757..953956aa5 100644 --- a/userapi/storage/postgres/openid_table.go +++ b/userapi/storage/postgres/openid_table.go @@ -3,6 +3,7 @@ package postgres import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" @@ -25,10 +26,10 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens ( ` const insertOpenIDTokenSQL = "" + - "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + "INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)" const selectOpenIDTokenSQL = "" + - "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" + "SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" type openIDTokenStatements struct { insertTokenStmt *sql.Stmt @@ -55,11 +56,11 @@ func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, - token, localpart string, + token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) - _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) + _, err = stmt.ExecContext(ctx, token, serverName, localpart, expiresAtMS) return } @@ -70,10 +71,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( token string, ) (*api.OpenIDTokenAttributes, error) { var openIDTokenAttrs api.OpenIDTokenAttributes + var localpart string + var serverName gomatrixserverlib.ServerName err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( - &openIDTokenAttrs.UserID, + &localpart, &serverName, &openIDTokenAttrs.ExpiresAtMS, ) + openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve token from the db") diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 2a0edd81f..c02ad6d12 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -383,11 +383,15 @@ func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serv // CreateOpenIDToken persists a new token that was issued for OpenID Connect func (d *Database) CreateOpenIDToken( ctx context.Context, - token, localpart string, + token, userID string, ) (int64, error) { + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return 0, nil + } expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS - err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, domain, expiresAtMS) }) return expiresAtMS, err } diff --git a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go index 9158cb365..2de85005f 100644 --- a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go +++ b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go @@ -11,6 +11,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error { ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, + server_name TEXT NOT NULL, created_ts BIGINT NOT NULL, password_hash TEXT, appservice_id TEXT, diff --git a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go index a9224db6b..636ce4efc 100644 --- a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go @@ -14,6 +14,7 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { session_id INTEGER, device_id TEXT , localpart TEXT , + server_name TEXT NOT NULL, created_ts BIGINT, display_name TEXT, last_seen_ts BIGINT, diff --git a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go index 230bc1433..471e496cd 100644 --- a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go +++ b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go @@ -12,6 +12,7 @@ func UpAddAccountType(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, + server_name TEXT NOT NULL, created_ts BIGINT NOT NULL, password_hash TEXT, appservice_id TEXT, diff --git a/userapi/storage/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go index 0e334f481..0201c39ba 100644 --- a/userapi/storage/sqlite3/openid_table.go +++ b/userapi/storage/sqlite3/openid_table.go @@ -3,6 +3,7 @@ package sqlite3 import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" @@ -25,10 +26,10 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens ( ` const insertOpenIDTokenSQL = "" + - "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + "INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)" const selectOpenIDTokenSQL = "" + - "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" + "SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" type openIDTokenStatements struct { db *sql.DB @@ -57,11 +58,11 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) ( func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, - token, localpart string, + token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) - _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) + _, err = stmt.ExecContext(ctx, token, serverName, localpart, expiresAtMS) return } @@ -72,10 +73,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( token string, ) (*api.OpenIDTokenAttributes, error) { var openIDTokenAttrs api.OpenIDTokenAttributes + var localpart string + var serverName gomatrixserverlib.ServerName err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( - &openIDTokenAttrs.UserID, + &localpart, &serverName, &openIDTokenAttrs.ExpiresAtMS, ) + openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve token from the db") diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 4ae2961f4..e14776cf3 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -79,7 +79,7 @@ type LoginTokenTable interface { } type OpenIDTable interface { - InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error) + InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64) (err error) SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) } diff --git a/userapi/storage/tables/stats_table_test.go b/userapi/storage/tables/stats_table_test.go index 73007afdc..b088d15cd 100644 --- a/userapi/storage/tables/stats_table_test.go +++ b/userapi/storage/tables/stats_table_test.go @@ -79,7 +79,7 @@ func mustMakeAccountAndDevice( accDB tables.AccountsTable, devDB tables.DevicesTable, localpart string, - serverName gomatrixserverlib.ServerName, + serverName gomatrixserverlib.ServerName, // nolint:unparam accType api.AccountType, userAgent string, ) {