Add missing files
This commit is contained in:
parent
9c3a1cfd47
commit
2e6987f8bd
|
@ -57,17 +57,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
|
||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
|
||||
if req.AccountType == api.AccountTypeGuest {
|
||||
acc, err := a.AccountDB.CreateGuestAccount(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.AccountCreated = true
|
||||
res.Account = acc
|
||||
return nil
|
||||
}
|
||||
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion)
|
||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion, req.AccountType)
|
||||
if err != nil {
|
||||
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
||||
switch req.OnConflict {
|
||||
|
@ -612,7 +602,7 @@ func (a *UserInternalAPI) QueryPolicyVersion(
|
|||
res *api.QueryPolicyVersionResponse,
|
||||
) error {
|
||||
var err error
|
||||
res.PolicyVersion, err = a.AccountDB.GetPrivacyPolicy(ctx, req.LocalPart)
|
||||
res.PolicyVersion, err = a.DB.GetPrivacyPolicy(ctx, req.LocalPart)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -626,7 +616,7 @@ func (a *UserInternalAPI) GetOutdatedPolicy(
|
|||
res *api.QueryOutdatedPolicyUsersResponse,
|
||||
) error {
|
||||
var err error
|
||||
res.OutdatedUsers, err = a.AccountDB.GetOutdatedPolicy(ctx, req.PolicyVersion)
|
||||
res.OutdatedUsers, err = a.DB.GetOutdatedPolicy(ctx, req.PolicyVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -639,5 +629,5 @@ func (a *UserInternalAPI) PerformUpdatePolicyVersion(
|
|||
req *api.UpdatePolicyVersionRequest,
|
||||
res *api.UpdatePolicyVersionResponse,
|
||||
) error {
|
||||
return a.AccountDB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart)
|
||||
return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart)
|
||||
}
|
||||
|
|
|
@ -199,7 +199,7 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
|
|||
}
|
||||
|
||||
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
|
||||
func (s *accountsStatements) selectPrivacyPolicy(
|
||||
func (s *accountsStatements) SelectPrivacyPolicy(
|
||||
ctx context.Context, txn *sql.Tx, localPart string,
|
||||
) (policy string, err error) {
|
||||
stmt := s.selectPrivacyPolicyStmt
|
||||
|
@ -211,7 +211,7 @@ func (s *accountsStatements) selectPrivacyPolicy(
|
|||
}
|
||||
|
||||
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
|
||||
func (s *accountsStatements) batchSelectPrivacyPolicy(
|
||||
func (s *accountsStatements) BatchSelectPrivacyPolicy(
|
||||
ctx context.Context, txn *sql.Tx, policyVersion string,
|
||||
) (userIDs []string, err error) {
|
||||
stmt := s.batchSelectPrivacyPolicyStmt
|
||||
|
@ -231,7 +231,7 @@ func (s *accountsStatements) batchSelectPrivacyPolicy(
|
|||
}
|
||||
|
||||
// updatePolicyVersion sets the policy_version for a specific user
|
||||
func (s *accountsStatements) updatePolicyVersion(
|
||||
func (s *accountsStatements) UpdatePolicyVersion(
|
||||
ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
|
||||
) (err error) {
|
||||
stmt := s.updatePolicyVersionStmt
|
||||
|
|
|
@ -16,7 +16,9 @@ package shared
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -671,8 +673,8 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
|
|||
|
||||
// GetPrivacyPolicy returns the accepted privacy policy version, if any.
|
||||
func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) {
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
policyVersion, err = d.Accounts.selectPrivacyPolicy(ctx, txn, localpart)
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
policyVersion, err = d.Accounts.SelectPrivacyPolicy(ctx, txn, localpart)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
@ -680,8 +682,8 @@ func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (poli
|
|||
|
||||
// GetOutdatedPolicy queries all users which didn't accept the current policy version
|
||||
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion)
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
userIDs, err = d.Accounts.BatchSelectPrivacyPolicy(ctx, txn, policyVersion)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
@ -689,8 +691,8 @@ func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string)
|
|||
|
||||
// UpdatePolicyVersion sets the accepted policy_version for a user.
|
||||
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) {
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.accounts.updatePolicyVersion(ctx, txn, policyVersion, localpart)
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Accounts.UpdatePolicyVersion(ctx, txn, policyVersion, localpart)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -199,7 +199,7 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
|
|||
|
||||
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
|
||||
|
||||
func (s *accountsStatements) selectPrivacyPolicy(
|
||||
func (s *accountsStatements) SelectPrivacyPolicy(
|
||||
ctx context.Context, txn *sql.Tx, localPart string,
|
||||
) (policy string, err error) {
|
||||
stmt := s.selectPrivacyPolicyStmt
|
||||
|
@ -211,7 +211,7 @@ func (s *accountsStatements) selectPrivacyPolicy(
|
|||
}
|
||||
|
||||
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
|
||||
func (s *accountsStatements) batchSelectPrivacyPolicy(
|
||||
func (s *accountsStatements) BatchSelectPrivacyPolicy(
|
||||
ctx context.Context, txn *sql.Tx, policyVersion string,
|
||||
) (userIDs []string, err error) {
|
||||
stmt := s.batchSelectPrivacyPolicyStmt
|
||||
|
@ -231,7 +231,7 @@ func (s *accountsStatements) batchSelectPrivacyPolicy(
|
|||
}
|
||||
|
||||
// updatePolicyVersion sets the policy_version for a specific user
|
||||
func (s *accountsStatements) updatePolicyVersion(
|
||||
func (s *accountsStatements) UpdatePolicyVersion(
|
||||
ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
|
||||
) (err error) {
|
||||
stmt := s.updatePolicyVersionStmt
|
||||
|
|
|
@ -30,12 +30,16 @@ type AccountDataTable interface {
|
|||
}
|
||||
|
||||
type AccountsTable interface {
|
||||
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType) (*api.Account, error)
|
||||
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
|
||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
|
||||
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
||||
|
||||
SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error)
|
||||
BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error)
|
||||
UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string) (err error)
|
||||
}
|
||||
|
||||
type DevicesTable interface {
|
||||
|
|
Loading…
Reference in a new issue