diff --git a/userapi/api/api.go b/userapi/api/api.go index e5a173cec..f38180a63 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -46,6 +46,7 @@ type UserInternalAPI interface { QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error GetOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyUsersRequest, res *QueryOutdatedPolicyUsersResponse) error + PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error } type PerformKeyBackupRequest struct { @@ -347,7 +348,7 @@ type QueryPolicyVersionResponse struct { PolicyVersion string } -// QueryOutdatedPolicyUsersRequest is the response for QueryOutdatedPolicyUsersRequest +// QueryOutdatedPolicyUsersRequest is the request for QueryOutdatedPolicyUsersRequest type QueryOutdatedPolicyUsersRequest struct { PolicyVersion string } @@ -357,6 +358,14 @@ type QueryOutdatedPolicyUsersResponse struct { OutdatedUsers []string } +// UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest +type UpdatePolicyVersionRequest struct { + PolicyVersion, LocalPart string +} + +// UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest +type UpdatePolicyVersionResponse struct{} + // Device represents a client's device (mobile, web, etc) type Device struct { ID string diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index 15b5572f4..3428e0fd3 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -127,7 +127,13 @@ func (t *UserInternalAPITrace) QueryPolicyVersion(ctx context.Context, req *Quer func (t *UserInternalAPITrace) GetOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyUsersRequest, res *QueryOutdatedPolicyUsersResponse) error { err := t.Impl.GetOutdatedPolicy(ctx, req, res) - util.GetLogger(ctx).Infof("QueryPolicyVersion req=%+v res=%+v", js(req), js(res)) + util.GetLogger(ctx).Infof("GetOutdatedPolicy req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *UserInternalAPITrace) PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error { + err := t.Impl.PerformUpdatePolicyVersion(ctx, req, res) + util.GetLogger(ctx).Infof("PerformUpdatePolicyVersion req=%+v res=%+v", js(req), js(res)) return err } diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 0d7a723a9..a6b933e3f 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -615,3 +615,11 @@ func (a *UserInternalAPI) GetOutdatedPolicy( return nil } + +func (a *UserInternalAPI) PerformUpdatePolicyVersion( + ctx context.Context, + req *api.UpdatePolicyVersionRequest, + res *api.UpdatePolicyVersionResponse, +) error { + return a.AccountDB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart) +} diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 4a99d1b13..9067d993e 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -37,17 +37,18 @@ const ( PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" PerformKeyBackupPath = "/userapi/performKeyBackup" + PerformUpdatePolicyVersionPath = "/userapi/performUpdatePolicyVersion" - QueryKeyBackupPath = "/userapi/queryKeyBackup" - QueryProfilePath = "/userapi/queryProfile" - QueryAccessTokenPath = "/userapi/queryAccessToken" - QueryDevicesPath = "/userapi/queryDevices" - QueryAccountDataPath = "/userapi/queryAccountData" - QueryDeviceInfosPath = "/userapi/queryDeviceInfos" - QuerySearchProfilesPath = "/userapi/querySearchProfiles" - QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" - QueryPolicyVersion = "/userapi/queryPolicyVersion" - QueryOutdatedPolicyUsers = "/userapi/queryOutdatedPolicy" + QueryKeyBackupPath = "/userapi/queryKeyBackup" + QueryProfilePath = "/userapi/queryProfile" + QueryAccessTokenPath = "/userapi/queryAccessToken" + QueryDevicesPath = "/userapi/queryDevices" + QueryAccountDataPath = "/userapi/queryAccountData" + QueryDeviceInfosPath = "/userapi/queryDeviceInfos" + QuerySearchProfilesPath = "/userapi/querySearchProfiles" + QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" + QueryPolicyVersionPath = "/userapi/queryPolicyVersion" + QueryOutdatedPolicyUsersPath = "/userapi/queryOutdatedPolicy" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -256,7 +257,7 @@ func (h *httpUserInternalAPI) QueryPolicyVersion(ctx context.Context, req *api.Q span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup") defer span.Finish() - apiURL := h.apiURL + QueryPolicyVersion + apiURL := h.apiURL + QueryPolicyVersionPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } @@ -264,6 +265,14 @@ func (h *httpUserInternalAPI) GetOutdatedPolicy(ctx context.Context, req *api.Qu span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup") defer span.Finish() - apiURL := h.apiURL + QueryOutdatedPolicyUsers + apiURL := h.apiURL + QueryOutdatedPolicyUsersPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) PerformUpdatePolicyVersion(ctx context.Context, req *api.UpdatePolicyVersionRequest, res *api.UpdatePolicyVersionResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup") + defer span.Finish() + + apiURL := h.apiURL + PerformUpdatePolicyVersionPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 22e49181c..70e8d17ce 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -265,7 +265,7 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - internalAPIMux.Handle(QueryPolicyVersion, + internalAPIMux.Handle(QueryPolicyVersionPath, httputil.MakeInternalAPI("queryPolicyVersion", func(req *http.Request) util.JSONResponse { request := api.QueryPolicyVersionRequest{} response := api.QueryPolicyVersionResponse{} @@ -279,7 +279,7 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - internalAPIMux.Handle(QueryOutdatedPolicyUsers, + internalAPIMux.Handle(QueryOutdatedPolicyUsersPath, httputil.MakeInternalAPI("queryOutdatedPolicyUsers", func(req *http.Request) util.JSONResponse { request := api.QueryOutdatedPolicyUsersRequest{} response := api.QueryOutdatedPolicyUsersResponse{} @@ -293,4 +293,18 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(PerformUpdatePolicyVersionPath, + httputil.MakeInternalAPI("performUpdatePolicyVersionPath", func(req *http.Request) util.JSONResponse { + request := api.UpdatePolicyVersionRequest{} + response := api.UpdatePolicyVersionResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + err := s.PerformUpdatePolicyVersion(req.Context(), &request, &response) + if err != nil { + return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response} + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index e172c0951..61f59c9b9 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -54,6 +54,7 @@ type Database interface { GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) + UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) error // Key backups CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error) diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go index db8cdecdf..1b22a44e4 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -73,6 +73,9 @@ const selectPrivacyPolicySQL = "" + const batchSelectPrivacyPolicySQL = "" + "SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1" +const updatePolicyVersionSQL = "" + + "UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2" + type accountsStatements struct { insertAccountStmt *sql.Stmt updatePasswordStmt *sql.Stmt @@ -82,6 +85,7 @@ type accountsStatements struct { selectNewNumericLocalpartStmt *sql.Stmt selectPrivacyPolicyStmt *sql.Stmt batchSelectPrivacyPolicyStmt *sql.Stmt + updatePolicyVersionStmt *sql.Stmt serverName gomatrixserverlib.ServerName } @@ -101,6 +105,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, {&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL}, {&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL}, + {&s.updatePolicyVersionStmt, updatePolicyVersionSQL}, }.Prepare(db) } @@ -218,3 +223,15 @@ func (s *accountsStatements) batchSelectPrivacyPolicy( } return userIDs, rows.Err() } + +// updatePolicyVersion sets the policy_version for a specific user +func (s *accountsStatements) updatePolicyVersion( + ctx context.Context, txn *sql.Tx, policyVersion, localpart string, +) (err error) { + stmt := s.updatePolicyVersionStmt + if txn != nil { + stmt = sqlutil.TxStmt(txn, stmt) + } + _, err = stmt.ExecContext(ctx, policyVersion, localpart) + return err +} diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index ff91ae447..4e5aa20a7 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -529,7 +529,7 @@ func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (poli return } -// GetOutdatedPolicy queries all users which didn't accept the current policy version +// GetOutdatedPolicy queries all users which didn't accept the current policy version. func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion) @@ -537,3 +537,11 @@ func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) }) return } + +// UpdatePolicyVersion sets the accepted policy_version for a user. +func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.accounts.updatePolicyVersion(ctx, txn, policyVersion, localpart) + }) + return +} diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 95b9eb7bf..f64c05717 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -71,6 +71,9 @@ const selectPrivacyPolicySQL = "" + const batchSelectPrivacyPolicySQL = "" + "SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1" +const updatePolicyVersionSQL = "" + + "UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2" + type accountsStatements struct { db *sql.DB insertAccountStmt *sql.Stmt @@ -81,6 +84,7 @@ type accountsStatements struct { selectNewNumericLocalpartStmt *sql.Stmt selectPrivacyPolicyStmt *sql.Stmt batchSelectPrivacyPolicyStmt *sql.Stmt + updatePolicyVersionStmt *sql.Stmt serverName gomatrixserverlib.ServerName } @@ -101,6 +105,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, {&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL}, {&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL}, + {&s.updatePolicyVersionStmt, updatePolicyVersionSQL}, }.Prepare(db) } @@ -219,3 +224,15 @@ func (s *accountsStatements) batchSelectPrivacyPolicy( } return userIDs, rows.Err() } + +// updatePolicyVersion sets the policy_version for a specific user +func (s *accountsStatements) updatePolicyVersion( + ctx context.Context, txn *sql.Tx, policyVersion, localpart string, +) (err error) { + stmt := s.updatePolicyVersionStmt + if txn != nil { + stmt = sqlutil.TxStmt(txn, stmt) + } + _, err = stmt.ExecContext(ctx, policyVersion, localpart) + return err +} diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index c7551299b..768d58f65 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -580,3 +580,11 @@ func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) }) return } + +// UpdatePolicyVersion sets the accepted policy_version for a user. +func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.accounts.updatePolicyVersion(ctx, txn, policyVersion, localpart) + }) + return +}