From 9583784e8a1110c5cb466f5a44736df9e6c113ce Mon Sep 17 00:00:00 2001 From: Till Faelligen Date: Mon, 14 Feb 2022 14:02:13 +0100 Subject: [PATCH] Add new coloumn to track accepted policy version --- userapi/storage/accounts/interface.go | 2 + .../accounts/postgres/accounts_table.go | 48 ++++++++++++++++++- userapi/storage/accounts/postgres/storage.go | 18 +++++++ .../accounts/sqlite3/accounts_table.go | 47 +++++++++++++++++- userapi/storage/accounts/sqlite3/storage.go | 18 +++++++ 5 files changed, 130 insertions(+), 3 deletions(-) diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index f03b3774c..e172c0951 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -52,6 +52,8 @@ type Database interface { DeactivateAccount(ctx context.Context, localpart string) (err error) CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error) GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) + GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) + GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) // Key backups CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error) diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go index b57aa901f..874b8abcf 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -39,9 +39,11 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- Identifies which application service this account belongs to, if any. appservice_id TEXT, -- If the account is currently active - is_deactivated BOOLEAN DEFAULT FALSE + is_deactivated BOOLEAN DEFAULT FALSE, + -- The policy version this user has accepted + policy_version TEXT -- TODO: - -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? + -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? ); -- Create sequence for autogenerated numeric usernames CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; @@ -65,6 +67,12 @@ const selectPasswordHashSQL = "" + const selectNewNumericLocalpartSQL = "" + "SELECT nextval('numeric_username_seq')" +const selectPrivacyPolicySQL = "" + + "SELECT policy_version FROM accounts_accounts WHERE localpart = $1" + +const batchSelectPrivacyPolicySQL = "" + + "SELECT localpart FROM accounts_accounts WHERE policy_version IS NULL or policy_version <> $1" + type accountsStatements struct { insertAccountStmt *sql.Stmt updatePasswordStmt *sql.Stmt @@ -72,6 +80,8 @@ type accountsStatements struct { selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt + selectPrivacyPolicyStmt *sql.Stmt + batchSelectPrivacyPolicyStmt *sql.Stmt serverName gomatrixserverlib.ServerName } @@ -89,6 +99,8 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server {&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL}, {&s.selectPasswordHashStmt, selectPasswordHashSQL}, {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, + {&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL}, + {&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL}, }.Prepare(db) } @@ -174,3 +186,35 @@ func (s *accountsStatements) selectNewNumericLocalpart( err = stmt.QueryRowContext(ctx).Scan(&id) return } + +// selectPrivacyPolicy gets the current privacy policy a specific user accepted +func (s *accountsStatements) selectPrivacyPolicy( + ctx context.Context, txn *sql.Tx, localPart string, +) (policy string, err error) { + stmt := s.selectPrivacyPolicyStmt + if txn != nil { + stmt = sqlutil.TxStmt(txn, stmt) + } + err = stmt.QueryRowContext(ctx, localPart).Scan(&policy) + return +} + +// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version +func (s *accountsStatements) batchSelectPrivacyPolicy( + ctx context.Context, txn *sql.Tx, policyVersion string, +) (userIDs []string, err error) { + stmt := s.batchSelectPrivacyPolicyStmt + if txn != nil { + stmt = sqlutil.TxStmt(txn, stmt) + } + rows, err := stmt.QueryContext(ctx, policyVersion) + defer rows.Close() + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return userIDs, err + } + userIDs = append(userIDs, userID) + } + return userIDs, rows.Err() +} diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index 2f8290623..b128f450c 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -518,3 +518,21 @@ func (d *Database) UpsertBackupKeys( }) return } + +// GetPrivacyPolicy returns the accepted privacy policy version, if any. +func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + policyVersion, err = d.accounts.selectPrivacyPolicy(ctx, txn, localpart) + return err + }) + return +} + +// GetOutdatedPolicy queries all users which didn't accept the current policy version +func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion) + return err + }) + return +} diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 8a7c8fba7..eab6114f1 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -39,7 +39,9 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- Identifies which application service this account belongs to, if any. appservice_id TEXT, -- If the account is currently active - is_deactivated BOOLEAN DEFAULT 0 + is_deactivated BOOLEAN DEFAULT 0, + -- The policy version this user has accepted + policy_version TEXT -- TODO: -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? ); @@ -63,6 +65,12 @@ const selectPasswordHashSQL = "" + const selectNewNumericLocalpartSQL = "" + "SELECT COUNT(localpart) FROM account_accounts" +const selectPrivacyPolicySQL = "" + + "SELECT policy_version FROM accounts_accounts WHERE localpart = $1" + +const batchSelectPrivacyPolicySQL = "" + + "SELECT localpart FROM accounts_accounts WHERE policy_version IS NULL or policy_version <> $1" + type accountsStatements struct { db *sql.DB insertAccountStmt *sql.Stmt @@ -71,6 +79,8 @@ type accountsStatements struct { selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt + selectPrivacyPolicyStmt *sql.Stmt + batchSelectPrivacyPolicyStmt *sql.Stmt serverName gomatrixserverlib.ServerName } @@ -89,6 +99,8 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server {&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL}, {&s.selectPasswordHashStmt, selectPasswordHashSQL}, {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, + {&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL}, + {&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL}, }.Prepare(db) } @@ -174,3 +186,36 @@ func (s *accountsStatements) selectNewNumericLocalpart( err = stmt.QueryRowContext(ctx).Scan(&id) return } + +// selectPrivacyPolicy gets the current privacy policy a specific user accepted + +func (s *accountsStatements) selectPrivacyPolicy( + ctx context.Context, txn *sql.Tx, localPart string, +) (policy string, err error) { + stmt := s.selectPrivacyPolicyStmt + if txn != nil { + stmt = sqlutil.TxStmt(txn, stmt) + } + err = stmt.QueryRowContext(ctx, localPart).Scan(&policy) + return +} + +// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version +func (s *accountsStatements) batchSelectPrivacyPolicy( + ctx context.Context, txn *sql.Tx, policyVersion string, +) (userIDs []string, err error) { + stmt := s.batchSelectPrivacyPolicyStmt + if txn != nil { + stmt = sqlutil.TxStmt(txn, stmt) + } + rows, err := stmt.QueryContext(ctx, policyVersion) + defer rows.Close() + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return userIDs, err + } + userIDs = append(userIDs, userID) + } + return userIDs, rows.Err() +} diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 2b731b759..911492d53 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -561,3 +561,21 @@ func (d *Database) UpsertBackupKeys( }) return } + +// GetPrivacyPolicy returns the accepted privacy policy version, if any. +func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + policyVersion, err = d.accounts.selectPrivacyPolicy(ctx, txn, localpart) + return err + }) + return +} + +// GetOutdatedPolicy queries all users which didn't accept the current policy version +func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion) + return err + }) + return +}