Add new coloumn to track accepted policy version

This commit is contained in:
Till Faelligen 2022-02-14 14:02:13 +01:00
parent b6ee34918c
commit 9583784e8a
5 changed files with 130 additions and 3 deletions

View file

@ -52,6 +52,8 @@ type Database interface {
DeactivateAccount(ctx context.Context, localpart string) (err error) DeactivateAccount(ctx context.Context, localpart string) (err error)
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error) CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error)
GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error)
// Key backups // Key backups
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error) CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)

View file

@ -39,9 +39,11 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- Identifies which application service this account belongs to, if any. -- Identifies which application service this account belongs to, if any.
appservice_id TEXT, appservice_id TEXT,
-- If the account is currently active -- If the account is currently active
is_deactivated BOOLEAN DEFAULT FALSE is_deactivated BOOLEAN DEFAULT FALSE,
-- The policy version this user has accepted
policy_version TEXT
-- TODO: -- TODO:
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
); );
-- Create sequence for autogenerated numeric usernames -- Create sequence for autogenerated numeric usernames
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
@ -65,6 +67,12 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT nextval('numeric_username_seq')" "SELECT nextval('numeric_username_seq')"
const selectPrivacyPolicySQL = "" +
"SELECT policy_version FROM accounts_accounts WHERE localpart = $1"
const batchSelectPrivacyPolicySQL = "" +
"SELECT localpart FROM accounts_accounts WHERE policy_version IS NULL or policy_version <> $1"
type accountsStatements struct { type accountsStatements struct {
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt updatePasswordStmt *sql.Stmt
@ -72,6 +80,8 @@ type accountsStatements struct {
selectAccountByLocalpartStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt
selectPrivacyPolicyStmt *sql.Stmt
batchSelectPrivacyPolicyStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -89,6 +99,8 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL}, {&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
{&s.selectPasswordHashStmt, selectPasswordHashSQL}, {&s.selectPasswordHashStmt, selectPasswordHashSQL},
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
}.Prepare(db) }.Prepare(db)
} }
@ -174,3 +186,35 @@ func (s *accountsStatements) selectNewNumericLocalpart(
err = stmt.QueryRowContext(ctx).Scan(&id) err = stmt.QueryRowContext(ctx).Scan(&id)
return return
} }
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
func (s *accountsStatements) selectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, localPart string,
) (policy string, err error) {
stmt := s.selectPrivacyPolicyStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
err = stmt.QueryRowContext(ctx, localPart).Scan(&policy)
return
}
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
func (s *accountsStatements) batchSelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, policyVersion string,
) (userIDs []string, err error) {
stmt := s.batchSelectPrivacyPolicyStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
rows, err := stmt.QueryContext(ctx, policyVersion)
defer rows.Close()
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return userIDs, err
}
userIDs = append(userIDs, userID)
}
return userIDs, rows.Err()
}

View file

@ -518,3 +518,21 @@ func (d *Database) UpsertBackupKeys(
}) })
return return
} }
// GetPrivacyPolicy returns the accepted privacy policy version, if any.
func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
policyVersion, err = d.accounts.selectPrivacyPolicy(ctx, txn, localpart)
return err
})
return
}
// GetOutdatedPolicy queries all users which didn't accept the current policy version
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion)
return err
})
return
}

View file

@ -39,7 +39,9 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- Identifies which application service this account belongs to, if any. -- Identifies which application service this account belongs to, if any.
appservice_id TEXT, appservice_id TEXT,
-- If the account is currently active -- If the account is currently active
is_deactivated BOOLEAN DEFAULT 0 is_deactivated BOOLEAN DEFAULT 0,
-- The policy version this user has accepted
policy_version TEXT
-- TODO: -- TODO:
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
); );
@ -63,6 +65,12 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT COUNT(localpart) FROM account_accounts" "SELECT COUNT(localpart) FROM account_accounts"
const selectPrivacyPolicySQL = "" +
"SELECT policy_version FROM accounts_accounts WHERE localpart = $1"
const batchSelectPrivacyPolicySQL = "" +
"SELECT localpart FROM accounts_accounts WHERE policy_version IS NULL or policy_version <> $1"
type accountsStatements struct { type accountsStatements struct {
db *sql.DB db *sql.DB
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
@ -71,6 +79,8 @@ type accountsStatements struct {
selectAccountByLocalpartStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt
selectPrivacyPolicyStmt *sql.Stmt
batchSelectPrivacyPolicyStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -89,6 +99,8 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL}, {&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
{&s.selectPasswordHashStmt, selectPasswordHashSQL}, {&s.selectPasswordHashStmt, selectPasswordHashSQL},
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
}.Prepare(db) }.Prepare(db)
} }
@ -174,3 +186,36 @@ func (s *accountsStatements) selectNewNumericLocalpart(
err = stmt.QueryRowContext(ctx).Scan(&id) err = stmt.QueryRowContext(ctx).Scan(&id)
return return
} }
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
func (s *accountsStatements) selectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, localPart string,
) (policy string, err error) {
stmt := s.selectPrivacyPolicyStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
err = stmt.QueryRowContext(ctx, localPart).Scan(&policy)
return
}
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
func (s *accountsStatements) batchSelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, policyVersion string,
) (userIDs []string, err error) {
stmt := s.batchSelectPrivacyPolicyStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
rows, err := stmt.QueryContext(ctx, policyVersion)
defer rows.Close()
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return userIDs, err
}
userIDs = append(userIDs, userID)
}
return userIDs, rows.Err()
}

View file

@ -561,3 +561,21 @@ func (d *Database) UpsertBackupKeys(
}) })
return return
} }
// GetPrivacyPolicy returns the accepted privacy policy version, if any.
func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
policyVersion, err = d.accounts.selectPrivacyPolicy(ctx, txn, localpart)
return err
})
return
}
// GetOutdatedPolicy queries all users which didn't accept the current policy version
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion)
return err
})
return
}