Add missing files
This commit is contained in:
parent
9c3a1cfd47
commit
2e6987f8bd
|
@ -57,17 +57,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
|
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
|
||||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
|
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion, req.AccountType)
|
||||||
if req.AccountType == api.AccountTypeGuest {
|
|
||||||
acc, err := a.AccountDB.CreateGuestAccount(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
res.AccountCreated = true
|
|
||||||
res.Account = acc
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
||||||
switch req.OnConflict {
|
switch req.OnConflict {
|
||||||
|
@ -612,7 +602,7 @@ func (a *UserInternalAPI) QueryPolicyVersion(
|
||||||
res *api.QueryPolicyVersionResponse,
|
res *api.QueryPolicyVersionResponse,
|
||||||
) error {
|
) error {
|
||||||
var err error
|
var err error
|
||||||
res.PolicyVersion, err = a.AccountDB.GetPrivacyPolicy(ctx, req.LocalPart)
|
res.PolicyVersion, err = a.DB.GetPrivacyPolicy(ctx, req.LocalPart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -626,7 +616,7 @@ func (a *UserInternalAPI) GetOutdatedPolicy(
|
||||||
res *api.QueryOutdatedPolicyUsersResponse,
|
res *api.QueryOutdatedPolicyUsersResponse,
|
||||||
) error {
|
) error {
|
||||||
var err error
|
var err error
|
||||||
res.OutdatedUsers, err = a.AccountDB.GetOutdatedPolicy(ctx, req.PolicyVersion)
|
res.OutdatedUsers, err = a.DB.GetOutdatedPolicy(ctx, req.PolicyVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -639,5 +629,5 @@ func (a *UserInternalAPI) PerformUpdatePolicyVersion(
|
||||||
req *api.UpdatePolicyVersionRequest,
|
req *api.UpdatePolicyVersionRequest,
|
||||||
res *api.UpdatePolicyVersionResponse,
|
res *api.UpdatePolicyVersionResponse,
|
||||||
) error {
|
) error {
|
||||||
return a.AccountDB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart)
|
return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart)
|
||||||
}
|
}
|
||||||
|
|
|
@ -199,7 +199,7 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
|
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
|
||||||
func (s *accountsStatements) selectPrivacyPolicy(
|
func (s *accountsStatements) SelectPrivacyPolicy(
|
||||||
ctx context.Context, txn *sql.Tx, localPart string,
|
ctx context.Context, txn *sql.Tx, localPart string,
|
||||||
) (policy string, err error) {
|
) (policy string, err error) {
|
||||||
stmt := s.selectPrivacyPolicyStmt
|
stmt := s.selectPrivacyPolicyStmt
|
||||||
|
@ -211,7 +211,7 @@ func (s *accountsStatements) selectPrivacyPolicy(
|
||||||
}
|
}
|
||||||
|
|
||||||
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
|
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
|
||||||
func (s *accountsStatements) batchSelectPrivacyPolicy(
|
func (s *accountsStatements) BatchSelectPrivacyPolicy(
|
||||||
ctx context.Context, txn *sql.Tx, policyVersion string,
|
ctx context.Context, txn *sql.Tx, policyVersion string,
|
||||||
) (userIDs []string, err error) {
|
) (userIDs []string, err error) {
|
||||||
stmt := s.batchSelectPrivacyPolicyStmt
|
stmt := s.batchSelectPrivacyPolicyStmt
|
||||||
|
@ -231,7 +231,7 @@ func (s *accountsStatements) batchSelectPrivacyPolicy(
|
||||||
}
|
}
|
||||||
|
|
||||||
// updatePolicyVersion sets the policy_version for a specific user
|
// updatePolicyVersion sets the policy_version for a specific user
|
||||||
func (s *accountsStatements) updatePolicyVersion(
|
func (s *accountsStatements) UpdatePolicyVersion(
|
||||||
ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
|
ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := s.updatePolicyVersionStmt
|
stmt := s.updatePolicyVersionStmt
|
||||||
|
|
|
@ -16,7 +16,9 @@ package shared
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -671,8 +673,8 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
|
||||||
|
|
||||||
// GetPrivacyPolicy returns the accepted privacy policy version, if any.
|
// GetPrivacyPolicy returns the accepted privacy policy version, if any.
|
||||||
func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) {
|
func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) {
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
policyVersion, err = d.Accounts.selectPrivacyPolicy(ctx, txn, localpart)
|
policyVersion, err = d.Accounts.SelectPrivacyPolicy(ctx, txn, localpart)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
@ -680,8 +682,8 @@ func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (poli
|
||||||
|
|
||||||
// GetOutdatedPolicy queries all users which didn't accept the current policy version
|
// GetOutdatedPolicy queries all users which didn't accept the current policy version
|
||||||
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
|
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion)
|
userIDs, err = d.Accounts.BatchSelectPrivacyPolicy(ctx, txn, policyVersion)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
@ -689,8 +691,8 @@ func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string)
|
||||||
|
|
||||||
// UpdatePolicyVersion sets the accepted policy_version for a user.
|
// UpdatePolicyVersion sets the accepted policy_version for a user.
|
||||||
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) {
|
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) {
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.accounts.updatePolicyVersion(ctx, txn, policyVersion, localpart)
|
return d.Accounts.UpdatePolicyVersion(ctx, txn, policyVersion, localpart)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -199,7 +199,7 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||||
|
|
||||||
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
|
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
|
||||||
|
|
||||||
func (s *accountsStatements) selectPrivacyPolicy(
|
func (s *accountsStatements) SelectPrivacyPolicy(
|
||||||
ctx context.Context, txn *sql.Tx, localPart string,
|
ctx context.Context, txn *sql.Tx, localPart string,
|
||||||
) (policy string, err error) {
|
) (policy string, err error) {
|
||||||
stmt := s.selectPrivacyPolicyStmt
|
stmt := s.selectPrivacyPolicyStmt
|
||||||
|
@ -211,7 +211,7 @@ func (s *accountsStatements) selectPrivacyPolicy(
|
||||||
}
|
}
|
||||||
|
|
||||||
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
|
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
|
||||||
func (s *accountsStatements) batchSelectPrivacyPolicy(
|
func (s *accountsStatements) BatchSelectPrivacyPolicy(
|
||||||
ctx context.Context, txn *sql.Tx, policyVersion string,
|
ctx context.Context, txn *sql.Tx, policyVersion string,
|
||||||
) (userIDs []string, err error) {
|
) (userIDs []string, err error) {
|
||||||
stmt := s.batchSelectPrivacyPolicyStmt
|
stmt := s.batchSelectPrivacyPolicyStmt
|
||||||
|
@ -231,7 +231,7 @@ func (s *accountsStatements) batchSelectPrivacyPolicy(
|
||||||
}
|
}
|
||||||
|
|
||||||
// updatePolicyVersion sets the policy_version for a specific user
|
// updatePolicyVersion sets the policy_version for a specific user
|
||||||
func (s *accountsStatements) updatePolicyVersion(
|
func (s *accountsStatements) UpdatePolicyVersion(
|
||||||
ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
|
ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := s.updatePolicyVersionStmt
|
stmt := s.updatePolicyVersionStmt
|
||||||
|
|
|
@ -30,12 +30,16 @@ type AccountDataTable interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccountsTable interface {
|
type AccountsTable interface {
|
||||||
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType) (*api.Account, error)
|
||||||
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
|
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
|
||||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||||
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
|
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
|
||||||
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||||
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
||||||
|
|
||||||
|
SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error)
|
||||||
|
BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error)
|
||||||
|
UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DevicesTable interface {
|
type DevicesTable interface {
|
||||||
|
|
Loading…
Reference in a new issue