diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index d0608f7aa..96d557ae4 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -158,20 +158,22 @@ func getReturnValueForUsesAllowed(usesAllowed int32) interface{} { } func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.MatrixErrorResponse( - http.StatusInternalServerError, - string(spec.ErrorInvalidParam), - "unable to parse query params", - ) - } + queryParams := req.URL.Query() returnAll := true - validQuery, ok := vars["valid"] + valid := true + validQuery, ok := queryParams["valid"] if ok { returnAll = false + validValue, err := strconv.ParseBool(validQuery[0]) + if err != nil { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorInvalidParam), + "invalid 'valid' query parameter", + ) + } + valid = validValue } - valid, err := strconv.ParseBool(validQuery) tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid) if err != nil { return util.MatrixErrorResponse( diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go index 666fb3c3e..3bb8b19e8 100644 --- a/userapi/storage/postgres/registration_tokens_table.go +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -3,6 +3,7 @@ package postgres import ( "context" "database/sql" + "time" "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -25,13 +26,24 @@ const selectTokenSQL = "" + const insertTokenSQL = "" + "INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)" -const listTokensSQL = "" + +const listAllTokensSQL = "" + "SELECT * FROM userapi_registration_tokens" +const listValidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" + + "(expiry_time > $1 OR expiry_time IS NULL)" + +const listInvalidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed <= pending + completed OR expiry_time <= $1)" + type registrationTokenStatements struct { - selectTokenStatement *sql.Stmt - insertTokenStatement *sql.Stmt - listTokensStatement *sql.Stmt + selectTokenStatement *sql.Stmt + insertTokenStatement *sql.Stmt + listAllTokensStatement *sql.Stmt + listValidTokensStatement *sql.Stmt + listInvalidTokenStatement *sql.Stmt } func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) { @@ -43,7 +55,9 @@ func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTa return s, sqlutil.StatementList{ {&s.selectTokenStatement, selectTokenSQL}, {&s.insertTokenStatement, insertTokenSQL}, - {&s.listTokensStatement, listTokensSQL}, + {&s.listAllTokensStatement, listAllTokensSQL}, + {&s.listValidTokensStatement, listValidTokensSQL}, + {&s.listInvalidTokenStatement, listInvalidTokensSQL}, }.Prepare(db) } @@ -95,14 +109,18 @@ func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context var tokenString sql.NullString var pending, completed, usesAllowed sql.NullInt32 var expiryTime sql.NullInt64 + var rows *sql.Rows + var err error if returnAll { - stmt = s.listTokensStatement + stmt = s.listAllTokensStatement + rows, err = stmt.QueryContext(ctx) } else if valid { - // TODO: Statement to Get All Valid Tokens + stmt = s.listValidTokensStatement + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) } else { - // TODO: Statement to Get All Invalid Tokens + stmt = s.listInvalidTokenStatement + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) } - rows, err := stmt.QueryContext(ctx) if err != nil { return tokens, err }