Add missing files

This commit is contained in:
Till Faelligen 2022-02-21 12:12:07 +01:00
parent 9c3a1cfd47
commit 2e6987f8bd
5 changed files with 23 additions and 27 deletions

View file

@ -57,17 +57,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
} }
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion, req.AccountType)
if req.AccountType == api.AccountTypeGuest {
acc, err := a.AccountDB.CreateGuestAccount(ctx)
if err != nil {
return err
}
res.AccountCreated = true
res.Account = acc
return nil
}
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion)
if err != nil { if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict { switch req.OnConflict {
@ -612,7 +602,7 @@ func (a *UserInternalAPI) QueryPolicyVersion(
res *api.QueryPolicyVersionResponse, res *api.QueryPolicyVersionResponse,
) error { ) error {
var err error var err error
res.PolicyVersion, err = a.AccountDB.GetPrivacyPolicy(ctx, req.LocalPart) res.PolicyVersion, err = a.DB.GetPrivacyPolicy(ctx, req.LocalPart)
if err != nil { if err != nil {
return err return err
} }
@ -626,7 +616,7 @@ func (a *UserInternalAPI) GetOutdatedPolicy(
res *api.QueryOutdatedPolicyUsersResponse, res *api.QueryOutdatedPolicyUsersResponse,
) error { ) error {
var err error var err error
res.OutdatedUsers, err = a.AccountDB.GetOutdatedPolicy(ctx, req.PolicyVersion) res.OutdatedUsers, err = a.DB.GetOutdatedPolicy(ctx, req.PolicyVersion)
if err != nil { if err != nil {
return err return err
} }
@ -639,5 +629,5 @@ func (a *UserInternalAPI) PerformUpdatePolicyVersion(
req *api.UpdatePolicyVersionRequest, req *api.UpdatePolicyVersionRequest,
res *api.UpdatePolicyVersionResponse, res *api.UpdatePolicyVersionResponse,
) error { ) error {
return a.AccountDB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart) return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart)
} }

View file

@ -199,7 +199,7 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
} }
// selectPrivacyPolicy gets the current privacy policy a specific user accepted // selectPrivacyPolicy gets the current privacy policy a specific user accepted
func (s *accountsStatements) selectPrivacyPolicy( func (s *accountsStatements) SelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, localPart string, ctx context.Context, txn *sql.Tx, localPart string,
) (policy string, err error) { ) (policy string, err error) {
stmt := s.selectPrivacyPolicyStmt stmt := s.selectPrivacyPolicyStmt
@ -211,7 +211,7 @@ func (s *accountsStatements) selectPrivacyPolicy(
} }
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version // batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
func (s *accountsStatements) batchSelectPrivacyPolicy( func (s *accountsStatements) BatchSelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, policyVersion string, ctx context.Context, txn *sql.Tx, policyVersion string,
) (userIDs []string, err error) { ) (userIDs []string, err error) {
stmt := s.batchSelectPrivacyPolicyStmt stmt := s.batchSelectPrivacyPolicyStmt
@ -231,7 +231,7 @@ func (s *accountsStatements) batchSelectPrivacyPolicy(
} }
// updatePolicyVersion sets the policy_version for a specific user // updatePolicyVersion sets the policy_version for a specific user
func (s *accountsStatements) updatePolicyVersion( func (s *accountsStatements) UpdatePolicyVersion(
ctx context.Context, txn *sql.Tx, policyVersion, localpart string, ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
) (err error) { ) (err error) {
stmt := s.updatePolicyVersionStmt stmt := s.updatePolicyVersionStmt

View file

@ -16,7 +16,9 @@ package shared
import ( import (
"context" "context"
"crypto/rand"
"database/sql" "database/sql"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -671,8 +673,8 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
// GetPrivacyPolicy returns the accepted privacy policy version, if any. // GetPrivacyPolicy returns the accepted privacy policy version, if any.
func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) { func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
policyVersion, err = d.Accounts.selectPrivacyPolicy(ctx, txn, localpart) policyVersion, err = d.Accounts.SelectPrivacyPolicy(ctx, txn, localpart)
return err return err
}) })
return return
@ -680,8 +682,8 @@ func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (poli
// GetOutdatedPolicy queries all users which didn't accept the current policy version // GetOutdatedPolicy queries all users which didn't accept the current policy version
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) { func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion) userIDs, err = d.Accounts.BatchSelectPrivacyPolicy(ctx, txn, policyVersion)
return err return err
}) })
return return
@ -689,8 +691,8 @@ func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string)
// UpdatePolicyVersion sets the accepted policy_version for a user. // UpdatePolicyVersion sets the accepted policy_version for a user.
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) { func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.accounts.updatePolicyVersion(ctx, txn, policyVersion, localpart) return d.Accounts.UpdatePolicyVersion(ctx, txn, policyVersion, localpart)
}) })
return return
} }

View file

@ -199,7 +199,7 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
// selectPrivacyPolicy gets the current privacy policy a specific user accepted // selectPrivacyPolicy gets the current privacy policy a specific user accepted
func (s *accountsStatements) selectPrivacyPolicy( func (s *accountsStatements) SelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, localPart string, ctx context.Context, txn *sql.Tx, localPart string,
) (policy string, err error) { ) (policy string, err error) {
stmt := s.selectPrivacyPolicyStmt stmt := s.selectPrivacyPolicyStmt
@ -211,7 +211,7 @@ func (s *accountsStatements) selectPrivacyPolicy(
} }
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version // batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
func (s *accountsStatements) batchSelectPrivacyPolicy( func (s *accountsStatements) BatchSelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, policyVersion string, ctx context.Context, txn *sql.Tx, policyVersion string,
) (userIDs []string, err error) { ) (userIDs []string, err error) {
stmt := s.batchSelectPrivacyPolicyStmt stmt := s.batchSelectPrivacyPolicyStmt
@ -231,7 +231,7 @@ func (s *accountsStatements) batchSelectPrivacyPolicy(
} }
// updatePolicyVersion sets the policy_version for a specific user // updatePolicyVersion sets the policy_version for a specific user
func (s *accountsStatements) updatePolicyVersion( func (s *accountsStatements) UpdatePolicyVersion(
ctx context.Context, txn *sql.Tx, policyVersion, localpart string, ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
) (err error) { ) (err error) {
stmt := s.updatePolicyVersionStmt stmt := s.updatePolicyVersionStmt

View file

@ -30,12 +30,16 @@ type AccountDataTable interface {
} }
type AccountsTable interface { type AccountsTable interface {
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType) (*api.Account, error)
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error) UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
DeactivateAccount(ctx context.Context, localpart string) (err error) DeactivateAccount(ctx context.Context, localpart string) (err error)
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error) SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error)
BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error)
UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string) (err error)
} }
type DevicesTable interface { type DevicesTable interface {