mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-11 08:03:09 -06:00
implement filter by valid query param
This commit is contained in:
parent
356edeb6b4
commit
86d2aa41c1
|
|
@ -158,20 +158,22 @@ func getReturnValueForUsesAllowed(usesAllowed int32) interface{} {
|
||||||
}
|
}
|
||||||
|
|
||||||
func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
queryParams := req.URL.Query()
|
||||||
if err != nil {
|
|
||||||
return util.MatrixErrorResponse(
|
|
||||||
http.StatusInternalServerError,
|
|
||||||
string(spec.ErrorInvalidParam),
|
|
||||||
"unable to parse query params",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
returnAll := true
|
returnAll := true
|
||||||
validQuery, ok := vars["valid"]
|
valid := true
|
||||||
|
validQuery, ok := queryParams["valid"]
|
||||||
if ok {
|
if ok {
|
||||||
returnAll = false
|
returnAll = false
|
||||||
|
validValue, err := strconv.ParseBool(validQuery[0])
|
||||||
|
if err != nil {
|
||||||
|
return util.MatrixErrorResponse(
|
||||||
|
http.StatusBadRequest,
|
||||||
|
string(spec.ErrorInvalidParam),
|
||||||
|
"invalid 'valid' query parameter",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
valid = validValue
|
||||||
}
|
}
|
||||||
valid, err := strconv.ParseBool(validQuery)
|
|
||||||
tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid)
|
tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.MatrixErrorResponse(
|
return util.MatrixErrorResponse(
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package postgres
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/api"
|
"github.com/matrix-org/dendrite/clientapi/api"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
|
@ -25,13 +26,24 @@ const selectTokenSQL = "" +
|
||||||
const insertTokenSQL = "" +
|
const insertTokenSQL = "" +
|
||||||
"INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
|
"INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
|
||||||
|
|
||||||
const listTokensSQL = "" +
|
const listAllTokensSQL = "" +
|
||||||
"SELECT * FROM userapi_registration_tokens"
|
"SELECT * FROM userapi_registration_tokens"
|
||||||
|
|
||||||
|
const listValidTokensSQL = "" +
|
||||||
|
"SELECT * FROM userapi_registration_tokens WHERE" +
|
||||||
|
"(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" +
|
||||||
|
"(expiry_time > $1 OR expiry_time IS NULL)"
|
||||||
|
|
||||||
|
const listInvalidTokensSQL = "" +
|
||||||
|
"SELECT * FROM userapi_registration_tokens WHERE" +
|
||||||
|
"(uses_allowed <= pending + completed OR expiry_time <= $1)"
|
||||||
|
|
||||||
type registrationTokenStatements struct {
|
type registrationTokenStatements struct {
|
||||||
selectTokenStatement *sql.Stmt
|
selectTokenStatement *sql.Stmt
|
||||||
insertTokenStatement *sql.Stmt
|
insertTokenStatement *sql.Stmt
|
||||||
listTokensStatement *sql.Stmt
|
listAllTokensStatement *sql.Stmt
|
||||||
|
listValidTokensStatement *sql.Stmt
|
||||||
|
listInvalidTokenStatement *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
|
func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
|
||||||
|
|
@ -43,7 +55,9 @@ func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTa
|
||||||
return s, sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.selectTokenStatement, selectTokenSQL},
|
{&s.selectTokenStatement, selectTokenSQL},
|
||||||
{&s.insertTokenStatement, insertTokenSQL},
|
{&s.insertTokenStatement, insertTokenSQL},
|
||||||
{&s.listTokensStatement, listTokensSQL},
|
{&s.listAllTokensStatement, listAllTokensSQL},
|
||||||
|
{&s.listValidTokensStatement, listValidTokensSQL},
|
||||||
|
{&s.listInvalidTokenStatement, listInvalidTokensSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -95,14 +109,18 @@ func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context
|
||||||
var tokenString sql.NullString
|
var tokenString sql.NullString
|
||||||
var pending, completed, usesAllowed sql.NullInt32
|
var pending, completed, usesAllowed sql.NullInt32
|
||||||
var expiryTime sql.NullInt64
|
var expiryTime sql.NullInt64
|
||||||
|
var rows *sql.Rows
|
||||||
|
var err error
|
||||||
if returnAll {
|
if returnAll {
|
||||||
stmt = s.listTokensStatement
|
stmt = s.listAllTokensStatement
|
||||||
|
rows, err = stmt.QueryContext(ctx)
|
||||||
} else if valid {
|
} else if valid {
|
||||||
// TODO: Statement to Get All Valid Tokens
|
stmt = s.listValidTokensStatement
|
||||||
|
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||||||
} else {
|
} else {
|
||||||
// TODO: Statement to Get All Invalid Tokens
|
stmt = s.listInvalidTokenStatement
|
||||||
|
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||||||
}
|
}
|
||||||
rows, err := stmt.QueryContext(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return tokens, err
|
return tokens, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue