OpenID
This commit is contained in:
parent
62dd0afc0b
commit
be2f90f0b8
|
@ -3,6 +3,7 @@ package postgres
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
@ -25,10 +26,10 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertOpenIDTokenSQL = "" +
|
const insertOpenIDTokenSQL = "" +
|
||||||
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
"INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
const selectOpenIDTokenSQL = "" +
|
const selectOpenIDTokenSQL = "" +
|
||||||
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
"SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||||
|
|
||||||
type openIDTokenStatements struct {
|
type openIDTokenStatements struct {
|
||||||
insertTokenStmt *sql.Stmt
|
insertTokenStmt *sql.Stmt
|
||||||
|
@ -55,11 +56,11 @@ func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
||||||
func (s *openIDTokenStatements) InsertOpenIDToken(
|
func (s *openIDTokenStatements) InsertOpenIDToken(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
txn *sql.Tx,
|
txn *sql.Tx,
|
||||||
token, localpart string,
|
token, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
expiresAtMS int64,
|
expiresAtMS int64,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||||
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
_, err = stmt.ExecContext(ctx, token, serverName, localpart, expiresAtMS)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,10 +71,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
|
||||||
token string,
|
token string,
|
||||||
) (*api.OpenIDTokenAttributes, error) {
|
) (*api.OpenIDTokenAttributes, error) {
|
||||||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||||
|
var localpart string
|
||||||
|
var serverName gomatrixserverlib.ServerName
|
||||||
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||||
&openIDTokenAttrs.UserID,
|
&localpart, &serverName,
|
||||||
&openIDTokenAttrs.ExpiresAtMS,
|
&openIDTokenAttrs.ExpiresAtMS,
|
||||||
)
|
)
|
||||||
|
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != sql.ErrNoRows {
|
if err != sql.ErrNoRows {
|
||||||
log.WithError(err).Error("Unable to retrieve token from the db")
|
log.WithError(err).Error("Unable to retrieve token from the db")
|
||||||
|
|
|
@ -383,11 +383,15 @@ func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serv
|
||||||
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
|
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
|
||||||
func (d *Database) CreateOpenIDToken(
|
func (d *Database) CreateOpenIDToken(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
token, localpart string,
|
token, userID string,
|
||||||
) (int64, error) {
|
) (int64, error) {
|
||||||
|
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
|
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
|
||||||
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS)
|
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, domain, expiresAtMS)
|
||||||
})
|
})
|
||||||
return expiresAtMS, err
|
return expiresAtMS, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error {
|
||||||
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
||||||
CREATE TABLE userapi_accounts (
|
CREATE TABLE userapi_accounts (
|
||||||
localpart TEXT NOT NULL PRIMARY KEY,
|
localpart TEXT NOT NULL PRIMARY KEY,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
created_ts BIGINT NOT NULL,
|
created_ts BIGINT NOT NULL,
|
||||||
password_hash TEXT,
|
password_hash TEXT,
|
||||||
appservice_id TEXT,
|
appservice_id TEXT,
|
||||||
|
|
|
@ -14,6 +14,7 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
|
||||||
session_id INTEGER,
|
session_id INTEGER,
|
||||||
device_id TEXT ,
|
device_id TEXT ,
|
||||||
localpart TEXT ,
|
localpart TEXT ,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
created_ts BIGINT,
|
created_ts BIGINT,
|
||||||
display_name TEXT,
|
display_name TEXT,
|
||||||
last_seen_ts BIGINT,
|
last_seen_ts BIGINT,
|
||||||
|
|
|
@ -12,6 +12,7 @@ func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
|
||||||
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
||||||
CREATE TABLE userapi_accounts (
|
CREATE TABLE userapi_accounts (
|
||||||
localpart TEXT NOT NULL PRIMARY KEY,
|
localpart TEXT NOT NULL PRIMARY KEY,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
created_ts BIGINT NOT NULL,
|
created_ts BIGINT NOT NULL,
|
||||||
password_hash TEXT,
|
password_hash TEXT,
|
||||||
appservice_id TEXT,
|
appservice_id TEXT,
|
||||||
|
|
|
@ -3,6 +3,7 @@ package sqlite3
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
@ -25,10 +26,10 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertOpenIDTokenSQL = "" +
|
const insertOpenIDTokenSQL = "" +
|
||||||
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
"INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
const selectOpenIDTokenSQL = "" +
|
const selectOpenIDTokenSQL = "" +
|
||||||
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
"SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||||
|
|
||||||
type openIDTokenStatements struct {
|
type openIDTokenStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
@ -57,11 +58,11 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (
|
||||||
func (s *openIDTokenStatements) InsertOpenIDToken(
|
func (s *openIDTokenStatements) InsertOpenIDToken(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
txn *sql.Tx,
|
txn *sql.Tx,
|
||||||
token, localpart string,
|
token, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
expiresAtMS int64,
|
expiresAtMS int64,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||||
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
_, err = stmt.ExecContext(ctx, token, serverName, localpart, expiresAtMS)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,10 +73,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
|
||||||
token string,
|
token string,
|
||||||
) (*api.OpenIDTokenAttributes, error) {
|
) (*api.OpenIDTokenAttributes, error) {
|
||||||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||||
|
var localpart string
|
||||||
|
var serverName gomatrixserverlib.ServerName
|
||||||
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||||
&openIDTokenAttrs.UserID,
|
&localpart, &serverName,
|
||||||
&openIDTokenAttrs.ExpiresAtMS,
|
&openIDTokenAttrs.ExpiresAtMS,
|
||||||
)
|
)
|
||||||
|
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != sql.ErrNoRows {
|
if err != sql.ErrNoRows {
|
||||||
log.WithError(err).Error("Unable to retrieve token from the db")
|
log.WithError(err).Error("Unable to retrieve token from the db")
|
||||||
|
|
|
@ -79,7 +79,7 @@ type LoginTokenTable interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenIDTable interface {
|
type OpenIDTable interface {
|
||||||
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error)
|
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64) (err error)
|
||||||
SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,7 @@ func mustMakeAccountAndDevice(
|
||||||
accDB tables.AccountsTable,
|
accDB tables.AccountsTable,
|
||||||
devDB tables.DevicesTable,
|
devDB tables.DevicesTable,
|
||||||
localpart string,
|
localpart string,
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName, // nolint:unparam
|
||||||
accType api.AccountType,
|
accType api.AccountType,
|
||||||
userAgent string,
|
userAgent string,
|
||||||
) {
|
) {
|
||||||
|
|
Loading…
Reference in a new issue