implement filter by valid query param

This commit is contained in:
santhoshivan23 2023-06-08 10:03:34 +05:30
parent 356edeb6b4
commit 86d2aa41c1
2 changed files with 39 additions and 19 deletions

View file

@ -158,20 +158,22 @@ func getReturnValueForUsesAllowed(usesAllowed int32) interface{} {
}
func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.MatrixErrorResponse(
http.StatusInternalServerError,
string(spec.ErrorInvalidParam),
"unable to parse query params",
)
}
queryParams := req.URL.Query()
returnAll := true
validQuery, ok := vars["valid"]
valid := true
validQuery, ok := queryParams["valid"]
if ok {
returnAll = false
validValue, err := strconv.ParseBool(validQuery[0])
if err != nil {
return util.MatrixErrorResponse(
http.StatusBadRequest,
string(spec.ErrorInvalidParam),
"invalid 'valid' query parameter",
)
}
valid = validValue
}
valid, err := strconv.ParseBool(validQuery)
tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid)
if err != nil {
return util.MatrixErrorResponse(

View file

@ -3,6 +3,7 @@ package postgres
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/internal/sqlutil"
@ -25,13 +26,24 @@ const selectTokenSQL = "" +
const insertTokenSQL = "" +
"INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
const listTokensSQL = "" +
const listAllTokensSQL = "" +
"SELECT * FROM userapi_registration_tokens"
const listValidTokensSQL = "" +
"SELECT * FROM userapi_registration_tokens WHERE" +
"(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" +
"(expiry_time > $1 OR expiry_time IS NULL)"
const listInvalidTokensSQL = "" +
"SELECT * FROM userapi_registration_tokens WHERE" +
"(uses_allowed <= pending + completed OR expiry_time <= $1)"
type registrationTokenStatements struct {
selectTokenStatement *sql.Stmt
insertTokenStatement *sql.Stmt
listTokensStatement *sql.Stmt
selectTokenStatement *sql.Stmt
insertTokenStatement *sql.Stmt
listAllTokensStatement *sql.Stmt
listValidTokensStatement *sql.Stmt
listInvalidTokenStatement *sql.Stmt
}
func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
@ -43,7 +55,9 @@ func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTa
return s, sqlutil.StatementList{
{&s.selectTokenStatement, selectTokenSQL},
{&s.insertTokenStatement, insertTokenSQL},
{&s.listTokensStatement, listTokensSQL},
{&s.listAllTokensStatement, listAllTokensSQL},
{&s.listValidTokensStatement, listValidTokensSQL},
{&s.listInvalidTokenStatement, listInvalidTokensSQL},
}.Prepare(db)
}
@ -95,14 +109,18 @@ func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context
var tokenString sql.NullString
var pending, completed, usesAllowed sql.NullInt32
var expiryTime sql.NullInt64
var rows *sql.Rows
var err error
if returnAll {
stmt = s.listTokensStatement
stmt = s.listAllTokensStatement
rows, err = stmt.QueryContext(ctx)
} else if valid {
// TODO: Statement to Get All Valid Tokens
stmt = s.listValidTokensStatement
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
} else {
// TODO: Statement to Get All Invalid Tokens
stmt = s.listInvalidTokenStatement
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
}
rows, err := stmt.QueryContext(ctx)
if err != nil {
return tokens, err
}