Add new coloumn to track accepted policy version

This commit is contained in:
Till Faelligen 2022-02-14 14:02:13 +01:00
parent b6ee34918c
commit 9583784e8a
5 changed files with 130 additions and 3 deletions

View file

@ -52,6 +52,8 @@ type Database interface {
DeactivateAccount(ctx context.Context, localpart string) (err error)
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error)
GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error)
// Key backups
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)

View file

@ -39,9 +39,11 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- Identifies which application service this account belongs to, if any.
appservice_id TEXT,
-- If the account is currently active
is_deactivated BOOLEAN DEFAULT FALSE
is_deactivated BOOLEAN DEFAULT FALSE,
-- The policy version this user has accepted
policy_version TEXT
-- TODO:
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
);
-- Create sequence for autogenerated numeric usernames
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
@ -65,6 +67,12 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" +
"SELECT nextval('numeric_username_seq')"
const selectPrivacyPolicySQL = "" +
"SELECT policy_version FROM accounts_accounts WHERE localpart = $1"
const batchSelectPrivacyPolicySQL = "" +
"SELECT localpart FROM accounts_accounts WHERE policy_version IS NULL or policy_version <> $1"
type accountsStatements struct {
insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt
@ -72,6 +80,8 @@ type accountsStatements struct {
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
selectPrivacyPolicyStmt *sql.Stmt
batchSelectPrivacyPolicyStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
@ -89,6 +99,8 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
}.Prepare(db)
}
@ -174,3 +186,35 @@ func (s *accountsStatements) selectNewNumericLocalpart(
err = stmt.QueryRowContext(ctx).Scan(&id)
return
}
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
func (s *accountsStatements) selectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, localPart string,
) (policy string, err error) {
stmt := s.selectPrivacyPolicyStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
err = stmt.QueryRowContext(ctx, localPart).Scan(&policy)
return
}
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
func (s *accountsStatements) batchSelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, policyVersion string,
) (userIDs []string, err error) {
stmt := s.batchSelectPrivacyPolicyStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
rows, err := stmt.QueryContext(ctx, policyVersion)
defer rows.Close()
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return userIDs, err
}
userIDs = append(userIDs, userID)
}
return userIDs, rows.Err()
}

View file

@ -518,3 +518,21 @@ func (d *Database) UpsertBackupKeys(
})
return
}
// GetPrivacyPolicy returns the accepted privacy policy version, if any.
func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
policyVersion, err = d.accounts.selectPrivacyPolicy(ctx, txn, localpart)
return err
})
return
}
// GetOutdatedPolicy queries all users which didn't accept the current policy version
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion)
return err
})
return
}

View file

@ -39,7 +39,9 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- Identifies which application service this account belongs to, if any.
appservice_id TEXT,
-- If the account is currently active
is_deactivated BOOLEAN DEFAULT 0
is_deactivated BOOLEAN DEFAULT 0,
-- The policy version this user has accepted
policy_version TEXT
-- TODO:
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
);
@ -63,6 +65,12 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" +
"SELECT COUNT(localpart) FROM account_accounts"
const selectPrivacyPolicySQL = "" +
"SELECT policy_version FROM accounts_accounts WHERE localpart = $1"
const batchSelectPrivacyPolicySQL = "" +
"SELECT localpart FROM accounts_accounts WHERE policy_version IS NULL or policy_version <> $1"
type accountsStatements struct {
db *sql.DB
insertAccountStmt *sql.Stmt
@ -71,6 +79,8 @@ type accountsStatements struct {
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
selectPrivacyPolicyStmt *sql.Stmt
batchSelectPrivacyPolicyStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
@ -89,6 +99,8 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
}.Prepare(db)
}
@ -174,3 +186,36 @@ func (s *accountsStatements) selectNewNumericLocalpart(
err = stmt.QueryRowContext(ctx).Scan(&id)
return
}
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
func (s *accountsStatements) selectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, localPart string,
) (policy string, err error) {
stmt := s.selectPrivacyPolicyStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
err = stmt.QueryRowContext(ctx, localPart).Scan(&policy)
return
}
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
func (s *accountsStatements) batchSelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, policyVersion string,
) (userIDs []string, err error) {
stmt := s.batchSelectPrivacyPolicyStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
rows, err := stmt.QueryContext(ctx, policyVersion)
defer rows.Close()
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return userIDs, err
}
userIDs = append(userIDs, userID)
}
return userIDs, rows.Err()
}

View file

@ -561,3 +561,21 @@ func (d *Database) UpsertBackupKeys(
})
return
}
// GetPrivacyPolicy returns the accepted privacy policy version, if any.
func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
policyVersion, err = d.accounts.selectPrivacyPolicy(ctx, txn, localpart)
return err
})
return
}
// GetOutdatedPolicy queries all users which didn't accept the current policy version
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion)
return err
})
return
}