This commit is contained in:
Neil Alexander 2022-11-07 16:01:12 +00:00
parent 62dd0afc0b
commit be2f90f0b8
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
8 changed files with 30 additions and 15 deletions

View file

@ -3,6 +3,7 @@ package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
@ -25,10 +26,10 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
`
const insertOpenIDTokenSQL = "" +
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
"INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
"SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
type openIDTokenStatements struct {
insertTokenStmt *sql.Stmt
@ -55,11 +56,11 @@ func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context,
txn *sql.Tx,
token, localpart string,
token, localpart string, serverName gomatrixserverlib.ServerName,
expiresAtMS int64,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
_, err = stmt.ExecContext(ctx, token, serverName, localpart, expiresAtMS)
return
}
@ -70,10 +71,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
token string,
) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes
var localpart string
var serverName gomatrixserverlib.ServerName
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDTokenAttrs.UserID,
&localpart, &serverName,
&openIDTokenAttrs.ExpiresAtMS,
)
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve token from the db")

View file

@ -383,11 +383,15 @@ func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serv
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
token, localpart string,
token, userID string,
) (int64, error) {
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return 0, nil
}
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, domain, expiresAtMS)
})
return expiresAtMS, err
}

View file

@ -11,6 +11,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error {
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
server_name TEXT NOT NULL,
created_ts BIGINT NOT NULL,
password_hash TEXT,
appservice_id TEXT,

View file

@ -14,6 +14,7 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
server_name TEXT NOT NULL,
created_ts BIGINT,
display_name TEXT,
last_seen_ts BIGINT,

View file

@ -12,6 +12,7 @@ func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
server_name TEXT NOT NULL,
created_ts BIGINT NOT NULL,
password_hash TEXT,
appservice_id TEXT,

View file

@ -3,6 +3,7 @@ package sqlite3
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
@ -25,10 +26,10 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
`
const insertOpenIDTokenSQL = "" +
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
"INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
"SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
type openIDTokenStatements struct {
db *sql.DB
@ -57,11 +58,11 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (
func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context,
txn *sql.Tx,
token, localpart string,
token, localpart string, serverName gomatrixserverlib.ServerName,
expiresAtMS int64,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
_, err = stmt.ExecContext(ctx, token, serverName, localpart, expiresAtMS)
return
}
@ -72,10 +73,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
token string,
) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes
var localpart string
var serverName gomatrixserverlib.ServerName
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDTokenAttrs.UserID,
&localpart, &serverName,
&openIDTokenAttrs.ExpiresAtMS,
)
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve token from the db")

View file

@ -79,7 +79,7 @@ type LoginTokenTable interface {
}
type OpenIDTable interface {
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error)
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64) (err error)
SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
}

View file

@ -79,7 +79,7 @@ func mustMakeAccountAndDevice(
accDB tables.AccountsTable,
devDB tables.DevicesTable,
localpart string,
serverName gomatrixserverlib.ServerName,
serverName gomatrixserverlib.ServerName, // nolint:unparam
accType api.AccountType,
userAgent string,
) {