addressed review comments

This commit is contained in:
santhoshivan23 2023-06-22 00:02:09 +05:30
parent 2fcc16fbb7
commit 5346ce735a
3 changed files with 22 additions and 30 deletions

View file

@ -18,6 +18,7 @@ import (
"github.com/matrix-org/util"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
"golang.org/x/exp/constraints"
clientapi "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/internal/httputil"
@ -39,8 +40,8 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
}
request := struct {
Token string `json:"token"`
UsesAllowed int32 `json:"uses_allowed"`
ExpiryTime int64 `json:"expiry_time"`
UsesAllowed *int32 `json:"uses_allowed,omitempty"`
ExpiryTime *int64 `json:"expiry_time,omitempty"`
Length int32 `json:"length"`
}{}
@ -87,15 +88,13 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
}
}
// At this point, we have a valid token, either through request body or through random generation.
if usesAllowed < 0 {
if usesAllowed != nil && *usesAllowed < 0 {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"),
}
}
if expiryTime != 0 && expiryTime < time.Now().UnixNano()/int64(time.Millisecond) {
if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("expiry_time must not be in the past"),
@ -106,10 +105,10 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
// If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating unlimited uses / no expiration will be persisted in DB)
registrationToken := &clientapi.RegistrationToken{
Token: &token,
UsesAllowed: &usesAllowed,
UsesAllowed: usesAllowed,
Pending: &pending,
Completed: &completed,
ExpiryTime: &expiryTime,
ExpiryTime: expiryTime,
}
created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken)
if !created {
@ -130,19 +129,19 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
Code: 200,
JSON: map[string]interface{}{
"token": token,
"uses_allowed": getReturnValueForUsesAllowed(usesAllowed),
"uses_allowed": getReturnValue(usesAllowed),
"pending": pending,
"completed": completed,
"expiry_time": getReturnValueExpiryTime(expiryTime),
"expiry_time": getReturnValue(expiryTime),
},
}
}
func getReturnValueForUsesAllowed(usesAllowed int32) interface{} {
if usesAllowed == 0 {
func getReturnValue[t constraints.Integer](in *t) any {
if in == nil {
return nil
}
return usesAllowed
return *in
}
func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
@ -176,13 +175,6 @@ func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userA
}
}
func getReturnValueExpiryTime(expiryTime int64) interface{} {
if expiryTime == 0 {
return nil
}
return expiryTime
}
func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {

View file

@ -106,8 +106,8 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
_, err := stmt.ExecContext(
ctx,
*registrationToken.Token,
nullIfZero(*registrationToken.UsesAllowed),
nullIfZero(*registrationToken.ExpiryTime),
getInsertValue(registrationToken.UsesAllowed),
getInsertValue(registrationToken.ExpiryTime),
*registrationToken.Pending,
*registrationToken.Completed)
if err != nil {
@ -116,11 +116,11 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
return true, nil
}
func nullIfZero[t constraints.Integer](in t) any {
if in == 0 {
func getInsertValue[t constraints.Integer](in *t) any {
if in == nil {
return nil
}
return in
return *in
}
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {

View file

@ -106,8 +106,8 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
_, err := stmt.ExecContext(
ctx,
*registrationToken.Token,
nullIfZero(*registrationToken.UsesAllowed),
nullIfZero(*registrationToken.ExpiryTime),
getInsertValue(registrationToken.UsesAllowed),
getInsertValue(registrationToken.ExpiryTime),
*registrationToken.Pending,
*registrationToken.Completed)
if err != nil {
@ -116,11 +116,11 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
return true, nil
}
func nullIfZero[t constraints.Integer](in t) any {
if in == 0 {
func getInsertValue[t constraints.Integer](in *t) any {
if in == nil {
return nil
}
return in
return *in
}
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {