implement update api

This commit is contained in:
santhoshivan23 2023-06-11 00:51:51 +05:30
parent 31f3125c26
commit 5e0da6ac0e
8 changed files with 130 additions and 10 deletions

View file

@ -182,7 +182,6 @@ func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userA
"error fetching registration tokens",
)
}
return util.JSONResponse{
Code: 200,
JSON: map[string]interface{}{
@ -238,6 +237,67 @@ func AdminDeleteRegistrationToken(req *http.Request, cfg *config.ClientAPI, user
}
}
func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
tokenText := vars["token"]
request := make(map[string]interface{})
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MatrixErrorResponse(
http.StatusBadRequest,
string(spec.ErrorBadJSON),
"Failed to decode request body:",
)
}
newAttributes := make(map[string]interface{})
usesAllowed, ok := request["uses_allowed"]
if ok {
// Only add usesAllowed to newAtrributes if it is present and valid
// Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic
// Synapse's behaviour of updating the field if and only if it is present in request body.
if !(usesAllowed == nil || int32(usesAllowed.(float64)) >= 0) {
return util.MatrixErrorResponse(
http.StatusBadRequest,
string(spec.ErrorInvalidParam),
"uses_allowed must be a non-negative integer or null",
)
}
newAttributes["usesAllowed"] = usesAllowed
}
expiryTime, ok := request["expiry_time"]
if ok {
// Only add expiryTime to newAtrributes if it is present and valid
// Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic
// Synapse's behaviour of updating the field if and only if it is present in request body.
if !(expiryTime == nil || int64(expiryTime.(float64)) > time.Now().UnixNano()/int64(time.Millisecond)) {
return util.MatrixErrorResponse(
http.StatusBadRequest,
string(spec.ErrorInvalidParam),
"expiry_time must be in the future",
)
}
newAttributes["expiryTime"] = expiryTime
}
if len(newAttributes) == 0 {
// No attributes to update. Return existing token
return AdminGetRegistrationToken(req, cfg, userAPI)
}
updatedToken, err := userAPI.PerformAdminUpdateRegistrationToken(req.Context(), tokenText, newAttributes)
if err != nil {
return util.MatrixErrorResponse(
http.StatusNotFound,
string(spec.ErrorUnknown),
fmt.Sprintf("token: %s not found", tokenText),
)
}
return util.JSONResponse{
Code: 200,
JSON: *updatedToken,
}
}
func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {

View file

@ -179,7 +179,7 @@ func Setup(
if req.Method == http.MethodGet {
return AdminGetRegistrationToken(req, cfg, userAPI)
} else if req.Method == http.MethodPut {
return AdminUpdateRegistrationToken(req, cfg, userAPI)
} else if req.Method == http.MethodDelete {
return AdminDeleteRegistrationToken(req, cfg, userAPI)
}

View file

@ -99,6 +99,7 @@ type ClientUserAPI interface {
PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error)
PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error
PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error

View file

@ -99,6 +99,14 @@ func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Contex
return a.DB.DeleteRegistrationToken(ctx, tokenString)
}
func (a *UserInternalAPI) PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) {
token, err := a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes)
if err != nil {
return nil, err
}
return token, nil
}
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {

View file

@ -37,6 +37,7 @@ type RegistrationTokens interface {
ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error)
DeleteRegistrationToken(ctx context.Context, tokenString string) error
UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
}
type Profile interface {

View file

@ -45,14 +45,26 @@ const getTokenSQL = "" +
const deleteTokenSQL = "" +
"DELETE FROM userapi_registration_tokens WHERE token = $1"
const updateTokenUsesAllowedAndExpiryTimeSQL = "" +
"UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1"
const updateTokenUsesAllowedSQL = "" +
"UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1"
const updateTokenExpiryTimeSQL = "" +
"UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1"
type registrationTokenStatements struct {
selectTokenStatement *sql.Stmt
insertTokenStatement *sql.Stmt
listAllTokensStatement *sql.Stmt
listValidTokensStatement *sql.Stmt
listInvalidTokenStatement *sql.Stmt
getTokenStatement *sql.Stmt
deleteTokenStatement *sql.Stmt
selectTokenStatement *sql.Stmt
insertTokenStatement *sql.Stmt
listAllTokensStatement *sql.Stmt
listValidTokensStatement *sql.Stmt
listInvalidTokenStatement *sql.Stmt
getTokenStatement *sql.Stmt
deleteTokenStatement *sql.Stmt
updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt
updateTokenUsesAllowedStatement *sql.Stmt
updateTokenExpiryTimeStatement *sql.Stmt
}
func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
@ -69,6 +81,9 @@ func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTa
{&s.listInvalidTokenStatement, listInvalidTokensSQL},
{&s.getTokenStatement, getTokenSQL},
{&s.deleteTokenStatement, deleteTokenSQL},
{&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL},
{&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL},
{&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL},
}.Prepare(db)
}
@ -175,7 +190,7 @@ func getReturnValueForInt64(value sql.NullInt64) *int64 {
}
func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) {
stmt := s.getTokenStatement
stmt := sqlutil.TxStmt(tx, s.getTokenStatement)
var pending, completed, usesAllowed sql.NullInt32
var expiryTime sql.NullInt64
err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime)
@ -207,3 +222,29 @@ func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Contex
}
return nil
}
func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) {
var stmt *sql.Stmt
usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"]
expiryTime, expiryTimePresent := newAttributes["expiryTime"]
if usesAllowedPresent && expiryTimePresent {
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement)
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime)
if err != nil {
return nil, err
}
} else if usesAllowedPresent {
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement)
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed)
if err != nil {
return nil, err
}
} else if expiryTimePresent {
stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement)
_, err := stmt.ExecContext(ctx, tokenString, expiryTime)
if err != nil {
return nil, err
}
}
return s.GetRegistrationToken(ctx, tx, tokenString)
}

View file

@ -104,6 +104,14 @@ func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString stri
return d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString)
}
func (d *Database) UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (updatedToken *clientapi.RegistrationToken, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
updatedToken, err = d.RegistrationTokens.UpdateRegistrationToken(ctx, txn, tokenString, newAttributes)
return err
})
return
}
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(

View file

@ -36,6 +36,7 @@ type RegistrationTokensTable interface {
ListRegistrationTokens(ctx context.Context, txn *sql.Tx, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
GetRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) (*clientapi.RegistrationToken, error)
DeleteRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) error
UpdateRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
}
type AccountDataTable interface {