From c2aa0209210f6fa3842440b2fbd85aed2fde4473 Mon Sep 17 00:00:00 2001 From: Till Faelligen Date: Thu, 10 Feb 2022 13:29:36 +0100 Subject: [PATCH] Add account_type for postgres --- .../accounts/postgres/accounts_table.go | 19 +++++---- userapi/storage/accounts/postgres/storage.go | 41 +++++++++---------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go index b57aa901f..30a77bd13 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -19,10 +19,11 @@ import ( "database/sql" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -39,16 +40,18 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- Identifies which application service this account belongs to, if any. appservice_id TEXT, -- If the account is currently active - is_deactivated BOOLEAN DEFAULT FALSE + is_deactivated BOOLEAN DEFAULT FALSE, + -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) + account_type INT DEFAULT 2 -- TODO: - -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? + -- upgraded_ts, devices, any email reset stuff? ); -- Create sequence for autogenerated numeric usernames CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; ` const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" const updatePasswordSQL = "" + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" @@ -96,16 +99,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) insertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) var err error - if appserviceID == "" { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil) + if accountType != api.AccountTypeAppService { + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) } else { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) } if err != nil { return nil, err diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index 2f8290623..24b60ac14 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -23,13 +23,14 @@ import ( "strconv" "time" + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas" - "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" // Import the postgres database driver. _ "github.com/lib/pq" @@ -73,6 +74,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver } m := sqlutil.NewMigrations() deltas.LoadIsActive(m) + deltas.LoadAddAccountType(m) if err = m.RunDeltas(db, dbProperties); err != nil { return nil, err } @@ -155,37 +157,32 @@ func (d *Database) SetPassword( return d.accounts.updatePassword(ctx, localpart, hash) } -// CreateGuestAccount makes a new guest account and creates an empty profile -// for this account. -func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - var numLocalpart int64 - numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) - if err != nil { - return err - } - localpart := strconv.FormatInt(numLocalpart, 10) - acc, err = d.createAccount(ctx, txn, localpart, "", "") - return err - }) - return acc, err -} - // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, sqlutil.ErrUserExists. func (d *Database) CreateAccount( - ctx context.Context, localpart, plaintextPassword, appserviceID string, + ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType, ) (acc *api.Account, err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) + // For guest accounts, we create a new numeric local part + if accountType == api.AccountTypeGuest { + var numLocalpart int64 + numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) + if err != nil { + return err + } + localpart = strconv.FormatInt(numLocalpart, 10) + plaintextPassword = "" + appserviceID = "" + } + acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) return err }) return } func (d *Database) createAccount( - ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, + ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { var account *api.Account var err error @@ -197,7 +194,7 @@ func (d *Database) createAccount( return nil, err } } - if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil { + if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { if sqlutil.IsUniqueConstraintViolationErr(err) { return nil, sqlutil.ErrUserExists }