mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-13 09:03:10 -06:00
addressed review comments
This commit is contained in:
parent
2fcc16fbb7
commit
5346ce735a
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/constraints"
|
||||||
|
|
||||||
clientapi "github.com/matrix-org/dendrite/clientapi/api"
|
clientapi "github.com/matrix-org/dendrite/clientapi/api"
|
||||||
"github.com/matrix-org/dendrite/internal/httputil"
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
|
|
@ -39,8 +40,8 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
|
||||||
}
|
}
|
||||||
request := struct {
|
request := struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
UsesAllowed int32 `json:"uses_allowed"`
|
UsesAllowed *int32 `json:"uses_allowed,omitempty"`
|
||||||
ExpiryTime int64 `json:"expiry_time"`
|
ExpiryTime *int64 `json:"expiry_time,omitempty"`
|
||||||
Length int32 `json:"length"`
|
Length int32 `json:"length"`
|
||||||
}{}
|
}{}
|
||||||
|
|
||||||
|
|
@ -87,15 +88,13 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// At this point, we have a valid token, either through request body or through random generation.
|
// At this point, we have a valid token, either through request body or through random generation.
|
||||||
|
if usesAllowed != nil && *usesAllowed < 0 {
|
||||||
if usesAllowed < 0 {
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"),
|
JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) {
|
||||||
if expiryTime != 0 && expiryTime < time.Now().UnixNano()/int64(time.Millisecond) {
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: spec.BadJSON("expiry_time must not be in the past"),
|
JSON: spec.BadJSON("expiry_time must not be in the past"),
|
||||||
|
|
@ -106,10 +105,10 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
|
||||||
// If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating unlimited uses / no expiration will be persisted in DB)
|
// If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating unlimited uses / no expiration will be persisted in DB)
|
||||||
registrationToken := &clientapi.RegistrationToken{
|
registrationToken := &clientapi.RegistrationToken{
|
||||||
Token: &token,
|
Token: &token,
|
||||||
UsesAllowed: &usesAllowed,
|
UsesAllowed: usesAllowed,
|
||||||
Pending: &pending,
|
Pending: &pending,
|
||||||
Completed: &completed,
|
Completed: &completed,
|
||||||
ExpiryTime: &expiryTime,
|
ExpiryTime: expiryTime,
|
||||||
}
|
}
|
||||||
created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken)
|
created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken)
|
||||||
if !created {
|
if !created {
|
||||||
|
|
@ -130,19 +129,19 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
|
||||||
Code: 200,
|
Code: 200,
|
||||||
JSON: map[string]interface{}{
|
JSON: map[string]interface{}{
|
||||||
"token": token,
|
"token": token,
|
||||||
"uses_allowed": getReturnValueForUsesAllowed(usesAllowed),
|
"uses_allowed": getReturnValue(usesAllowed),
|
||||||
"pending": pending,
|
"pending": pending,
|
||||||
"completed": completed,
|
"completed": completed,
|
||||||
"expiry_time": getReturnValueExpiryTime(expiryTime),
|
"expiry_time": getReturnValue(expiryTime),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getReturnValueForUsesAllowed(usesAllowed int32) interface{} {
|
func getReturnValue[t constraints.Integer](in *t) any {
|
||||||
if usesAllowed == 0 {
|
if in == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return usesAllowed
|
return *in
|
||||||
}
|
}
|
||||||
|
|
||||||
func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||||
|
|
@ -176,13 +175,6 @@ func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userA
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getReturnValueExpiryTime(expiryTime int64) interface{} {
|
|
||||||
if expiryTime == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return expiryTime
|
|
||||||
}
|
|
||||||
|
|
||||||
func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -106,8 +106,8 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
|
||||||
_, err := stmt.ExecContext(
|
_, err := stmt.ExecContext(
|
||||||
ctx,
|
ctx,
|
||||||
*registrationToken.Token,
|
*registrationToken.Token,
|
||||||
nullIfZero(*registrationToken.UsesAllowed),
|
getInsertValue(registrationToken.UsesAllowed),
|
||||||
nullIfZero(*registrationToken.ExpiryTime),
|
getInsertValue(registrationToken.ExpiryTime),
|
||||||
*registrationToken.Pending,
|
*registrationToken.Pending,
|
||||||
*registrationToken.Completed)
|
*registrationToken.Completed)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -116,11 +116,11 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func nullIfZero[t constraints.Integer](in t) any {
|
func getInsertValue[t constraints.Integer](in *t) any {
|
||||||
if in == 0 {
|
if in == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return in
|
return *in
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
|
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
|
||||||
|
|
|
||||||
|
|
@ -106,8 +106,8 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
|
||||||
_, err := stmt.ExecContext(
|
_, err := stmt.ExecContext(
|
||||||
ctx,
|
ctx,
|
||||||
*registrationToken.Token,
|
*registrationToken.Token,
|
||||||
nullIfZero(*registrationToken.UsesAllowed),
|
getInsertValue(registrationToken.UsesAllowed),
|
||||||
nullIfZero(*registrationToken.ExpiryTime),
|
getInsertValue(registrationToken.ExpiryTime),
|
||||||
*registrationToken.Pending,
|
*registrationToken.Pending,
|
||||||
*registrationToken.Completed)
|
*registrationToken.Completed)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -116,11 +116,11 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func nullIfZero[t constraints.Integer](in t) any {
|
func getInsertValue[t constraints.Integer](in *t) any {
|
||||||
if in == 0 {
|
if in == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return in
|
return *in
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
|
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue