This commit is contained in:
Neil Alexander 2022-11-07 16:01:12 +00:00
parent 62dd0afc0b
commit be2f90f0b8
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
8 changed files with 30 additions and 15 deletions

View file

@ -3,6 +3,7 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
@ -25,10 +26,10 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
` `
const insertOpenIDTokenSQL = "" + const insertOpenIDTokenSQL = "" +
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" "INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
const selectOpenIDTokenSQL = "" + const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" "SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
type openIDTokenStatements struct { type openIDTokenStatements struct {
insertTokenStmt *sql.Stmt insertTokenStmt *sql.Stmt
@ -55,11 +56,11 @@ func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
func (s *openIDTokenStatements) InsertOpenIDToken( func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
token, localpart string, token, localpart string, serverName gomatrixserverlib.ServerName,
expiresAtMS int64, expiresAtMS int64,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) _, err = stmt.ExecContext(ctx, token, serverName, localpart, expiresAtMS)
return return
} }
@ -70,10 +71,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
token string, token string,
) (*api.OpenIDTokenAttributes, error) { ) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes var openIDTokenAttrs api.OpenIDTokenAttributes
var localpart string
var serverName gomatrixserverlib.ServerName
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDTokenAttrs.UserID, &localpart, &serverName,
&openIDTokenAttrs.ExpiresAtMS, &openIDTokenAttrs.ExpiresAtMS,
) )
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve token from the db") log.WithError(err).Error("Unable to retrieve token from the db")

View file

@ -383,11 +383,15 @@ func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serv
// CreateOpenIDToken persists a new token that was issued for OpenID Connect // CreateOpenIDToken persists a new token that was issued for OpenID Connect
func (d *Database) CreateOpenIDToken( func (d *Database) CreateOpenIDToken(
ctx context.Context, ctx context.Context,
token, localpart string, token, userID string,
) (int64, error) { ) (int64, error) {
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return 0, nil
}
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS) return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, domain, expiresAtMS)
}) })
return expiresAtMS, err return expiresAtMS, err
} }

View file

@ -11,6 +11,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error {
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts ( CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL PRIMARY KEY,
server_name TEXT NOT NULL,
created_ts BIGINT NOT NULL, created_ts BIGINT NOT NULL,
password_hash TEXT, password_hash TEXT,
appservice_id TEXT, appservice_id TEXT,

View file

@ -14,6 +14,7 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
session_id INTEGER, session_id INTEGER,
device_id TEXT , device_id TEXT ,
localpart TEXT , localpart TEXT ,
server_name TEXT NOT NULL,
created_ts BIGINT, created_ts BIGINT,
display_name TEXT, display_name TEXT,
last_seen_ts BIGINT, last_seen_ts BIGINT,

View file

@ -12,6 +12,7 @@ func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts ( CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL PRIMARY KEY,
server_name TEXT NOT NULL,
created_ts BIGINT NOT NULL, created_ts BIGINT NOT NULL,
password_hash TEXT, password_hash TEXT,
appservice_id TEXT, appservice_id TEXT,

View file

@ -3,6 +3,7 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
@ -25,10 +26,10 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
` `
const insertOpenIDTokenSQL = "" + const insertOpenIDTokenSQL = "" +
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" "INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
const selectOpenIDTokenSQL = "" + const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" "SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
type openIDTokenStatements struct { type openIDTokenStatements struct {
db *sql.DB db *sql.DB
@ -57,11 +58,11 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (
func (s *openIDTokenStatements) InsertOpenIDToken( func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
token, localpart string, token, localpart string, serverName gomatrixserverlib.ServerName,
expiresAtMS int64, expiresAtMS int64,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) _, err = stmt.ExecContext(ctx, token, serverName, localpart, expiresAtMS)
return return
} }
@ -72,10 +73,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
token string, token string,
) (*api.OpenIDTokenAttributes, error) { ) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes var openIDTokenAttrs api.OpenIDTokenAttributes
var localpart string
var serverName gomatrixserverlib.ServerName
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDTokenAttrs.UserID, &localpart, &serverName,
&openIDTokenAttrs.ExpiresAtMS, &openIDTokenAttrs.ExpiresAtMS,
) )
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve token from the db") log.WithError(err).Error("Unable to retrieve token from the db")

View file

@ -79,7 +79,7 @@ type LoginTokenTable interface {
} }
type OpenIDTable interface { type OpenIDTable interface {
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error) InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64) (err error)
SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
} }

View file

@ -79,7 +79,7 @@ func mustMakeAccountAndDevice(
accDB tables.AccountsTable, accDB tables.AccountsTable,
devDB tables.DevicesTable, devDB tables.DevicesTable,
localpart string, localpart string,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName, // nolint:unparam
accType api.AccountType, accType api.AccountType,
userAgent string, userAgent string,
) { ) {