Add posibility to track sent policy versions
This commit is contained in:
parent
cb4526793d
commit
fb95331aa2
|
@ -362,6 +362,7 @@ type QueryOutdatedPolicyUsersResponse struct {
|
||||||
// UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest
|
// UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest
|
||||||
type UpdatePolicyVersionRequest struct {
|
type UpdatePolicyVersionRequest struct {
|
||||||
PolicyVersion, LocalPart string
|
PolicyVersion, LocalPart string
|
||||||
|
ServerNoticeUpdate bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest
|
// UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest
|
||||||
|
|
|
@ -629,5 +629,5 @@ func (a *UserInternalAPI) PerformUpdatePolicyVersion(
|
||||||
req *api.UpdatePolicyVersionRequest,
|
req *api.UpdatePolicyVersionRequest,
|
||||||
res *api.UpdatePolicyVersionResponse,
|
res *api.UpdatePolicyVersionResponse,
|
||||||
) error {
|
) error {
|
||||||
return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart)
|
return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart, req.ServerNoticeUpdate)
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,7 +53,7 @@ type Database interface {
|
||||||
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||||
GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error)
|
GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error)
|
||||||
GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error)
|
GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error)
|
||||||
UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) error
|
UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) error
|
||||||
|
|
||||||
// Key backups
|
// Key backups
|
||||||
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
|
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
|
||||||
|
|
|
@ -45,7 +45,9 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
||||||
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
|
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
|
||||||
account_type SMALLINT NOT NULL,
|
account_type SMALLINT NOT NULL,
|
||||||
-- The policy version this user has accepted
|
-- The policy version this user has accepted
|
||||||
policy_version TEXT
|
policy_version TEXT,
|
||||||
|
-- The policy version the user received from the server notices room
|
||||||
|
policy_version_sent TEXT
|
||||||
-- TODO:
|
-- TODO:
|
||||||
-- upgraded_ts, devices, any email reset stuff?
|
-- upgraded_ts, devices, any email reset stuff?
|
||||||
);
|
);
|
||||||
|
@ -75,11 +77,14 @@ const selectPrivacyPolicySQL = "" +
|
||||||
"SELECT policy_version FROM account_accounts WHERE localpart = $1"
|
"SELECT policy_version FROM account_accounts WHERE localpart = $1"
|
||||||
|
|
||||||
const batchSelectPrivacyPolicySQL = "" +
|
const batchSelectPrivacyPolicySQL = "" +
|
||||||
"SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1"
|
"SELECT localpart FROM account_accounts WHERE (policy_version IS NULL OR policy_version <> $1) AND (policy_version_sent IS NULL OR policy_version_sent <> $1)"
|
||||||
|
|
||||||
const updatePolicyVersionSQL = "" +
|
const updatePolicyVersionSQL = "" +
|
||||||
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
|
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
|
const updatePolicyVersionServerNoticeSQL = "" +
|
||||||
|
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
updatePasswordStmt *sql.Stmt
|
updatePasswordStmt *sql.Stmt
|
||||||
|
@ -90,6 +95,7 @@ type accountsStatements struct {
|
||||||
selectPrivacyPolicyStmt *sql.Stmt
|
selectPrivacyPolicyStmt *sql.Stmt
|
||||||
batchSelectPrivacyPolicyStmt *sql.Stmt
|
batchSelectPrivacyPolicyStmt *sql.Stmt
|
||||||
updatePolicyVersionStmt *sql.Stmt
|
updatePolicyVersionStmt *sql.Stmt
|
||||||
|
updatePolicyVersionServerNoticeStmt *sql.Stmt
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,6 +117,7 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
|
||||||
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
|
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
|
||||||
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
||||||
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
|
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
|
||||||
|
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -232,9 +239,12 @@ func (s *accountsStatements) BatchSelectPrivacyPolicy(
|
||||||
|
|
||||||
// updatePolicyVersion sets the policy_version for a specific user
|
// updatePolicyVersion sets the policy_version for a specific user
|
||||||
func (s *accountsStatements) UpdatePolicyVersion(
|
func (s *accountsStatements) UpdatePolicyVersion(
|
||||||
ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
|
ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := s.updatePolicyVersionStmt
|
stmt := s.updatePolicyVersionStmt
|
||||||
|
if serverNotice {
|
||||||
|
stmt = s.updatePolicyVersionServerNoticeStmt
|
||||||
|
}
|
||||||
if txn != nil {
|
if txn != nil {
|
||||||
stmt = sqlutil.TxStmt(txn, stmt)
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,8 @@ func LoadAddPolicyVersion(m *sqlutil.Migrations) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpAddPolicyVersion(tx *sql.Tx) error {
|
func UpAddPolicyVersion(tx *sql.Tx) error {
|
||||||
_, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version TEXT;")
|
_, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version TEXT;" +
|
||||||
|
"ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version_sent TEXT;")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -20,7 +21,8 @@ func UpAddPolicyVersion(tx *sql.Tx) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func DownAddPolicyVersion(tx *sql.Tx) error {
|
func DownAddPolicyVersion(tx *sql.Tx) error {
|
||||||
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;")
|
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;" +
|
||||||
|
"ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -690,9 +690,9 @@ func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePolicyVersion sets the accepted policy_version for a user.
|
// UpdatePolicyVersion sets the accepted policy_version for a user.
|
||||||
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) {
|
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) (err error) {
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.Accounts.UpdatePolicyVersion(ctx, txn, policyVersion, localpart)
|
return d.Accounts.UpdatePolicyVersion(ctx, txn, policyVersion, localpart, serverNotice)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,9 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
||||||
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
|
-- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
|
||||||
account_type INTEGER NOT NULL,
|
account_type INTEGER NOT NULL,
|
||||||
-- The policy version this user has accepted
|
-- The policy version this user has accepted
|
||||||
policy_version TEXT
|
policy_version TEXT,
|
||||||
|
-- The policy version the user received from the server notices room
|
||||||
|
policy_version_sent TEXT
|
||||||
-- TODO:
|
-- TODO:
|
||||||
-- upgraded_ts, devices, any email reset stuff?
|
-- upgraded_ts, devices, any email reset stuff?
|
||||||
);
|
);
|
||||||
|
@ -73,11 +75,14 @@ const selectPrivacyPolicySQL = "" +
|
||||||
"SELECT policy_version FROM account_accounts WHERE localpart = $1"
|
"SELECT policy_version FROM account_accounts WHERE localpart = $1"
|
||||||
|
|
||||||
const batchSelectPrivacyPolicySQL = "" +
|
const batchSelectPrivacyPolicySQL = "" +
|
||||||
"SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1"
|
"SELECT localpart FROM account_accounts WHERE (policy_version IS NULL OR policy_version <> $1) AND (policy_version_sent IS NULL OR policy_version_sent <> $2)"
|
||||||
|
|
||||||
const updatePolicyVersionSQL = "" +
|
const updatePolicyVersionSQL = "" +
|
||||||
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
|
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
|
const updatePolicyVersionServerNoticeSQL = "" +
|
||||||
|
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
|
@ -89,6 +94,7 @@ type accountsStatements struct {
|
||||||
selectPrivacyPolicyStmt *sql.Stmt
|
selectPrivacyPolicyStmt *sql.Stmt
|
||||||
batchSelectPrivacyPolicyStmt *sql.Stmt
|
batchSelectPrivacyPolicyStmt *sql.Stmt
|
||||||
updatePolicyVersionStmt *sql.Stmt
|
updatePolicyVersionStmt *sql.Stmt
|
||||||
|
updatePolicyVersionServerNoticeStmt *sql.Stmt
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,6 +117,7 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
||||||
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
|
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
|
||||||
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
||||||
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
|
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
|
||||||
|
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -218,7 +225,10 @@ func (s *accountsStatements) BatchSelectPrivacyPolicy(
|
||||||
if txn != nil {
|
if txn != nil {
|
||||||
stmt = sqlutil.TxStmt(txn, stmt)
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
}
|
}
|
||||||
rows, err := stmt.QueryContext(ctx, policyVersion)
|
rows, err := stmt.QueryContext(ctx, policyVersion, policyVersion)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var userID string
|
var userID string
|
||||||
|
@ -232,9 +242,12 @@ func (s *accountsStatements) BatchSelectPrivacyPolicy(
|
||||||
|
|
||||||
// updatePolicyVersion sets the policy_version for a specific user
|
// updatePolicyVersion sets the policy_version for a specific user
|
||||||
func (s *accountsStatements) UpdatePolicyVersion(
|
func (s *accountsStatements) UpdatePolicyVersion(
|
||||||
ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
|
ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := s.updatePolicyVersionStmt
|
stmt := s.updatePolicyVersionStmt
|
||||||
|
if serverNotice {
|
||||||
|
stmt = s.updatePolicyVersionServerNoticeStmt
|
||||||
|
}
|
||||||
if txn != nil {
|
if txn != nil {
|
||||||
stmt = sqlutil.TxStmt(txn, stmt)
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,8 @@ func LoadAddPolicyVersion(m *sqlutil.Migrations) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpAddPolicyVersion(tx *sql.Tx) error {
|
func UpAddPolicyVersion(tx *sql.Tx) error {
|
||||||
_, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN policy_version TEXT;")
|
_, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN policy_version TEXT;" +
|
||||||
|
"ALTER TABLE account_accounts ADD COLUMN policy_version_sent TEXT;")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -20,7 +21,8 @@ func UpAddPolicyVersion(tx *sql.Tx) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func DownAddPolicyVersion(tx *sql.Tx) error {
|
func DownAddPolicyVersion(tx *sql.Tx) error {
|
||||||
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;")
|
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;" +
|
||||||
|
"ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,7 +39,7 @@ type AccountsTable interface {
|
||||||
|
|
||||||
SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error)
|
SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error)
|
||||||
BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error)
|
BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error)
|
||||||
UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string) (err error)
|
UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DevicesTable interface {
|
type DevicesTable interface {
|
||||||
|
|
Loading…
Reference in a new issue