diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 96d557ae4..1644a3009 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -198,6 +198,46 @@ func getReturnValueExpiryTime(expiryTime int64) interface{} { return expiryTime } +func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + tokenText := vars["token"] + token, err := userAPI.PerformAdminGetRegistrationToken(req.Context(), tokenText) + if err != nil { + return util.MatrixErrorResponse( + http.StatusNotFound, + string(spec.ErrorUnknown), + fmt.Sprintf("token: %s not found", tokenText), + ) + } + return util.JSONResponse{ + Code: 200, + JSON: token, + } +} + +func AdminDeleteRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + tokenText := vars["token"] + err = userAPI.PerformAdminDeleteRegistrationToken(req.Context(), tokenText) + if err != nil { + return util.MatrixErrorResponse( + http.StatusNotFound, + string(spec.ErrorUnknown), + fmt.Sprintf("token: %s not found", tokenText), + ) + } + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{}, + } +} + func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 2d96e05cc..e1620c53c 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -169,11 +169,29 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/registrationTokens", - httputil.MakeAdminAPI("admin_registration_tokens", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_list_registration_tokens", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminListRegistrationTokens(req, cfg, userAPI) }), ).Methods(http.MethodGet, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/registrationTokens/{token}", + httputil.MakeAdminAPI("admin_get_registration_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if req.Method == http.MethodGet { + return AdminGetRegistrationToken(req, cfg, userAPI) + } else if req.Method == http.MethodPut { + + } else if req.Method == http.MethodDelete { + return AdminDeleteRegistrationToken(req, cfg, userAPI) + } + return util.MatrixErrorResponse( + 404, + string(spec.ErrorNotFound), + "unknown method", + ) + + }), + ).Methods(http.MethodGet, http.MethodPut, http.MethodDelete, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}", httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminEvacuateRoom(req, rsAPI) diff --git a/userapi/api/api.go b/userapi/api/api.go index 9f014d3f0..532422d84 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -97,6 +97,8 @@ type ClientUserAPI interface { QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) + PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) + PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 65ea6a868..07f425097 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -87,6 +87,18 @@ func (a *UserInternalAPI) PerformAdminListRegistrationTokens(ctx context.Context return tokens, nil } +func (a *UserInternalAPI) PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) { + token, err := a.DB.GetRegistrationToken(ctx, tokenString) + if err != nil { + return nil, err + } + return token, nil +} + +func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error { + return a.DB.DeleteRegistrationToken(ctx, tokenString) +} + func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 986da99b5..144cd1f73 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -35,6 +35,8 @@ type RegistrationTokens interface { RegistrationTokenExists(ctx context.Context, token string) (bool, error) InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) + GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) + DeleteRegistrationToken(ctx context.Context, tokenString string) error } type Profile interface { diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go index 3bb8b19e8..0163a2924 100644 --- a/userapi/storage/postgres/registration_tokens_table.go +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -3,6 +3,7 @@ package postgres import ( "context" "database/sql" + "fmt" "time" "github.com/matrix-org/dendrite/clientapi/api" @@ -38,12 +39,20 @@ const listInvalidTokensSQL = "" + "SELECT * FROM userapi_registration_tokens WHERE" + "(uses_allowed <= pending + completed OR expiry_time <= $1)" +const getTokenSQL = "" + + "SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1" + +const deleteTokenSQL = "" + + "DELETE FROM userapi_registration_tokens WHERE token = $1" + type registrationTokenStatements struct { selectTokenStatement *sql.Stmt insertTokenStatement *sql.Stmt listAllTokensStatement *sql.Stmt listValidTokensStatement *sql.Stmt listInvalidTokenStatement *sql.Stmt + getTokenStatement *sql.Stmt + deleteTokenStatement *sql.Stmt } func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) { @@ -58,6 +67,8 @@ func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTa {&s.listAllTokensStatement, listAllTokensSQL}, {&s.listValidTokensStatement, listValidTokensSQL}, {&s.listInvalidTokenStatement, listInvalidTokensSQL}, + {&s.getTokenStatement, getTokenSQL}, + {&s.deleteTokenStatement, deleteTokenSQL}, }.Prepare(db) } @@ -129,12 +140,18 @@ func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context if err != nil { return tokens, err } + tokenString := tokenString.String + pending := pending.Int32 + completed := completed.Int32 + usesAllowed := getReturnValueForInt32(usesAllowed) + expiryTime := getReturnValueForInt64(expiryTime) + tokenMap := api.RegistrationToken{ - Token: &tokenString.String, - Pending: &pending.Int32, - Completed: &pending.Int32, - UsesAllowed: getReturnValueForInt32(usesAllowed), - ExpiryTime: getReturnValueForInt64(expiryTime), + Token: &tokenString, + Pending: &pending, + Completed: &completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, } tokens = append(tokens, tokenMap) } @@ -156,3 +173,37 @@ func getReturnValueForInt64(value sql.NullInt64) *int64 { } return nil } + +func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) { + stmt := s.getTokenStatement + var pending, completed, usesAllowed sql.NullInt32 + var expiryTime sql.NullInt64 + err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime) + if err != nil { + return nil, err + } + token := api.RegistrationToken{ + Token: &tokenString, + Pending: &pending.Int32, + Completed: &completed.Int32, + UsesAllowed: getReturnValueForInt32(usesAllowed), + ExpiryTime: getReturnValueForInt64(expiryTime), + } + return &token, nil +} + +func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error { + stmt := s.deleteTokenStatement + res, err := stmt.ExecContext(ctx, tokenString) + if err != nil { + return err + } + count, err := res.RowsAffected() + if err != nil { + return err + } + if count == 0 { + return fmt.Errorf("token: %s does not exists", tokenString) + } + return nil +} diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 58ebe30ec..86f1fd9e6 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -96,6 +96,14 @@ func (d *Database) ListRegistrationTokens(ctx context.Context, returnAll bool, v return d.RegistrationTokens.ListRegistrationTokens(ctx, nil, returnAll, valid) } +func (d *Database) GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) { + return d.RegistrationTokens.GetRegistrationToken(ctx, nil, tokenString) +} + +func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString string) error { + return d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString) +} + // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index dfd52235a..fe902481a 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -34,6 +34,8 @@ type RegistrationTokensTable interface { RegistrationTokenExists(ctx context.Context, txn *sql.Tx, token string) (bool, error) InsertRegistrationToken(ctx context.Context, txn *sql.Tx, registrationToken *clientapi.RegistrationToken) (bool, error) ListRegistrationTokens(ctx context.Context, txn *sql.Tx, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) + GetRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) (*clientapi.RegistrationToken, error) + DeleteRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) error } type AccountDataTable interface {