Add new coloumn to track accepted policy version
This commit is contained in:
parent
b6ee34918c
commit
9583784e8a
|
@ -52,6 +52,8 @@ type Database interface {
|
||||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||||
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
|
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
|
||||||
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||||
|
GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error)
|
||||||
|
GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error)
|
||||||
|
|
||||||
// Key backups
|
// Key backups
|
||||||
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
|
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
|
||||||
|
|
|
@ -39,9 +39,11 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
||||||
-- Identifies which application service this account belongs to, if any.
|
-- Identifies which application service this account belongs to, if any.
|
||||||
appservice_id TEXT,
|
appservice_id TEXT,
|
||||||
-- If the account is currently active
|
-- If the account is currently active
|
||||||
is_deactivated BOOLEAN DEFAULT FALSE
|
is_deactivated BOOLEAN DEFAULT FALSE,
|
||||||
|
-- The policy version this user has accepted
|
||||||
|
policy_version TEXT
|
||||||
-- TODO:
|
-- TODO:
|
||||||
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
|
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
|
||||||
);
|
);
|
||||||
-- Create sequence for autogenerated numeric usernames
|
-- Create sequence for autogenerated numeric usernames
|
||||||
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
|
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
|
||||||
|
@ -65,6 +67,12 @@ const selectPasswordHashSQL = "" +
|
||||||
const selectNewNumericLocalpartSQL = "" +
|
const selectNewNumericLocalpartSQL = "" +
|
||||||
"SELECT nextval('numeric_username_seq')"
|
"SELECT nextval('numeric_username_seq')"
|
||||||
|
|
||||||
|
const selectPrivacyPolicySQL = "" +
|
||||||
|
"SELECT policy_version FROM accounts_accounts WHERE localpart = $1"
|
||||||
|
|
||||||
|
const batchSelectPrivacyPolicySQL = "" +
|
||||||
|
"SELECT localpart FROM accounts_accounts WHERE policy_version IS NULL or policy_version <> $1"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
updatePasswordStmt *sql.Stmt
|
updatePasswordStmt *sql.Stmt
|
||||||
|
@ -72,6 +80,8 @@ type accountsStatements struct {
|
||||||
selectAccountByLocalpartStmt *sql.Stmt
|
selectAccountByLocalpartStmt *sql.Stmt
|
||||||
selectPasswordHashStmt *sql.Stmt
|
selectPasswordHashStmt *sql.Stmt
|
||||||
selectNewNumericLocalpartStmt *sql.Stmt
|
selectNewNumericLocalpartStmt *sql.Stmt
|
||||||
|
selectPrivacyPolicyStmt *sql.Stmt
|
||||||
|
batchSelectPrivacyPolicyStmt *sql.Stmt
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,6 +99,8 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
||||||
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
|
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
|
||||||
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
|
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
|
||||||
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
||||||
|
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
|
||||||
|
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -174,3 +186,35 @@ func (s *accountsStatements) selectNewNumericLocalpart(
|
||||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
|
||||||
|
func (s *accountsStatements) selectPrivacyPolicy(
|
||||||
|
ctx context.Context, txn *sql.Tx, localPart string,
|
||||||
|
) (policy string, err error) {
|
||||||
|
stmt := s.selectPrivacyPolicyStmt
|
||||||
|
if txn != nil {
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
|
}
|
||||||
|
err = stmt.QueryRowContext(ctx, localPart).Scan(&policy)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
|
||||||
|
func (s *accountsStatements) batchSelectPrivacyPolicy(
|
||||||
|
ctx context.Context, txn *sql.Tx, policyVersion string,
|
||||||
|
) (userIDs []string, err error) {
|
||||||
|
stmt := s.batchSelectPrivacyPolicyStmt
|
||||||
|
if txn != nil {
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
|
}
|
||||||
|
rows, err := stmt.QueryContext(ctx, policyVersion)
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
var userID string
|
||||||
|
if err := rows.Scan(&userID); err != nil {
|
||||||
|
return userIDs, err
|
||||||
|
}
|
||||||
|
userIDs = append(userIDs, userID)
|
||||||
|
}
|
||||||
|
return userIDs, rows.Err()
|
||||||
|
}
|
||||||
|
|
|
@ -518,3 +518,21 @@ func (d *Database) UpsertBackupKeys(
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPrivacyPolicy returns the accepted privacy policy version, if any.
|
||||||
|
func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) {
|
||||||
|
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
policyVersion, err = d.accounts.selectPrivacyPolicy(ctx, txn, localpart)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOutdatedPolicy queries all users which didn't accept the current policy version
|
||||||
|
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
|
||||||
|
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -39,7 +39,9 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
||||||
-- Identifies which application service this account belongs to, if any.
|
-- Identifies which application service this account belongs to, if any.
|
||||||
appservice_id TEXT,
|
appservice_id TEXT,
|
||||||
-- If the account is currently active
|
-- If the account is currently active
|
||||||
is_deactivated BOOLEAN DEFAULT 0
|
is_deactivated BOOLEAN DEFAULT 0,
|
||||||
|
-- The policy version this user has accepted
|
||||||
|
policy_version TEXT
|
||||||
-- TODO:
|
-- TODO:
|
||||||
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
|
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
|
||||||
);
|
);
|
||||||
|
@ -63,6 +65,12 @@ const selectPasswordHashSQL = "" +
|
||||||
const selectNewNumericLocalpartSQL = "" +
|
const selectNewNumericLocalpartSQL = "" +
|
||||||
"SELECT COUNT(localpart) FROM account_accounts"
|
"SELECT COUNT(localpart) FROM account_accounts"
|
||||||
|
|
||||||
|
const selectPrivacyPolicySQL = "" +
|
||||||
|
"SELECT policy_version FROM accounts_accounts WHERE localpart = $1"
|
||||||
|
|
||||||
|
const batchSelectPrivacyPolicySQL = "" +
|
||||||
|
"SELECT localpart FROM accounts_accounts WHERE policy_version IS NULL or policy_version <> $1"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
|
@ -71,6 +79,8 @@ type accountsStatements struct {
|
||||||
selectAccountByLocalpartStmt *sql.Stmt
|
selectAccountByLocalpartStmt *sql.Stmt
|
||||||
selectPasswordHashStmt *sql.Stmt
|
selectPasswordHashStmt *sql.Stmt
|
||||||
selectNewNumericLocalpartStmt *sql.Stmt
|
selectNewNumericLocalpartStmt *sql.Stmt
|
||||||
|
selectPrivacyPolicyStmt *sql.Stmt
|
||||||
|
batchSelectPrivacyPolicyStmt *sql.Stmt
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,6 +99,8 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
||||||
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
|
{&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL},
|
||||||
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
|
{&s.selectPasswordHashStmt, selectPasswordHashSQL},
|
||||||
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
||||||
|
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
|
||||||
|
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -174,3 +186,36 @@ func (s *accountsStatements) selectNewNumericLocalpart(
|
||||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// selectPrivacyPolicy gets the current privacy policy a specific user accepted
|
||||||
|
|
||||||
|
func (s *accountsStatements) selectPrivacyPolicy(
|
||||||
|
ctx context.Context, txn *sql.Tx, localPart string,
|
||||||
|
) (policy string, err error) {
|
||||||
|
stmt := s.selectPrivacyPolicyStmt
|
||||||
|
if txn != nil {
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
|
}
|
||||||
|
err = stmt.QueryRowContext(ctx, localPart).Scan(&policy)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version
|
||||||
|
func (s *accountsStatements) batchSelectPrivacyPolicy(
|
||||||
|
ctx context.Context, txn *sql.Tx, policyVersion string,
|
||||||
|
) (userIDs []string, err error) {
|
||||||
|
stmt := s.batchSelectPrivacyPolicyStmt
|
||||||
|
if txn != nil {
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
|
}
|
||||||
|
rows, err := stmt.QueryContext(ctx, policyVersion)
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
var userID string
|
||||||
|
if err := rows.Scan(&userID); err != nil {
|
||||||
|
return userIDs, err
|
||||||
|
}
|
||||||
|
userIDs = append(userIDs, userID)
|
||||||
|
}
|
||||||
|
return userIDs, rows.Err()
|
||||||
|
}
|
||||||
|
|
|
@ -561,3 +561,21 @@ func (d *Database) UpsertBackupKeys(
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPrivacyPolicy returns the accepted privacy policy version, if any.
|
||||||
|
func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) {
|
||||||
|
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
policyVersion, err = d.accounts.selectPrivacyPolicy(ctx, txn, localpart)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOutdatedPolicy queries all users which didn't accept the current policy version
|
||||||
|
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
|
||||||
|
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue