implement filter by valid query param

This commit is contained in:
santhoshivan23 2023-06-08 10:03:34 +05:30
parent 356edeb6b4
commit 86d2aa41c1
2 changed files with 39 additions and 19 deletions

View file

@ -158,20 +158,22 @@ func getReturnValueForUsesAllowed(usesAllowed int32) interface{} {
} }
func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) queryParams := req.URL.Query()
if err != nil {
return util.MatrixErrorResponse(
http.StatusInternalServerError,
string(spec.ErrorInvalidParam),
"unable to parse query params",
)
}
returnAll := true returnAll := true
validQuery, ok := vars["valid"] valid := true
validQuery, ok := queryParams["valid"]
if ok { if ok {
returnAll = false returnAll = false
validValue, err := strconv.ParseBool(validQuery[0])
if err != nil {
return util.MatrixErrorResponse(
http.StatusBadRequest,
string(spec.ErrorInvalidParam),
"invalid 'valid' query parameter",
)
}
valid = validValue
} }
valid, err := strconv.ParseBool(validQuery)
tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid) tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid)
if err != nil { if err != nil {
return util.MatrixErrorResponse( return util.MatrixErrorResponse(

View file

@ -3,6 +3,7 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"time"
"github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -25,13 +26,24 @@ const selectTokenSQL = "" +
const insertTokenSQL = "" + const insertTokenSQL = "" +
"INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)" "INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
const listTokensSQL = "" + const listAllTokensSQL = "" +
"SELECT * FROM userapi_registration_tokens" "SELECT * FROM userapi_registration_tokens"
const listValidTokensSQL = "" +
"SELECT * FROM userapi_registration_tokens WHERE" +
"(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" +
"(expiry_time > $1 OR expiry_time IS NULL)"
const listInvalidTokensSQL = "" +
"SELECT * FROM userapi_registration_tokens WHERE" +
"(uses_allowed <= pending + completed OR expiry_time <= $1)"
type registrationTokenStatements struct { type registrationTokenStatements struct {
selectTokenStatement *sql.Stmt selectTokenStatement *sql.Stmt
insertTokenStatement *sql.Stmt insertTokenStatement *sql.Stmt
listTokensStatement *sql.Stmt listAllTokensStatement *sql.Stmt
listValidTokensStatement *sql.Stmt
listInvalidTokenStatement *sql.Stmt
} }
func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) { func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
@ -43,7 +55,9 @@ func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTa
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.selectTokenStatement, selectTokenSQL}, {&s.selectTokenStatement, selectTokenSQL},
{&s.insertTokenStatement, insertTokenSQL}, {&s.insertTokenStatement, insertTokenSQL},
{&s.listTokensStatement, listTokensSQL}, {&s.listAllTokensStatement, listAllTokensSQL},
{&s.listValidTokensStatement, listValidTokensSQL},
{&s.listInvalidTokenStatement, listInvalidTokensSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -95,14 +109,18 @@ func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context
var tokenString sql.NullString var tokenString sql.NullString
var pending, completed, usesAllowed sql.NullInt32 var pending, completed, usesAllowed sql.NullInt32
var expiryTime sql.NullInt64 var expiryTime sql.NullInt64
var rows *sql.Rows
var err error
if returnAll { if returnAll {
stmt = s.listTokensStatement stmt = s.listAllTokensStatement
rows, err = stmt.QueryContext(ctx)
} else if valid { } else if valid {
// TODO: Statement to Get All Valid Tokens stmt = s.listValidTokensStatement
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
} else { } else {
// TODO: Statement to Get All Invalid Tokens stmt = s.listInvalidTokenStatement
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
} }
rows, err := stmt.QueryContext(ctx)
if err != nil { if err != nil {
return tokens, err return tokens, err
} }