mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-10 23:53:09 -06:00
implement filter by valid query param
This commit is contained in:
parent
356edeb6b4
commit
86d2aa41c1
|
|
@ -158,20 +158,22 @@ func getReturnValueForUsesAllowed(usesAllowed int32) interface{} {
|
|||
}
|
||||
|
||||
func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusInternalServerError,
|
||||
string(spec.ErrorInvalidParam),
|
||||
"unable to parse query params",
|
||||
)
|
||||
}
|
||||
queryParams := req.URL.Query()
|
||||
returnAll := true
|
||||
validQuery, ok := vars["valid"]
|
||||
valid := true
|
||||
validQuery, ok := queryParams["valid"]
|
||||
if ok {
|
||||
returnAll = false
|
||||
validValue, err := strconv.ParseBool(validQuery[0])
|
||||
if err != nil {
|
||||
return util.MatrixErrorResponse(
|
||||
http.StatusBadRequest,
|
||||
string(spec.ErrorInvalidParam),
|
||||
"invalid 'valid' query parameter",
|
||||
)
|
||||
}
|
||||
valid = validValue
|
||||
}
|
||||
valid, err := strconv.ParseBool(validQuery)
|
||||
tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid)
|
||||
if err != nil {
|
||||
return util.MatrixErrorResponse(
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package postgres
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/api"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
|
|
@ -25,13 +26,24 @@ const selectTokenSQL = "" +
|
|||
const insertTokenSQL = "" +
|
||||
"INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
|
||||
|
||||
const listTokensSQL = "" +
|
||||
const listAllTokensSQL = "" +
|
||||
"SELECT * FROM userapi_registration_tokens"
|
||||
|
||||
const listValidTokensSQL = "" +
|
||||
"SELECT * FROM userapi_registration_tokens WHERE" +
|
||||
"(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" +
|
||||
"(expiry_time > $1 OR expiry_time IS NULL)"
|
||||
|
||||
const listInvalidTokensSQL = "" +
|
||||
"SELECT * FROM userapi_registration_tokens WHERE" +
|
||||
"(uses_allowed <= pending + completed OR expiry_time <= $1)"
|
||||
|
||||
type registrationTokenStatements struct {
|
||||
selectTokenStatement *sql.Stmt
|
||||
insertTokenStatement *sql.Stmt
|
||||
listTokensStatement *sql.Stmt
|
||||
selectTokenStatement *sql.Stmt
|
||||
insertTokenStatement *sql.Stmt
|
||||
listAllTokensStatement *sql.Stmt
|
||||
listValidTokensStatement *sql.Stmt
|
||||
listInvalidTokenStatement *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
|
||||
|
|
@ -43,7 +55,9 @@ func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTa
|
|||
return s, sqlutil.StatementList{
|
||||
{&s.selectTokenStatement, selectTokenSQL},
|
||||
{&s.insertTokenStatement, insertTokenSQL},
|
||||
{&s.listTokensStatement, listTokensSQL},
|
||||
{&s.listAllTokensStatement, listAllTokensSQL},
|
||||
{&s.listValidTokensStatement, listValidTokensSQL},
|
||||
{&s.listInvalidTokenStatement, listInvalidTokensSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
|
|
@ -95,14 +109,18 @@ func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context
|
|||
var tokenString sql.NullString
|
||||
var pending, completed, usesAllowed sql.NullInt32
|
||||
var expiryTime sql.NullInt64
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if returnAll {
|
||||
stmt = s.listTokensStatement
|
||||
stmt = s.listAllTokensStatement
|
||||
rows, err = stmt.QueryContext(ctx)
|
||||
} else if valid {
|
||||
// TODO: Statement to Get All Valid Tokens
|
||||
stmt = s.listValidTokensStatement
|
||||
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||||
} else {
|
||||
// TODO: Statement to Get All Invalid Tokens
|
||||
stmt = s.listInvalidTokenStatement
|
||||
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return tokens, err
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue