mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-11 08:03:09 -06:00
implement update api
This commit is contained in:
parent
31f3125c26
commit
5e0da6ac0e
|
|
@ -182,7 +182,6 @@ func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userA
|
||||||
"error fetching registration tokens",
|
"error fetching registration tokens",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 200,
|
Code: 200,
|
||||||
JSON: map[string]interface{}{
|
JSON: map[string]interface{}{
|
||||||
|
|
@ -238,6 +237,67 @@ func AdminDeleteRegistrationToken(req *http.Request, cfg *config.ClientAPI, user
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
tokenText := vars["token"]
|
||||||
|
request := make(map[string]interface{})
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MatrixErrorResponse(
|
||||||
|
http.StatusBadRequest,
|
||||||
|
string(spec.ErrorBadJSON),
|
||||||
|
"Failed to decode request body:",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
newAttributes := make(map[string]interface{})
|
||||||
|
usesAllowed, ok := request["uses_allowed"]
|
||||||
|
if ok {
|
||||||
|
// Only add usesAllowed to newAtrributes if it is present and valid
|
||||||
|
// Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic
|
||||||
|
// Synapse's behaviour of updating the field if and only if it is present in request body.
|
||||||
|
if !(usesAllowed == nil || int32(usesAllowed.(float64)) >= 0) {
|
||||||
|
return util.MatrixErrorResponse(
|
||||||
|
http.StatusBadRequest,
|
||||||
|
string(spec.ErrorInvalidParam),
|
||||||
|
"uses_allowed must be a non-negative integer or null",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
newAttributes["usesAllowed"] = usesAllowed
|
||||||
|
}
|
||||||
|
expiryTime, ok := request["expiry_time"]
|
||||||
|
if ok {
|
||||||
|
// Only add expiryTime to newAtrributes if it is present and valid
|
||||||
|
// Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic
|
||||||
|
// Synapse's behaviour of updating the field if and only if it is present in request body.
|
||||||
|
if !(expiryTime == nil || int64(expiryTime.(float64)) > time.Now().UnixNano()/int64(time.Millisecond)) {
|
||||||
|
return util.MatrixErrorResponse(
|
||||||
|
http.StatusBadRequest,
|
||||||
|
string(spec.ErrorInvalidParam),
|
||||||
|
"expiry_time must be in the future",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
newAttributes["expiryTime"] = expiryTime
|
||||||
|
}
|
||||||
|
if len(newAttributes) == 0 {
|
||||||
|
// No attributes to update. Return existing token
|
||||||
|
return AdminGetRegistrationToken(req, cfg, userAPI)
|
||||||
|
}
|
||||||
|
updatedToken, err := userAPI.PerformAdminUpdateRegistrationToken(req.Context(), tokenText, newAttributes)
|
||||||
|
if err != nil {
|
||||||
|
return util.MatrixErrorResponse(
|
||||||
|
http.StatusNotFound,
|
||||||
|
string(spec.ErrorUnknown),
|
||||||
|
fmt.Sprintf("token: %s not found", tokenText),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: *updatedToken,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
|
func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -179,7 +179,7 @@ func Setup(
|
||||||
if req.Method == http.MethodGet {
|
if req.Method == http.MethodGet {
|
||||||
return AdminGetRegistrationToken(req, cfg, userAPI)
|
return AdminGetRegistrationToken(req, cfg, userAPI)
|
||||||
} else if req.Method == http.MethodPut {
|
} else if req.Method == http.MethodPut {
|
||||||
|
return AdminUpdateRegistrationToken(req, cfg, userAPI)
|
||||||
} else if req.Method == http.MethodDelete {
|
} else if req.Method == http.MethodDelete {
|
||||||
return AdminDeleteRegistrationToken(req, cfg, userAPI)
|
return AdminDeleteRegistrationToken(req, cfg, userAPI)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -99,6 +99,7 @@ type ClientUserAPI interface {
|
||||||
PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
|
PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
|
||||||
PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error)
|
PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error)
|
||||||
PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error
|
PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error
|
||||||
|
PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
|
||||||
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
|
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
|
||||||
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
|
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
|
||||||
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
|
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
|
||||||
|
|
|
||||||
|
|
@ -99,6 +99,14 @@ func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Contex
|
||||||
return a.DB.DeleteRegistrationToken(ctx, tokenString)
|
return a.DB.DeleteRegistrationToken(ctx, tokenString)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) {
|
||||||
|
token, err := a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
|
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
|
||||||
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@ type RegistrationTokens interface {
|
||||||
ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
|
ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
|
||||||
GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error)
|
GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error)
|
||||||
DeleteRegistrationToken(ctx context.Context, tokenString string) error
|
DeleteRegistrationToken(ctx context.Context, tokenString string) error
|
||||||
|
UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Profile interface {
|
type Profile interface {
|
||||||
|
|
|
||||||
|
|
@ -45,14 +45,26 @@ const getTokenSQL = "" +
|
||||||
const deleteTokenSQL = "" +
|
const deleteTokenSQL = "" +
|
||||||
"DELETE FROM userapi_registration_tokens WHERE token = $1"
|
"DELETE FROM userapi_registration_tokens WHERE token = $1"
|
||||||
|
|
||||||
|
const updateTokenUsesAllowedAndExpiryTimeSQL = "" +
|
||||||
|
"UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1"
|
||||||
|
|
||||||
|
const updateTokenUsesAllowedSQL = "" +
|
||||||
|
"UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1"
|
||||||
|
|
||||||
|
const updateTokenExpiryTimeSQL = "" +
|
||||||
|
"UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1"
|
||||||
|
|
||||||
type registrationTokenStatements struct {
|
type registrationTokenStatements struct {
|
||||||
selectTokenStatement *sql.Stmt
|
selectTokenStatement *sql.Stmt
|
||||||
insertTokenStatement *sql.Stmt
|
insertTokenStatement *sql.Stmt
|
||||||
listAllTokensStatement *sql.Stmt
|
listAllTokensStatement *sql.Stmt
|
||||||
listValidTokensStatement *sql.Stmt
|
listValidTokensStatement *sql.Stmt
|
||||||
listInvalidTokenStatement *sql.Stmt
|
listInvalidTokenStatement *sql.Stmt
|
||||||
getTokenStatement *sql.Stmt
|
getTokenStatement *sql.Stmt
|
||||||
deleteTokenStatement *sql.Stmt
|
deleteTokenStatement *sql.Stmt
|
||||||
|
updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt
|
||||||
|
updateTokenUsesAllowedStatement *sql.Stmt
|
||||||
|
updateTokenExpiryTimeStatement *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
|
func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
|
||||||
|
|
@ -69,6 +81,9 @@ func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTa
|
||||||
{&s.listInvalidTokenStatement, listInvalidTokensSQL},
|
{&s.listInvalidTokenStatement, listInvalidTokensSQL},
|
||||||
{&s.getTokenStatement, getTokenSQL},
|
{&s.getTokenStatement, getTokenSQL},
|
||||||
{&s.deleteTokenStatement, deleteTokenSQL},
|
{&s.deleteTokenStatement, deleteTokenSQL},
|
||||||
|
{&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL},
|
||||||
|
{&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL},
|
||||||
|
{&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -175,7 +190,7 @@ func getReturnValueForInt64(value sql.NullInt64) *int64 {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) {
|
func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) {
|
||||||
stmt := s.getTokenStatement
|
stmt := sqlutil.TxStmt(tx, s.getTokenStatement)
|
||||||
var pending, completed, usesAllowed sql.NullInt32
|
var pending, completed, usesAllowed sql.NullInt32
|
||||||
var expiryTime sql.NullInt64
|
var expiryTime sql.NullInt64
|
||||||
err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime)
|
err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime)
|
||||||
|
|
@ -207,3 +222,29 @@ func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Contex
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) {
|
||||||
|
var stmt *sql.Stmt
|
||||||
|
usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"]
|
||||||
|
expiryTime, expiryTimePresent := newAttributes["expiryTime"]
|
||||||
|
if usesAllowedPresent && expiryTimePresent {
|
||||||
|
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement)
|
||||||
|
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else if usesAllowedPresent {
|
||||||
|
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement)
|
||||||
|
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else if expiryTimePresent {
|
||||||
|
stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement)
|
||||||
|
_, err := stmt.ExecContext(ctx, tokenString, expiryTime)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.GetRegistrationToken(ctx, tx, tokenString)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -104,6 +104,14 @@ func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString stri
|
||||||
return d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString)
|
return d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (updatedToken *clientapi.RegistrationToken, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
updatedToken, err = d.RegistrationTokens.UpdateRegistrationToken(ctx, txn, tokenString, newAttributes)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||||
func (d *Database) GetAccountByPassword(
|
func (d *Database) GetAccountByPassword(
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ type RegistrationTokensTable interface {
|
||||||
ListRegistrationTokens(ctx context.Context, txn *sql.Tx, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
|
ListRegistrationTokens(ctx context.Context, txn *sql.Tx, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
|
||||||
GetRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) (*clientapi.RegistrationToken, error)
|
GetRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) (*clientapi.RegistrationToken, error)
|
||||||
DeleteRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) error
|
DeleteRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) error
|
||||||
|
UpdateRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccountDataTable interface {
|
type AccountDataTable interface {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue