Add posibility to track sent policy versions

This commit is contained in:
Till Faelligen 2022-02-21 14:26:00 +01:00
parent cb4526793d
commit fb95331aa2
9 changed files with 65 additions and 37 deletions

View file

@ -362,6 +362,7 @@ type QueryOutdatedPolicyUsersResponse struct {
// UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest // UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest
type UpdatePolicyVersionRequest struct { type UpdatePolicyVersionRequest struct {
PolicyVersion, LocalPart string PolicyVersion, LocalPart string
ServerNoticeUpdate bool
} }
// UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest // UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest

View file

@ -629,5 +629,5 @@ func (a *UserInternalAPI) PerformUpdatePolicyVersion(
req *api.UpdatePolicyVersionRequest, req *api.UpdatePolicyVersionRequest,
res *api.UpdatePolicyVersionResponse, res *api.UpdatePolicyVersionResponse,
) error { ) error {
return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart) return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart, req.ServerNoticeUpdate)
} }

View file

@ -53,7 +53,7 @@ type Database interface {
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error)
GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error)
UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) error UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) error
// Key backups // Key backups
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error) CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)

View file

@ -45,7 +45,9 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
account_type SMALLINT NOT NULL, account_type SMALLINT NOT NULL,
-- The policy version this user has accepted -- The policy version this user has accepted
policy_version TEXT policy_version TEXT,
-- The policy version the user received from the server notices room
policy_version_sent TEXT
-- TODO: -- TODO:
-- upgraded_ts, devices, any email reset stuff? -- upgraded_ts, devices, any email reset stuff?
); );
@ -75,11 +77,14 @@ const selectPrivacyPolicySQL = "" +
"SELECT policy_version FROM account_accounts WHERE localpart = $1" "SELECT policy_version FROM account_accounts WHERE localpart = $1"
const batchSelectPrivacyPolicySQL = "" + const batchSelectPrivacyPolicySQL = "" +
"SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1" "SELECT localpart FROM account_accounts WHERE (policy_version IS NULL OR policy_version <> $1) AND (policy_version_sent IS NULL OR policy_version_sent <> $1)"
const updatePolicyVersionSQL = "" + const updatePolicyVersionSQL = "" +
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2" "UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
const updatePolicyVersionServerNoticeSQL = "" +
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
type accountsStatements struct { type accountsStatements struct {
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt updatePasswordStmt *sql.Stmt
@ -90,6 +95,7 @@ type accountsStatements struct {
selectPrivacyPolicyStmt *sql.Stmt selectPrivacyPolicyStmt *sql.Stmt
batchSelectPrivacyPolicyStmt *sql.Stmt batchSelectPrivacyPolicyStmt *sql.Stmt
updatePolicyVersionStmt *sql.Stmt updatePolicyVersionStmt *sql.Stmt
updatePolicyVersionServerNoticeStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -111,6 +117,7 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL}, {&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL}, {&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL}, {&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -232,9 +239,12 @@ func (s *accountsStatements) BatchSelectPrivacyPolicy(
// updatePolicyVersion sets the policy_version for a specific user // updatePolicyVersion sets the policy_version for a specific user
func (s *accountsStatements) UpdatePolicyVersion( func (s *accountsStatements) UpdatePolicyVersion(
ctx context.Context, txn *sql.Tx, policyVersion, localpart string, ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool,
) (err error) { ) (err error) {
stmt := s.updatePolicyVersionStmt stmt := s.updatePolicyVersionStmt
if serverNotice {
stmt = s.updatePolicyVersionServerNoticeStmt
}
if txn != nil { if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt) stmt = sqlutil.TxStmt(txn, stmt)
} }

View file

@ -12,7 +12,8 @@ func LoadAddPolicyVersion(m *sqlutil.Migrations) {
} }
func UpAddPolicyVersion(tx *sql.Tx) error { func UpAddPolicyVersion(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version TEXT;") _, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version TEXT;" +
"ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version_sent TEXT;")
if err != nil { if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err) return fmt.Errorf("failed to execute upgrade: %w", err)
} }
@ -20,7 +21,8 @@ func UpAddPolicyVersion(tx *sql.Tx) error {
} }
func DownAddPolicyVersion(tx *sql.Tx) error { func DownAddPolicyVersion(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;") _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;" +
"ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
if err != nil { if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err) return fmt.Errorf("failed to execute downgrade: %w", err)
} }

View file

@ -690,9 +690,9 @@ func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string)
} }
// UpdatePolicyVersion sets the accepted policy_version for a user. // UpdatePolicyVersion sets the accepted policy_version for a user.
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) { func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) (err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Accounts.UpdatePolicyVersion(ctx, txn, policyVersion, localpart) return d.Accounts.UpdatePolicyVersion(ctx, txn, policyVersion, localpart, serverNotice)
}) })
return return
} }

