mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-26 16:21:55 -06:00
223 lines
7.3 KiB
Go
223 lines
7.3 KiB
Go
|
package postgres
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"database/sql"
|
||
|
"time"
|
||
|
|
||
|
"github.com/matrix-org/dendrite/clientapi/api"
|
||
|
internal "github.com/matrix-org/dendrite/internal"
|
||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||
|
"golang.org/x/exp/constraints"
|
||
|
)
|
||
|
|
||
|
const registrationTokensSchema = `
|
||
|
CREATE TABLE IF NOT EXISTS userapi_registration_tokens (
|
||
|
token TEXT PRIMARY KEY,
|
||
|
pending BIGINT,
|
||
|
completed BIGINT,
|
||
|
uses_allowed BIGINT,
|
||
|
expiry_time BIGINT
|
||
|
);
|
||
|
`
|
||
|
|
||
|
const selectTokenSQL = "" +
|
||
|
"SELECT token FROM userapi_registration_tokens WHERE token = $1"
|
||
|
|
||
|
const insertTokenSQL = "" +
|
||
|
"INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
|
||
|
|
||
|
const listAllTokensSQL = "" +
|
||
|
"SELECT * FROM userapi_registration_tokens"
|
||
|
|
||
|
const listValidTokensSQL = "" +
|
||
|
"SELECT * FROM userapi_registration_tokens WHERE" +
|
||
|
"(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" +
|
||
|
"(expiry_time > $1 OR expiry_time IS NULL)"
|
||
|
|
||
|
const listInvalidTokensSQL = "" +
|
||
|
"SELECT * FROM userapi_registration_tokens WHERE" +
|
||
|
"(uses_allowed <= pending + completed OR expiry_time <= $1)"
|
||
|
|
||
|
const getTokenSQL = "" +
|
||
|
"SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1"
|
||
|
|
||
|
const deleteTokenSQL = "" +
|
||
|
"DELETE FROM userapi_registration_tokens WHERE token = $1"
|
||
|
|
||
|
const updateTokenUsesAllowedAndExpiryTimeSQL = "" +
|
||
|
"UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1"
|
||
|
|
||
|
const updateTokenUsesAllowedSQL = "" +
|
||
|
"UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1"
|
||
|
|
||
|
const updateTokenExpiryTimeSQL = "" +
|
||
|
"UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1"
|
||
|
|
||
|
type registrationTokenStatements struct {
|
||
|
selectTokenStatement *sql.Stmt
|
||
|
insertTokenStatement *sql.Stmt
|
||
|
listAllTokensStatement *sql.Stmt
|
||
|
listValidTokensStatement *sql.Stmt
|
||
|
listInvalidTokenStatement *sql.Stmt
|
||
|
getTokenStatement *sql.Stmt
|
||
|
deleteTokenStatement *sql.Stmt
|
||
|
updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt
|
||
|
updateTokenUsesAllowedStatement *sql.Stmt
|
||
|
updateTokenExpiryTimeStatement *sql.Stmt
|
||
|
}
|
||
|
|
||
|
func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
|
||
|
s := ®istrationTokenStatements{}
|
||
|
_, err := db.Exec(registrationTokensSchema)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return s, sqlutil.StatementList{
|
||
|
{&s.selectTokenStatement, selectTokenSQL},
|
||
|
{&s.insertTokenStatement, insertTokenSQL},
|
||
|
{&s.listAllTokensStatement, listAllTokensSQL},
|
||
|
{&s.listValidTokensStatement, listValidTokensSQL},
|
||
|
{&s.listInvalidTokenStatement, listInvalidTokensSQL},
|
||
|
{&s.getTokenStatement, getTokenSQL},
|
||
|
{&s.deleteTokenStatement, deleteTokenSQL},
|
||
|
{&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL},
|
||
|
{&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL},
|
||
|
{&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL},
|
||
|
}.Prepare(db)
|
||
|
}
|
||
|
|
||
|
func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
|
||
|
var existingToken string
|
||
|
stmt := sqlutil.TxStmt(tx, s.selectTokenStatement)
|
||
|
err := stmt.QueryRowContext(ctx, token).Scan(&existingToken)
|
||
|
if err != nil {
|
||
|
if err == sql.ErrNoRows {
|
||
|
return false, nil
|
||
|
}
|
||
|
return false, err
|
||
|
}
|
||
|
return true, nil
|
||
|
}
|
||
|
|
||
|
func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) {
|
||
|
stmt := sqlutil.TxStmt(tx, s.insertTokenStatement)
|
||
|
_, err := stmt.ExecContext(
|
||
|
ctx,
|
||
|
*registrationToken.Token,
|
||
|
getInsertValue(registrationToken.UsesAllowed),
|
||
|
getInsertValue(registrationToken.ExpiryTime),
|
||
|
*registrationToken.Pending,
|
||
|
*registrationToken.Completed)
|
||
|
if err != nil {
|
||
|
return false, err
|
||
|
}
|
||
|
return true, nil
|
||
|
}
|
||
|
|
||
|
func getInsertValue[t constraints.Integer](in *t) any {
|
||
|
if in == nil {
|
||
|
return nil
|
||
|
}
|
||
|
return *in
|
||
|
}
|
||
|
|
||
|
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
|
||
|
var stmt *sql.Stmt
|
||
|
var tokens []api.RegistrationToken
|
||
|
var tokenString string
|
||
|
var pending, completed, usesAllowed *int32
|
||
|
var expiryTime *int64
|
||
|
var rows *sql.Rows
|
||
|
var err error
|
||
|
if returnAll {
|
||
|
stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement)
|
||
|
rows, err = stmt.QueryContext(ctx)
|
||
|
} else if valid {
|
||
|
stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement)
|
||
|
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||
|
} else {
|
||
|
stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement)
|
||
|
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||
|
}
|
||
|
if err != nil {
|
||
|
return tokens, err
|
||
|
}
|
||
|
defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed")
|
||
|
for rows.Next() {
|
||
|
err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime)
|
||
|
if err != nil {
|
||
|
return tokens, err
|
||
|
}
|
||
|
tokenString := tokenString
|
||
|
pending := pending
|
||
|
completed := completed
|
||
|
usesAllowed := usesAllowed
|
||
|
expiryTime := expiryTime
|
||
|
|
||
|
tokenMap := api.RegistrationToken{
|
||
|
Token: &tokenString,
|
||
|
Pending: pending,
|
||
|
Completed: completed,
|
||
|
UsesAllowed: usesAllowed,
|
||
|
ExpiryTime: expiryTime,
|
||
|
}
|
||
|
tokens = append(tokens, tokenMap)
|
||
|
}
|
||
|
return tokens, rows.Err()
|
||
|
}
|
||
|
|
||
|
func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) {
|
||
|
stmt := sqlutil.TxStmt(tx, s.getTokenStatement)
|
||
|
var pending, completed, usesAllowed *int32
|
||
|
var expiryTime *int64
|
||
|
err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
token := api.RegistrationToken{
|
||
|
Token: &tokenString,
|
||
|
Pending: pending,
|
||
|
Completed: completed,
|
||
|
UsesAllowed: usesAllowed,
|
||
|
ExpiryTime: expiryTime,
|
||
|
}
|
||
|
return &token, nil
|
||
|
}
|
||
|
|
||
|
func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error {
|
||
|
stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement)
|
||
|
_, err := stmt.ExecContext(ctx, tokenString)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) {
|
||
|
var stmt *sql.Stmt
|
||
|
usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"]
|
||
|
expiryTime, expiryTimePresent := newAttributes["expiryTime"]
|
||
|
if usesAllowedPresent && expiryTimePresent {
|
||
|
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement)
|
||
|
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
} else if usesAllowedPresent {
|
||
|
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement)
|
||
|
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
} else if expiryTimePresent {
|
||
|
stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement)
|
||
|
_, err := stmt.ExecContext(ctx, tokenString, expiryTime)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
return s.GetRegistrationToken(ctx, tx, tokenString)
|
||
|
}
|