Add missing files

This commit is contained in:
Till Faelligen 2022-02-21 12:12:07 +01:00
parent 9c3a1cfd47
commit 2e6987f8bd
5 changed files with 23 additions and 27 deletions

View file

@ -57,17 +57,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
}
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
if req.AccountType == api.AccountTypeGuest {
acc, err := a.AccountDB.CreateGuestAccount(ctx)
if err != nil {
return err
}
res.AccountCreated = true
res.Account = acc
return nil
}
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion)
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion, req.AccountType)
if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict {
@ -612,7 +602,7 @@ func (a *UserInternalAPI) QueryPolicyVersion(
res *api.QueryPolicyVersionResponse,
) error {
var err error
res.PolicyVersion, err = a.AccountDB.GetPrivacyPolicy(ctx, req.LocalPart)
res.PolicyVersion, err = a.DB.GetPrivacyPolicy(ctx, req.LocalPart)
if err != nil {
return err
}
@ -626,7 +616,7 @@ func (a *UserInternalAPI) GetOutdatedPolicy(
res *api.QueryOutdatedPolicyUsersResponse,
) error {
var err error
res.OutdatedUsers, err = a.AccountDB.GetOutdatedPolicy(ctx, req.PolicyVersion)
res.OutdatedUsers, err = a.DB.GetOutdatedPolicy(ctx, req.PolicyVersion)
if err != nil {
return err
}
@ -639,5 +629,5 @@ func (a *UserInternalAPI) PerformUpdatePolicyVersion(
req *api.UpdatePolicyVersionRequest,
res *api.UpdatePolicyVersionResponse,
) error {
return a.AccountDB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart)
return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart)
}

View file

@ -199,7 +199,7 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
}
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
func (s *accountsStatements) selectPrivacyPolicy(
func (s *accountsStatements) SelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, localPart string,
) (policy string, err error) {
stmt := s.selectPrivacyPolicyStmt
@ -211,7 +211,7 @@ func (s *accountsStatements) selectPrivacyPolicy(
}
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
func (s *accountsStatements) batchSelectPrivacyPolicy(
func (s *accountsStatements) BatchSelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, policyVersion string,
) (userIDs []string, err error) {
stmt := s.batchSelectPrivacyPolicyStmt
@ -231,7 +231,7 @@ func (s *accountsStatements) batchSelectPrivacyPolicy(
}
// updatePolicyVersion sets the policy_version for a specific user
func (s *accountsStatements) updatePolicyVersion(
func (s *accountsStatements) UpdatePolicyVersion(
ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
) (err error) {
stmt := s.updatePolicyVersionStmt

View file

@ -16,7 +16,9 @@ package shared
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
@ -671,8 +673,8 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
// GetPrivacyPolicy returns the accepted privacy policy version, if any.
func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
policyVersion, err = d.Accounts.selectPrivacyPolicy(ctx, txn, localpart)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
policyVersion, err = d.Accounts.SelectPrivacyPolicy(ctx, txn, localpart)
return err
})
return
@ -680,8 +682,8 @@ func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (poli
// GetOutdatedPolicy queries all users which didn't accept the current policy version
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
userIDs, err = d.Accounts.BatchSelectPrivacyPolicy(ctx, txn, policyVersion)
return err
})
return
@ -689,8 +691,8 @@ func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string)
// UpdatePolicyVersion sets the accepted policy_version for a user.
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.accounts.updatePolicyVersion(ctx, txn, policyVersion, localpart)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Accounts.UpdatePolicyVersion(ctx, txn, policyVersion, localpart)
})
return
}

View file

@ -199,7 +199,7 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
func (s *accountsStatements) selectPrivacyPolicy(
func (s *accountsStatements) SelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, localPart string,
) (policy string, err error) {
stmt := s.selectPrivacyPolicyStmt
@ -211,7 +211,7 @@ func (s *accountsStatements) selectPrivacyPolicy(
}
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
func (s *accountsStatements) batchSelectPrivacyPolicy(
func (s *accountsStatements) BatchSelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, policyVersion string,
) (userIDs []string, err error) {
stmt := s.batchSelectPrivacyPolicyStmt
@ -231,7 +231,7 @@ func (s *accountsStatements) batchSelectPrivacyPolicy(
}
// updatePolicyVersion sets the policy_version for a specific user
func (s *accountsStatements) updatePolicyVersion(
func (s *accountsStatements) UpdatePolicyVersion(
ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
) (err error) {
stmt := s.updatePolicyVersionStmt

View file

@ -30,12 +30,16 @@ type AccountDataTable interface {
}
type AccountsTable interface {
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType) (*api.Account, error)
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
DeactivateAccount(ctx context.Context, localpart string) (err error)
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error)
BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error)
UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string) (err error)
}
type DevicesTable interface {