From 99fa964b62f4fdeee81dfdf24c5cf6e4b9f67d46 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Tue, 6 Jun 2023 22:21:21 +0530 Subject: [PATCH] handle cases when request field is not present --- clientapi/routing/admin.go | 40 +++++++++++++------ .../postgres/registration_tokens_table.go | 16 +++++++- 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 8accef535..f2a391c9c 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -28,17 +28,6 @@ import ( userapi "github.com/matrix-org/dendrite/userapi/api" ) -func generateRandomToken(length int) string { - allowedChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_" - rand.Seed(time.Now().UnixNano()) - var sb strings.Builder - for i := 0; i < length; i++ { - randomIndex := rand.Intn(len(allowedChars)) - sb.WriteByte(allowedChars[randomIndex]) - } - return sb.String() -} - func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { if !cfg.RegistrationRequiresToken { return util.MatrixErrorResponse( @@ -133,14 +122,39 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u Code: 200, JSON: map[string]interface{}{ "token": token, - "uses_allowed": usesAllowed, + "uses_allowed": getReturnValueForUsesAllowed(usesAllowed), "pending": pending, "completed": completed, - "expiry_time": expiryTime, + "expiry_time": getReturnValueExpiryTime(expiryTime), }, } } +func generateRandomToken(length int) string { + allowedChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_" + rand.Seed(time.Now().UnixNano()) + var sb strings.Builder + for i := 0; i < length; i++ { + randomIndex := rand.Intn(len(allowedChars)) + sb.WriteByte(allowedChars[randomIndex]) + } + return sb.String() +} + +func getReturnValueForUsesAllowed(usesAllowed int32) interface{} { + if usesAllowed == 0 { + return nil + } + return usesAllowed +} + +func getReturnValueExpiryTime(expiryTime int64) interface{} { + if expiryTime == 0 { + return nil + } + return expiryTime +} + func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go index 750e53b26..6c55444c0 100644 --- a/userapi/storage/postgres/registration_tokens_table.go +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -58,9 +58,23 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex stmt := sqlutil.TxStmt(tx, s.insertTokenStatment) pending := 0 completed := 0 - _, err := stmt.ExecContext(ctx, token, nil, expiryTime, pending, completed) + _, err := stmt.ExecContext(ctx, token, nullIfZeroInt32(usesAllowed), nullIfZero(expiryTime), pending, completed) if err != nil { return false, err } return true, nil } + +func nullIfZero(value int64) interface{} { + if value == 0 { + return nil + } + return value +} + +func nullIfZeroInt32(value int32) interface{} { + if value == 0 { + return nil + } + return value +}