mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-12 16:43:09 -06:00
81 lines
2 KiB
Go
81 lines
2 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
|
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
|
)
|
|
|
|
const registrationTokensSchema = `
|
|
CREATE TABLE IF NOT EXISTS userapi_registration_tokens (
|
|
token TEXT PRIMARY KEY,
|
|
pending BIGINT,
|
|
completed BIGINT,
|
|
uses_allowed BIGINT,
|
|
expiry_time BIGINT
|
|
);
|
|
`
|
|
|
|
const selectTokenSQL = "" +
|
|
"SELECT token FROM userapi_registration_tokens WHERE token = $1"
|
|
|
|
const insertTokenSQL = "" +
|
|
"INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
|
|
|
|
type registrationTokenStatements struct {
|
|
selectTokenStatement *sql.Stmt
|
|
insertTokenStatment *sql.Stmt
|
|
}
|
|
|
|
func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
|
|
s := ®istrationTokenStatements{}
|
|
_, err := db.Exec(registrationTokensSchema)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return s, sqlutil.StatementList{
|
|
{&s.selectTokenStatement, selectTokenSQL},
|
|
{&s.insertTokenStatment, insertTokenSQL},
|
|
}.Prepare(db)
|
|
}
|
|
|
|
func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
|
|
var existingToken string
|
|
stmt := s.selectTokenStatement
|
|
err := stmt.QueryRowContext(ctx, token).Scan(&existingToken)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return false, nil
|
|
}
|
|
return false, err
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, token string, usesAllowed int32, expiryTime int64) (bool, error) {
|
|
stmt := sqlutil.TxStmt(tx, s.insertTokenStatment)
|
|
pending := 0
|
|
completed := 0
|
|
_, err := stmt.ExecContext(ctx, token, nullIfZeroInt32(usesAllowed), nullIfZero(expiryTime), pending, completed)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func nullIfZero(value int64) interface{} {
|
|
if value == 0 {
|
|
return nil
|
|
}
|
|
return value
|
|
}
|
|
|
|
func nullIfZeroInt32(value int32) interface{} {
|
|
if value == 0 {
|
|
return nil
|
|
}
|
|
return value
|
|
}
|