View file

@ -45,7 +45,9 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
account_type INTEGER NOT NULL, account_type INTEGER NOT NULL,
-- The policy version this user has accepted -- The policy version this user has accepted
policy_version TEXT policy_version TEXT,
-- The policy version the user received from the server notices room
policy_version_sent TEXT
-- TODO: -- TODO:
-- upgraded_ts, devices, any email reset stuff? -- upgraded_ts, devices, any email reset stuff?
); );
@ -73,11 +75,14 @@ const selectPrivacyPolicySQL = "" +
"SELECT policy_version FROM account_accounts WHERE localpart = $1" "SELECT policy_version FROM account_accounts WHERE localpart = $1"
const batchSelectPrivacyPolicySQL = "" + const batchSelectPrivacyPolicySQL = "" +
"SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1" "SELECT localpart FROM account_accounts WHERE (policy_version IS NULL OR policy_version <> $1) AND (policy_version_sent IS NULL OR policy_version_sent <> $2)"
const updatePolicyVersionSQL = "" + const updatePolicyVersionSQL = "" +
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2" "UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
const updatePolicyVersionServerNoticeSQL = "" +
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
type accountsStatements struct { type accountsStatements struct {
db *sql.DB db *sql.DB
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
@ -89,6 +94,7 @@ type accountsStatements struct {
selectPrivacyPolicyStmt *sql.Stmt selectPrivacyPolicyStmt *sql.Stmt
batchSelectPrivacyPolicyStmt *sql.Stmt batchSelectPrivacyPolicyStmt *sql.Stmt
updatePolicyVersionStmt *sql.Stmt updatePolicyVersionStmt *sql.Stmt
updatePolicyVersionServerNoticeStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -111,6 +117,7 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL}, {&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL}, {&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL}, {&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -218,7 +225,10 @@ func (s *accountsStatements) BatchSelectPrivacyPolicy(
if txn != nil { if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt) stmt = sqlutil.TxStmt(txn, stmt)
} }
rows, err := stmt.QueryContext(ctx, policyVersion) rows, err := stmt.QueryContext(ctx, policyVersion, policyVersion)
if err != nil {
return nil, err
}
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
var userID string var userID string
@ -232,9 +242,12 @@ func (s *accountsStatements) BatchSelectPrivacyPolicy(
// updatePolicyVersion sets the policy_version for a specific user // updatePolicyVersion sets the policy_version for a specific user
func (s *accountsStatements) UpdatePolicyVersion( func (s *accountsStatements) UpdatePolicyVersion(
ctx context.Context, txn *sql.Tx, policyVersion, localpart string, ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool,
) (err error) { ) (err error) {
stmt := s.updatePolicyVersionStmt stmt := s.updatePolicyVersionStmt
if serverNotice {
stmt = s.updatePolicyVersionServerNoticeStmt
}
if txn != nil { if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt) stmt = sqlutil.TxStmt(txn, stmt)
} }

View file

@ -12,7 +12,8 @@ func LoadAddPolicyVersion(m *sqlutil.Migrations) {
} }
func UpAddPolicyVersion(tx *sql.Tx) error { func UpAddPolicyVersion(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN policy_version TEXT;") _, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN policy_version TEXT;" +
"ALTER TABLE account_accounts ADD COLUMN policy_version_sent TEXT;")
if err != nil { if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err) return fmt.Errorf("failed to execute upgrade: %w", err)
} }
@ -20,7 +21,8 @@ func UpAddPolicyVersion(tx *sql.Tx) error {
} }
func DownAddPolicyVersion(tx *sql.Tx) error { func DownAddPolicyVersion(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;") _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;" +
"ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
if err != nil { if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err) return fmt.Errorf("failed to execute downgrade: %w", err)
} }

View file

@ -39,7 +39,7 @@ type AccountsTable interface {
SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error) SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error)
BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error) BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error)
UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string) (err error) UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool) (err error)
} }
type DevicesTable interface { type DevicesTable interface {