diff --git a/userapi/internal/api.go b/userapi/internal/api.go index fdcf796fd..57d8e09e0 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -57,17 +57,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc } func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { - acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) - if req.AccountType == api.AccountTypeGuest { - acc, err := a.AccountDB.CreateGuestAccount(ctx) - if err != nil { - return err - } - res.AccountCreated = true - res.Account = acc - return nil - } - acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion) + acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion, req.AccountType) if err != nil { if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists switch req.OnConflict { @@ -612,7 +602,7 @@ func (a *UserInternalAPI) QueryPolicyVersion( res *api.QueryPolicyVersionResponse, ) error { var err error - res.PolicyVersion, err = a.AccountDB.GetPrivacyPolicy(ctx, req.LocalPart) + res.PolicyVersion, err = a.DB.GetPrivacyPolicy(ctx, req.LocalPart) if err != nil { return err } @@ -626,7 +616,7 @@ func (a *UserInternalAPI) GetOutdatedPolicy( res *api.QueryOutdatedPolicyUsersResponse, ) error { var err error - res.OutdatedUsers, err = a.AccountDB.GetOutdatedPolicy(ctx, req.PolicyVersion) + res.OutdatedUsers, err = a.DB.GetOutdatedPolicy(ctx, req.PolicyVersion) if err != nil { return err } @@ -639,5 +629,5 @@ func (a *UserInternalAPI) PerformUpdatePolicyVersion( req *api.UpdatePolicyVersionRequest, res *api.UpdatePolicyVersionResponse, ) error { - return a.AccountDB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart) + return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart) } diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index f276e2038..2eedee82f 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -199,7 +199,7 @@ func (s *accountsStatements) SelectNewNumericLocalpart( } // selectPrivacyPolicy gets the current privacy policy a specific user accepted -func (s *accountsStatements) selectPrivacyPolicy( +func (s *accountsStatements) SelectPrivacyPolicy( ctx context.Context, txn *sql.Tx, localPart string, ) (policy string, err error) { stmt := s.selectPrivacyPolicyStmt @@ -211,7 +211,7 @@ func (s *accountsStatements) selectPrivacyPolicy( } // batchSelectPrivacyPolicy queries all users which didn't accept the current policy version -func (s *accountsStatements) batchSelectPrivacyPolicy( +func (s *accountsStatements) BatchSelectPrivacyPolicy( ctx context.Context, txn *sql.Tx, policyVersion string, ) (userIDs []string, err error) { stmt := s.batchSelectPrivacyPolicyStmt @@ -231,7 +231,7 @@ func (s *accountsStatements) batchSelectPrivacyPolicy( } // updatePolicyVersion sets the policy_version for a specific user -func (s *accountsStatements) updatePolicyVersion( +func (s *accountsStatements) UpdatePolicyVersion( ctx context.Context, txn *sql.Tx, policyVersion, localpart string, ) (err error) { stmt := s.updatePolicyVersionStmt diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 1d48315cf..c28505f15 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -16,7 +16,9 @@ package shared import ( "context" + "crypto/rand" "database/sql" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -671,8 +673,8 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) ( // GetPrivacyPolicy returns the accepted privacy policy version, if any. func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - policyVersion, err = d.Accounts.selectPrivacyPolicy(ctx, txn, localpart) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + policyVersion, err = d.Accounts.SelectPrivacyPolicy(ctx, txn, localpart) return err }) return @@ -680,8 +682,8 @@ func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (poli // GetOutdatedPolicy queries all users which didn't accept the current policy version func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + userIDs, err = d.Accounts.BatchSelectPrivacyPolicy(ctx, txn, policyVersion) return err }) return @@ -689,8 +691,8 @@ func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) // UpdatePolicyVersion sets the accepted policy_version for a user. func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.accounts.updatePolicyVersion(ctx, txn, policyVersion, localpart) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Accounts.UpdatePolicyVersion(ctx, txn, policyVersion, localpart) }) return } diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index 483736a80..0bf54ee18 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -199,7 +199,7 @@ func (s *accountsStatements) SelectNewNumericLocalpart( // selectPrivacyPolicy gets the current privacy policy a specific user accepted -func (s *accountsStatements) selectPrivacyPolicy( +func (s *accountsStatements) SelectPrivacyPolicy( ctx context.Context, txn *sql.Tx, localPart string, ) (policy string, err error) { stmt := s.selectPrivacyPolicyStmt @@ -211,7 +211,7 @@ func (s *accountsStatements) selectPrivacyPolicy( } // batchSelectPrivacyPolicy queries all users which didn't accept the current policy version -func (s *accountsStatements) batchSelectPrivacyPolicy( +func (s *accountsStatements) BatchSelectPrivacyPolicy( ctx context.Context, txn *sql.Tx, policyVersion string, ) (userIDs []string, err error) { stmt := s.batchSelectPrivacyPolicyStmt @@ -231,7 +231,7 @@ func (s *accountsStatements) batchSelectPrivacyPolicy( } // updatePolicyVersion sets the policy_version for a specific user -func (s *accountsStatements) updatePolicyVersion( +func (s *accountsStatements) UpdatePolicyVersion( ctx context.Context, txn *sql.Tx, policyVersion, localpart string, ) (err error) { stmt := s.updatePolicyVersionStmt diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 12939ced5..59f96d0b5 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -30,12 +30,16 @@ type AccountDataTable interface { } type AccountsTable interface { - InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) + InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType) (*api.Account, error) UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error) DeactivateAccount(ctx context.Context, localpart string) (err error) SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error) SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error) + + SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error) + BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error) + UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string) (err error) } type DevicesTable interface {