dendrite/userapi/storage/accounts/sqlite3/openid_table.go
2020-11-17 11:17:31 -05:00

100 lines
2.8 KiB
Go

package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
const openIDTokenSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS account_openid (
-- This is the token value, empty by default
token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL,
-- When this token was first created, as a unix timestamp (ms resolution).
token_created_ts BIGINT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution).
token_expires_ts BIGINT NOT NULL,
-- (optional) Relying Party the token was created for
token_rp TEXT,
);
-- Create sequence for autogenerated numeric usernames
-- CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
`
const insertTokenSQL = "" +
"INSERT INTO account_openid(token, localpart, token_created_ts, token_expires_ts, token_rp) VALUES ($1, $2, $3, $4, $5)"
const selectTokenSQL = "" +
"SELECT token, localpart, token_created_ts, token_expires_ts, token_rp FROM account_openid WHERE token = $1"
type tokenStatements struct {
db *sql.DB
insertTokenStmt *sql.Stmt
selectTokenStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
_, err = db.Exec(openIDTokenSchema)
if err != nil {
return err
}
if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil {
return
}
if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil {
return
}
s.serverName = server
return
}
// insertToken inserts a new OpenID Connect token to the DB.
// tokenRP is the OpenID Relying Party; if not specified, it's left nil
// Returns new token, otherwise returns error if token already exists.
func (s *tokenStatements) insertToken(
ctx context.Context,
txn *sql.Tx,
token, localpart string,
createdTimeMS, expiresTimeMS int64,
tokenRP string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
if tokenRP == "" {
_, err = stmt.ExecContext(ctx, token, localpart, createdTimeMS, expiresTimeMS, nil)
} else {
_, err = stmt.ExecContext(ctx, token, localpart, createdTimeMS, expiresTimeMS, tokenRP)
}
return
}
func (s *tokenStatements) selectToken(
ctx context.Context,
token string,
) (openIDToken *api.OpenIDToken, err error) {
var localpart string
err = s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDToken.Token,
localpart,
&openIDToken.CreatedTS,
&openIDToken.ExpiresTS,
&openIDToken.RelyingParty,
)
if err != nil {
return nil, err
}
openIDToken.UserID = userutil.MakeUserID(localpart, s.serverName)
return openIDToken, nil
}