From fe2464fd4b3b9384cfc62937d3b7302cb56d3dfc Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Tue, 6 Jun 2023 21:52:05 +0530 Subject: [PATCH] Move DB Layer to UserAPI --- clientapi/routing/admin.go | 13 +++- clientapi/routing/routing.go | 2 +- roomserver/api/api.go | 1 - roomserver/internal/perform/perform_admin.go | 15 ----- roomserver/storage/interface.go | 1 - .../postgres/registration_tokens_table.go | 25 ------- roomserver/storage/postgres/storage.go | 3 - roomserver/storage/shared/storage.go | 27 ++++---- roomserver/storage/tables/interface.go | 4 -- userapi/api/api.go | 1 + userapi/internal/user_api.go | 15 +++++ userapi/storage/interface.go | 6 ++ .../postgres/registration_tokens_table.go | 66 +++++++++++++++++++ userapi/storage/postgres/storage.go | 5 ++ userapi/storage/shared/storage.go | 13 ++++ userapi/storage/tables/interface.go | 5 ++ 16 files changed, 133 insertions(+), 69 deletions(-) delete mode 100644 roomserver/storage/postgres/registration_tokens_table.go create mode 100644 userapi/storage/postgres/registration_tokens_table.go diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 558e011e3..a0a307273 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/userapi/api" + userapi "github.com/matrix-org/dendrite/userapi/api" ) func generateRandomToken(length int) string { @@ -38,7 +39,7 @@ func generateRandomToken(length int) string { return sb.String() } -func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { if !cfg.RegistrationRequiresToken { return util.MatrixErrorResponse( http.StatusForbidden, @@ -67,7 +68,11 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, r length := request.Length if len(token) == 0 { - // Token not present in request body. Hence, generate a random token. + if length == 0 { + // length not provided in request. Assign default value of 16. + length = 16 + } + // token not present in request body. Hence, generate a random token. if !(length > 0 && length <= 64) { return util.MatrixErrorResponse( http.StatusBadRequest, @@ -109,7 +114,9 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, r } pending := 0 completed := 0 - created, err := rsAPI.PerformAdminCreateRegistrationToken(req.Context(), token, usesAllowed, int32(pending), int32(completed), expiryTime) + // If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating + // unlimited uses / no expiration will be persisted in DB) + created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), token, usesAllowed, expiryTime) if err != nil { return util.MatrixErrorResponse( http.StatusInternalServerError, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index efa3f45e8..bbca60227 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -164,7 +164,7 @@ func Setup( } dendriteAdminRouter.Handle("/admin/registrationTokens/new", httputil.MakeAdminAPI("admin_registration_tokens_new", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return AdminCreateNewRegistrationToken(req, cfg, rsAPI) + return AdminCreateNewRegistrationToken(req, cfg, userAPI) }), ).Methods(http.MethodPost, http.MethodOptions) diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 54762b6ff..7cb3379e0 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -173,7 +173,6 @@ type ClientRoomserverAPI interface { PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse) // PerformRoomUpgrade upgrades a room to a newer version PerformRoomUpgrade(ctx context.Context, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error) - PerformAdminCreateRegistrationToken(ctx context.Context, token string, usesAllowed, pending, completed int32, expiryTime int64) (bool, error) PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error) PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error) PerformAdminPurgeRoom(ctx context.Context, roomID string) error diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 292d91f23..575525e21 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -42,21 +42,6 @@ type Admin struct { Leaver *Leaver } -func (r *Admin) PerformAdminCreateRegistrationToken( - ctx context.Context, token string, - usesAllowed, pending, completed int32, - expiryTime int64) (bool, error) { - exists, err := r.DB.RegistrationTokenExists(ctx, token) - if err != nil { - return false, err - } - if exists { - fmt.Println(fmt.Sprintf("token: %s already exists", token)) - return false, fmt.Errorf("token: %s already exists", token) - } - return true, nil -} - // PerformAdminEvacuateRoom will remove all local users from the given room. func (r *Admin) PerformAdminEvacuateRoom( ctx context.Context, diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 4cf8f3b3a..7d22df008 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -27,7 +27,6 @@ import ( ) type Database interface { - RegistrationTokenExists(ctx context.Context, token string) (bool, error) // Do we support processing input events for more than one room at a time? SupportsConcurrentRoomInputs() bool // RoomInfo returns room information for the given room ID, or nil if there is no room. diff --git a/roomserver/storage/postgres/registration_tokens_table.go b/roomserver/storage/postgres/registration_tokens_table.go deleted file mode 100644 index 8fd0e41f9..000000000 --- a/roomserver/storage/postgres/registration_tokens_table.go +++ /dev/null @@ -1,25 +0,0 @@ -package postgres - -import ( - "context" - "database/sql" -) - -const registrationTokensSchema = ` -CREATE TABLE IF NOT EXISTS roomserver_registration_tokens ( - token TEXT PRIMARY KEY, - pending BIGINT, - completed BIGINT, - uses_allowed BIGINT, - expiry_time BIGINT -); -` - -func CreateRegistrationTokensTable(db *sql.DB) error { - _, err := db.Exec(registrationTokensSchema) - return err -} - -func RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { - return true, nil -} diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 5836ab153..19cde5410 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -92,9 +92,6 @@ func executeMigration(ctx context.Context, db *sql.DB) error { } func (d *Database) create(db *sql.DB) error { - if err := CreateRegistrationTokensTable(db); err != nil { - return err - } if err := CreateEventStateKeysTable(db); err != nil { return err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 0a8a358ef..cefa58a3d 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -31,18 +31,17 @@ const redactionsArePermanent = true type Database struct { DB *sql.DB EventDatabase - Cache caching.RoomServerCaches - Writer sqlutil.Writer - RoomsTable tables.Rooms - StateSnapshotTable tables.StateSnapshot - StateBlockTable tables.StateBlock - RoomAliasesTable tables.RoomAliases - InvitesTable tables.Invites - MembershipTable tables.Membership - PublishedTable tables.Published - Purge tables.Purge - GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) - RegistrationTokensTable tables.RegistrationTokens + Cache caching.RoomServerCaches + Writer sqlutil.Writer + RoomsTable tables.Rooms + StateSnapshotTable tables.StateSnapshot + StateBlockTable tables.StateBlock + RoomAliasesTable tables.RoomAliases + InvitesTable tables.Invites + MembershipTable tables.Membership + PublishedTable tables.Published + Purge tables.Purge + GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) } // EventDatabase contains all tables needed to work with events @@ -58,10 +57,6 @@ type EventDatabase struct { RedactionsTable tables.Redactions } -func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) { - return d.RegistrationTokensTable.RegistrationTokenExists(ctx, nil, token) -} - func (d *Database) SupportsConcurrentRoomInputs() bool { return true } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 471b341eb..333483b32 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -19,10 +19,6 @@ type EventJSONPair struct { EventJSON []byte } -type RegistrationTokens interface { - RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) -} - type EventJSON interface { // Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions). InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error diff --git a/userapi/api/api.go b/userapi/api/api.go index 050402645..1dfae8ed1 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -94,6 +94,7 @@ type ClientUserAPI interface { QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error + PerformAdminCreateRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (bool, error) PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 32f3d84b5..8f388ab82 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -63,6 +63,21 @@ type UserInternalAPI struct { Updater *DeviceListUpdater } +func (a *UserInternalAPI) PerformAdminCreateRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (bool, error) { + exists, err := a.DB.RegistrationTokenExists(ctx, token) + if err != nil { + return false, err + } + if exists { + return false, fmt.Errorf("token: %s already exists", token) + } + _, err = a.DB.InsertRegistrationToken(ctx, token, usesAllowed, expiryTime) + if err != nil { + return false, fmt.Errorf("Error creating token: %s"+err.Error(), token) + } + return true, nil +} + func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 4f5e99a8a..8815df68f 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -30,6 +30,11 @@ import ( "github.com/matrix-org/dendrite/userapi/types" ) +type RegistrationTokens interface { + RegistrationTokenExists(ctx context.Context, token string) (bool, error) + InsertRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (bool, error) +} + type Profile interface { GetProfileByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) @@ -144,6 +149,7 @@ type UserDatabase interface { Pusher Statistics ThreePID + RegistrationTokens } type KeyChangeDatabase interface { diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go new file mode 100644 index 000000000..750e53b26 --- /dev/null +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -0,0 +1,66 @@ +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +const registrationTokensSchema = ` +CREATE TABLE IF NOT EXISTS userapi_registration_tokens ( + token TEXT PRIMARY KEY, + pending BIGINT, + completed BIGINT, + uses_allowed BIGINT, + expiry_time BIGINT +); +` + +const selectTokenSQL = "" + + "SELECT token FROM userapi_registration_tokens WHERE token = $1" + +const insertTokenSQL = "" + + "INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)" + +type registrationTokenStatements struct { + selectTokenStatement *sql.Stmt + insertTokenStatment *sql.Stmt +} + +func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) { + s := ®istrationTokenStatements{} + _, err := db.Exec(registrationTokensSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.selectTokenStatement, selectTokenSQL}, + {&s.insertTokenStatment, insertTokenSQL}, + }.Prepare(db) +} + +func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { + var existingToken string + stmt := s.selectTokenStatement + err := stmt.QueryRowContext(ctx, token).Scan(&existingToken) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil +} + +func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, token string, usesAllowed int32, expiryTime int64) (bool, error) { + stmt := sqlutil.TxStmt(tx, s.insertTokenStatment) + pending := 0 + completed := 0 + _, err := stmt.ExecContext(ctx, token, nil, expiryTime, pending, completed) + if err != nil { + return false, err + } + return true, nil +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 72e7c9cd9..d01ccc776 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -53,6 +53,10 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties * return nil, err } + registationTokensTable, err := NewPostgresRegistrationTokensTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresRegistrationsTokenTable: %w", err) + } accountsTable, err := NewPostgresAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err) @@ -125,6 +129,7 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties * ThreePIDs: threePIDTable, Pushers: pusherTable, Notifications: notificationsTable, + RegistrationTokens: registationTokensTable, Stats: statsTable, ServerName: serverName, DB: db, diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 537bbbf4a..9ec210391 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -43,6 +43,7 @@ import ( type Database struct { DB *sql.DB Writer sqlutil.Writer + RegistrationTokens tables.RegistrationTokensTable Accounts tables.AccountsTable Profiles tables.ProfileTable AccountDatas tables.AccountDataTable @@ -78,6 +79,18 @@ const ( loginTokenByteLength = 32 ) +func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) { + return d.RegistrationTokens.RegistrationTokenExists(ctx, nil, token) +} + +func (d *Database) InsertRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (created bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + created, err = d.RegistrationTokens.InsertRegistrationToken(ctx, txn, token, usesAllowed, expiryTime) + return err + }) + return +} + // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 3c6214e7c..41c99baed 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -29,6 +29,11 @@ import ( "github.com/matrix-org/dendrite/userapi/types" ) +type RegistrationTokensTable interface { + RegistrationTokenExists(ctx context.Context, txn *sql.Tx, token string) (bool, error) + InsertRegistrationToken(ctx context.Context, txn *sql.Tx, token string, usesAllowed int32, expiryTime int64) (bool, error) +} + type AccountDataTable interface { InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage) error SelectAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)