From fb95331aa24b6e9e01c15288ab8a867e0659fde7 Mon Sep 17 00:00:00 2001 From: Till Faelligen Date: Mon, 21 Feb 2022 14:26:00 +0100 Subject: [PATCH] Add posibility to track sent policy versions --- userapi/api/api.go | 1 + userapi/internal/api.go | 2 +- userapi/storage/interface.go | 2 +- userapi/storage/postgres/accounts_table.go | 36 ++++++++++------ .../2022021414375800_add_policy_version.go | 6 ++- userapi/storage/shared/storage.go | 4 +- userapi/storage/sqlite3/accounts_table.go | 43 ++++++++++++------- .../2022021414375800_add_policy_version.go | 6 ++- userapi/storage/tables/interface.go | 2 +- 9 files changed, 65 insertions(+), 37 deletions(-) diff --git a/userapi/api/api.go b/userapi/api/api.go index 5140cc5b8..a40275307 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -362,6 +362,7 @@ type QueryOutdatedPolicyUsersResponse struct { // UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest type UpdatePolicyVersionRequest struct { PolicyVersion, LocalPart string + ServerNoticeUpdate bool } // UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 57d8e09e0..862dd0c2c 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -629,5 +629,5 @@ func (a *UserInternalAPI) PerformUpdatePolicyVersion( req *api.UpdatePolicyVersionRequest, res *api.UpdatePolicyVersionResponse, ) error { - return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart) + return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart, req.ServerNoticeUpdate) } diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 9e4ff4b4b..39164fed0 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -53,7 +53,7 @@ type Database interface { GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) - UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) error + UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) error // Key backups CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error) diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 2eedee82f..22ef83274 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -45,7 +45,9 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) account_type SMALLINT NOT NULL, -- The policy version this user has accepted - policy_version TEXT + policy_version TEXT, + -- The policy version the user received from the server notices room + policy_version_sent TEXT -- TODO: -- upgraded_ts, devices, any email reset stuff? ); @@ -75,22 +77,26 @@ const selectPrivacyPolicySQL = "" + "SELECT policy_version FROM account_accounts WHERE localpart = $1" const batchSelectPrivacyPolicySQL = "" + - "SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1" + "SELECT localpart FROM account_accounts WHERE (policy_version IS NULL OR policy_version <> $1) AND (policy_version_sent IS NULL OR policy_version_sent <> $1)" const updatePolicyVersionSQL = "" + "UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2" +const updatePolicyVersionServerNoticeSQL = "" + + "UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2" + type accountsStatements struct { - insertAccountStmt *sql.Stmt - updatePasswordStmt *sql.Stmt - deactivateAccountStmt *sql.Stmt - selectAccountByLocalpartStmt *sql.Stmt - selectPasswordHashStmt *sql.Stmt - selectNewNumericLocalpartStmt *sql.Stmt - selectPrivacyPolicyStmt *sql.Stmt - batchSelectPrivacyPolicyStmt *sql.Stmt - updatePolicyVersionStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + insertAccountStmt *sql.Stmt + updatePasswordStmt *sql.Stmt + deactivateAccountStmt *sql.Stmt + selectAccountByLocalpartStmt *sql.Stmt + selectPasswordHashStmt *sql.Stmt + selectNewNumericLocalpartStmt *sql.Stmt + selectPrivacyPolicyStmt *sql.Stmt + batchSelectPrivacyPolicyStmt *sql.Stmt + updatePolicyVersionStmt *sql.Stmt + updatePolicyVersionServerNoticeStmt *sql.Stmt + serverName gomatrixserverlib.ServerName } func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { @@ -111,6 +117,7 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam {&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL}, {&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL}, {&s.updatePolicyVersionStmt, updatePolicyVersionSQL}, + {&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL}, }.Prepare(db) } @@ -232,9 +239,12 @@ func (s *accountsStatements) BatchSelectPrivacyPolicy( // updatePolicyVersion sets the policy_version for a specific user func (s *accountsStatements) UpdatePolicyVersion( - ctx context.Context, txn *sql.Tx, policyVersion, localpart string, + ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool, ) (err error) { stmt := s.updatePolicyVersionStmt + if serverNotice { + stmt = s.updatePolicyVersionServerNoticeStmt + } if txn != nil { stmt = sqlutil.TxStmt(txn, stmt) } diff --git a/userapi/storage/postgres/deltas/2022021414375800_add_policy_version.go b/userapi/storage/postgres/deltas/2022021414375800_add_policy_version.go index 8cbbe6f33..10b1006df 100644 --- a/userapi/storage/postgres/deltas/2022021414375800_add_policy_version.go +++ b/userapi/storage/postgres/deltas/2022021414375800_add_policy_version.go @@ -12,7 +12,8 @@ func LoadAddPolicyVersion(m *sqlutil.Migrations) { } func UpAddPolicyVersion(tx *sql.Tx) error { - _, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version TEXT;") + _, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version TEXT;" + + "ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version_sent TEXT;") if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } @@ -20,7 +21,8 @@ func UpAddPolicyVersion(tx *sql.Tx) error { } func DownAddPolicyVersion(tx *sql.Tx) error { - _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;") + _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;" + + "ALTER TABLE account_accounts DROP COLUMN policy_version_sent;") if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index c28505f15..085a4e207 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -690,9 +690,9 @@ func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) } // UpdatePolicyVersion sets the accepted policy_version for a user. -func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) { +func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) (err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Accounts.UpdatePolicyVersion(ctx, txn, policyVersion, localpart) + return d.Accounts.UpdatePolicyVersion(ctx, txn, policyVersion, localpart, serverNotice) }) return } diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index 0bf54ee18..cbd073d51 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -45,7 +45,9 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) account_type INTEGER NOT NULL, -- The policy version this user has accepted - policy_version TEXT + policy_version TEXT, + -- The policy version the user received from the server notices room + policy_version_sent TEXT -- TODO: -- upgraded_ts, devices, any email reset stuff? ); @@ -73,23 +75,27 @@ const selectPrivacyPolicySQL = "" + "SELECT policy_version FROM account_accounts WHERE localpart = $1" const batchSelectPrivacyPolicySQL = "" + - "SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1" + "SELECT localpart FROM account_accounts WHERE (policy_version IS NULL OR policy_version <> $1) AND (policy_version_sent IS NULL OR policy_version_sent <> $2)" const updatePolicyVersionSQL = "" + "UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2" +const updatePolicyVersionServerNoticeSQL = "" + + "UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2" + type accountsStatements struct { - db *sql.DB - insertAccountStmt *sql.Stmt - updatePasswordStmt *sql.Stmt - deactivateAccountStmt *sql.Stmt - selectAccountByLocalpartStmt *sql.Stmt - selectPasswordHashStmt *sql.Stmt - selectNewNumericLocalpartStmt *sql.Stmt - selectPrivacyPolicyStmt *sql.Stmt - batchSelectPrivacyPolicyStmt *sql.Stmt - updatePolicyVersionStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + db *sql.DB + insertAccountStmt *sql.Stmt + updatePasswordStmt *sql.Stmt + deactivateAccountStmt *sql.Stmt + selectAccountByLocalpartStmt *sql.Stmt + selectPasswordHashStmt *sql.Stmt + selectNewNumericLocalpartStmt *sql.Stmt + selectPrivacyPolicyStmt *sql.Stmt + batchSelectPrivacyPolicyStmt *sql.Stmt + updatePolicyVersionStmt *sql.Stmt + updatePolicyVersionServerNoticeStmt *sql.Stmt + serverName gomatrixserverlib.ServerName } func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { @@ -111,6 +117,7 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) {&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL}, {&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL}, {&s.updatePolicyVersionStmt, updatePolicyVersionSQL}, + {&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL}, }.Prepare(db) } @@ -218,7 +225,10 @@ func (s *accountsStatements) BatchSelectPrivacyPolicy( if txn != nil { stmt = sqlutil.TxStmt(txn, stmt) } - rows, err := stmt.QueryContext(ctx, policyVersion) + rows, err := stmt.QueryContext(ctx, policyVersion, policyVersion) + if err != nil { + return nil, err + } defer rows.Close() for rows.Next() { var userID string @@ -232,9 +242,12 @@ func (s *accountsStatements) BatchSelectPrivacyPolicy( // updatePolicyVersion sets the policy_version for a specific user func (s *accountsStatements) UpdatePolicyVersion( - ctx context.Context, txn *sql.Tx, policyVersion, localpart string, + ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool, ) (err error) { stmt := s.updatePolicyVersionStmt + if serverNotice { + stmt = s.updatePolicyVersionServerNoticeStmt + } if txn != nil { stmt = sqlutil.TxStmt(txn, stmt) } diff --git a/userapi/storage/sqlite3/deltas/2022021414375800_add_policy_version.go b/userapi/storage/sqlite3/deltas/2022021414375800_add_policy_version.go index ae69cf0f0..2292b9031 100644 --- a/userapi/storage/sqlite3/deltas/2022021414375800_add_policy_version.go +++ b/userapi/storage/sqlite3/deltas/2022021414375800_add_policy_version.go @@ -12,7 +12,8 @@ func LoadAddPolicyVersion(m *sqlutil.Migrations) { } func UpAddPolicyVersion(tx *sql.Tx) error { - _, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN policy_version TEXT;") + _, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN policy_version TEXT;" + + "ALTER TABLE account_accounts ADD COLUMN policy_version_sent TEXT;") if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } @@ -20,7 +21,8 @@ func UpAddPolicyVersion(tx *sql.Tx) error { } func DownAddPolicyVersion(tx *sql.Tx) error { - _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;") + _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;" + + "ALTER TABLE account_accounts DROP COLUMN policy_version_sent;") if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 59f96d0b5..a126086b1 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -39,7 +39,7 @@ type AccountsTable interface { SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error) BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error) - UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string) (err error) + UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool) (err error) } type DevicesTable interface {