From 5e0da6ac0ed63ea1746e71b9d785e88ed47f2196 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Sun, 11 Jun 2023 00:51:51 +0530 Subject: [PATCH] implement update api --- clientapi/routing/admin.go | 62 ++++++++++++++++++- clientapi/routing/routing.go | 2 +- userapi/api/api.go | 1 + userapi/internal/user_api.go | 8 +++ userapi/storage/interface.go | 1 + .../postgres/registration_tokens_table.go | 57 ++++++++++++++--- userapi/storage/shared/storage.go | 8 +++ userapi/storage/tables/interface.go | 1 + 8 files changed, 130 insertions(+), 10 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 1644a3009..53b1be3cb 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -182,7 +182,6 @@ func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userA "error fetching registration tokens", ) } - return util.JSONResponse{ Code: 200, JSON: map[string]interface{}{ @@ -238,6 +237,67 @@ func AdminDeleteRegistrationToken(req *http.Request, cfg *config.ClientAPI, user } } +func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + tokenText := vars["token"] + request := make(map[string]interface{}) + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorBadJSON), + "Failed to decode request body:", + ) + } + newAttributes := make(map[string]interface{}) + usesAllowed, ok := request["uses_allowed"] + if ok { + // Only add usesAllowed to newAtrributes if it is present and valid + // Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic + // Synapse's behaviour of updating the field if and only if it is present in request body. + if !(usesAllowed == nil || int32(usesAllowed.(float64)) >= 0) { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorInvalidParam), + "uses_allowed must be a non-negative integer or null", + ) + } + newAttributes["usesAllowed"] = usesAllowed + } + expiryTime, ok := request["expiry_time"] + if ok { + // Only add expiryTime to newAtrributes if it is present and valid + // Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic + // Synapse's behaviour of updating the field if and only if it is present in request body. + if !(expiryTime == nil || int64(expiryTime.(float64)) > time.Now().UnixNano()/int64(time.Millisecond)) { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorInvalidParam), + "expiry_time must be in the future", + ) + } + newAttributes["expiryTime"] = expiryTime + } + if len(newAttributes) == 0 { + // No attributes to update. Return existing token + return AdminGetRegistrationToken(req, cfg, userAPI) + } + updatedToken, err := userAPI.PerformAdminUpdateRegistrationToken(req.Context(), tokenText, newAttributes) + if err != nil { + return util.MatrixErrorResponse( + http.StatusNotFound, + string(spec.ErrorUnknown), + fmt.Sprintf("token: %s not found", tokenText), + ) + } + return util.JSONResponse{ + Code: 200, + JSON: *updatedToken, + } +} + func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index e1620c53c..79e628f3f 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -179,7 +179,7 @@ func Setup( if req.Method == http.MethodGet { return AdminGetRegistrationToken(req, cfg, userAPI) } else if req.Method == http.MethodPut { - + return AdminUpdateRegistrationToken(req, cfg, userAPI) } else if req.Method == http.MethodDelete { return AdminDeleteRegistrationToken(req, cfg, userAPI) } diff --git a/userapi/api/api.go b/userapi/api/api.go index 532422d84..a0dce9758 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -99,6 +99,7 @@ type ClientUserAPI interface { PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error + PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 07f425097..2cfd649a8 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -99,6 +99,14 @@ func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Contex return a.DB.DeleteRegistrationToken(ctx, tokenString) } +func (a *UserInternalAPI) PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) { + token, err := a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes) + if err != nil { + return nil, err + } + return token, nil +} + func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 144cd1f73..125b31585 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -37,6 +37,7 @@ type RegistrationTokens interface { ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) DeleteRegistrationToken(ctx context.Context, tokenString string) error + UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) } type Profile interface { diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go index 0163a2924..3f85f2093 100644 --- a/userapi/storage/postgres/registration_tokens_table.go +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -45,14 +45,26 @@ const getTokenSQL = "" + const deleteTokenSQL = "" + "DELETE FROM userapi_registration_tokens WHERE token = $1" +const updateTokenUsesAllowedAndExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1" + +const updateTokenUsesAllowedSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1" + +const updateTokenExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1" + type registrationTokenStatements struct { - selectTokenStatement *sql.Stmt - insertTokenStatement *sql.Stmt - listAllTokensStatement *sql.Stmt - listValidTokensStatement *sql.Stmt - listInvalidTokenStatement *sql.Stmt - getTokenStatement *sql.Stmt - deleteTokenStatement *sql.Stmt + selectTokenStatement *sql.Stmt + insertTokenStatement *sql.Stmt + listAllTokensStatement *sql.Stmt + listValidTokensStatement *sql.Stmt + listInvalidTokenStatement *sql.Stmt + getTokenStatement *sql.Stmt + deleteTokenStatement *sql.Stmt + updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt + updateTokenUsesAllowedStatement *sql.Stmt + updateTokenExpiryTimeStatement *sql.Stmt } func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) { @@ -69,6 +81,9 @@ func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTa {&s.listInvalidTokenStatement, listInvalidTokensSQL}, {&s.getTokenStatement, getTokenSQL}, {&s.deleteTokenStatement, deleteTokenSQL}, + {&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL}, + {&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL}, + {&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL}, }.Prepare(db) } @@ -175,7 +190,7 @@ func getReturnValueForInt64(value sql.NullInt64) *int64 { } func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) { - stmt := s.getTokenStatement + stmt := sqlutil.TxStmt(tx, s.getTokenStatement) var pending, completed, usesAllowed sql.NullInt32 var expiryTime sql.NullInt64 err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime) @@ -207,3 +222,29 @@ func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Contex } return nil } + +func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) { + var stmt *sql.Stmt + usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"] + expiryTime, expiryTimePresent := newAttributes["expiryTime"] + if usesAllowedPresent && expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime) + if err != nil { + return nil, err + } + } else if usesAllowedPresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed) + if err != nil { + return nil, err + } + } else if expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, expiryTime) + if err != nil { + return nil, err + } + } + return s.GetRegistrationToken(ctx, tx, tokenString) +} diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 86f1fd9e6..481256db1 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -104,6 +104,14 @@ func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString stri return d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString) } +func (d *Database) UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (updatedToken *clientapi.RegistrationToken, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + updatedToken, err = d.RegistrationTokens.UpdateRegistrationToken(ctx, txn, tokenString, newAttributes) + return err + }) + return +} + // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index fe902481a..3a0be73e4 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -36,6 +36,7 @@ type RegistrationTokensTable interface { ListRegistrationTokens(ctx context.Context, txn *sql.Tx, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) GetRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) (*clientapi.RegistrationToken, error) DeleteRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) error + UpdateRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) } type AccountDataTable interface {