OpenID
This commit is contained in:
parent
62dd0afc0b
commit
be2f90f0b8
|
@ -3,6 +3,7 @@ package postgres
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
|
@ -25,10 +26,10 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
|
|||
`
|
||||
|
||||
const insertOpenIDTokenSQL = "" +
|
||||
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const selectOpenIDTokenSQL = "" +
|
||||
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||
"SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||
|
||||
type openIDTokenStatements struct {
|
||||
insertTokenStmt *sql.Stmt
|
||||
|
@ -55,11 +56,11 @@ func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
|||
func (s *openIDTokenStatements) InsertOpenIDToken(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
token, localpart string,
|
||||
token, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
expiresAtMS int64,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
||||
_, err = stmt.ExecContext(ctx, token, serverName, localpart, expiresAtMS)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -70,10 +71,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
|
|||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||
var localpart string
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||
&openIDTokenAttrs.UserID,
|
||||
&localpart, &serverName,
|
||||
&openIDTokenAttrs.ExpiresAtMS,
|
||||
)
|
||||
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve token from the db")
|
||||
|
|
|
@ -383,11 +383,15 @@ func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serv
|
|||
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
|
||||
func (d *Database) CreateOpenIDToken(
|
||||
ctx context.Context,
|
||||
token, localpart string,
|
||||
token, userID string,
|
||||
) (int64, error) {
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return 0, nil
|
||||
}
|
||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
|
||||
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS)
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, domain, expiresAtMS)
|
||||
})
|
||||
return expiresAtMS, err
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error {
|
|||
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
||||
CREATE TABLE userapi_accounts (
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
server_name TEXT NOT NULL,
|
||||
created_ts BIGINT NOT NULL,
|
||||
password_hash TEXT,
|
||||
appservice_id TEXT,
|
||||
|
|
|
@ -14,6 +14,7 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
|
|||
session_id INTEGER,
|
||||
device_id TEXT ,
|
||||
localpart TEXT ,
|
||||
server_name TEXT NOT NULL,
|
||||
created_ts BIGINT,
|
||||
display_name TEXT,
|
||||
last_seen_ts BIGINT,
|
||||
|
|
|
@ -12,6 +12,7 @@ func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
|
|||
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
||||
CREATE TABLE userapi_accounts (
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
server_name TEXT NOT NULL,
|
||||
created_ts BIGINT NOT NULL,
|
||||
password_hash TEXT,
|
||||
appservice_id TEXT,
|
||||
|
|
|
@ -3,6 +3,7 @@ package sqlite3
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
|
@ -25,10 +26,10 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
|
|||
`
|
||||
|
||||
const insertOpenIDTokenSQL = "" +
|
||||
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const selectOpenIDTokenSQL = "" +
|
||||
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||
"SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||
|
||||
type openIDTokenStatements struct {
|
||||
db *sql.DB
|
||||
|
@ -57,11 +58,11 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (
|
|||
func (s *openIDTokenStatements) InsertOpenIDToken(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
token, localpart string,
|
||||
token, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
expiresAtMS int64,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
||||
_, err = stmt.ExecContext(ctx, token, serverName, localpart, expiresAtMS)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -72,10 +73,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
|
|||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||
var localpart string
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||
&openIDTokenAttrs.UserID,
|
||||
&localpart, &serverName,
|
||||
&openIDTokenAttrs.ExpiresAtMS,
|
||||
)
|
||||
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve token from the db")
|
||||
|
|
|
@ -79,7 +79,7 @@ type LoginTokenTable interface {
|
|||
}
|
||||
|
||||
type OpenIDTable interface {
|
||||
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error)
|
||||
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64) (err error)
|
||||
SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||
}
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ func mustMakeAccountAndDevice(
|
|||
accDB tables.AccountsTable,
|
||||
devDB tables.DevicesTable,
|
||||
localpart string,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
serverName gomatrixserverlib.ServerName, // nolint:unparam
|
||||
accType api.AccountType,
|
||||
userAgent string,
|
||||
) {
|
||||
|
|
Loading…
Reference in a new issue