Add a way to update the policy_version for a user
This commit is contained in:
parent
a505471c90
commit
097f1d4609
|
@ -46,6 +46,7 @@ type UserInternalAPI interface {
|
||||||
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
|
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
|
||||||
QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error
|
QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error
|
||||||
GetOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyUsersRequest, res *QueryOutdatedPolicyUsersResponse) error
|
GetOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyUsersRequest, res *QueryOutdatedPolicyUsersResponse) error
|
||||||
|
PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type PerformKeyBackupRequest struct {
|
type PerformKeyBackupRequest struct {
|
||||||
|
@ -347,7 +348,7 @@ type QueryPolicyVersionResponse struct {
|
||||||
PolicyVersion string
|
PolicyVersion string
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryOutdatedPolicyUsersRequest is the response for QueryOutdatedPolicyUsersRequest
|
// QueryOutdatedPolicyUsersRequest is the request for QueryOutdatedPolicyUsersRequest
|
||||||
type QueryOutdatedPolicyUsersRequest struct {
|
type QueryOutdatedPolicyUsersRequest struct {
|
||||||
PolicyVersion string
|
PolicyVersion string
|
||||||
}
|
}
|
||||||
|
@ -357,6 +358,14 @@ type QueryOutdatedPolicyUsersResponse struct {
|
||||||
OutdatedUsers []string
|
OutdatedUsers []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest
|
||||||
|
type UpdatePolicyVersionRequest struct {
|
||||||
|
PolicyVersion, LocalPart string
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest
|
||||||
|
type UpdatePolicyVersionResponse struct{}
|
||||||
|
|
||||||
// Device represents a client's device (mobile, web, etc)
|
// Device represents a client's device (mobile, web, etc)
|
||||||
type Device struct {
|
type Device struct {
|
||||||
ID string
|
ID string
|
||||||
|
|
|
@ -127,7 +127,13 @@ func (t *UserInternalAPITrace) QueryPolicyVersion(ctx context.Context, req *Quer
|
||||||
|
|
||||||
func (t *UserInternalAPITrace) GetOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyUsersRequest, res *QueryOutdatedPolicyUsersResponse) error {
|
func (t *UserInternalAPITrace) GetOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyUsersRequest, res *QueryOutdatedPolicyUsersResponse) error {
|
||||||
err := t.Impl.GetOutdatedPolicy(ctx, req, res)
|
err := t.Impl.GetOutdatedPolicy(ctx, req, res)
|
||||||
util.GetLogger(ctx).Infof("QueryPolicyVersion req=%+v res=%+v", js(req), js(res))
|
util.GetLogger(ctx).Infof("GetOutdatedPolicy req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UserInternalAPITrace) PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error {
|
||||||
|
err := t.Impl.PerformUpdatePolicyVersion(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).Infof("PerformUpdatePolicyVersion req=%+v res=%+v", js(req), js(res))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -615,3 +615,11 @@ func (a *UserInternalAPI) GetOutdatedPolicy(
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) PerformUpdatePolicyVersion(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.UpdatePolicyVersionRequest,
|
||||||
|
res *api.UpdatePolicyVersionResponse,
|
||||||
|
) error {
|
||||||
|
return a.AccountDB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart)
|
||||||
|
}
|
||||||
|
|
|
@ -37,6 +37,7 @@ const (
|
||||||
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
|
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
|
||||||
PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation"
|
PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation"
|
||||||
PerformKeyBackupPath = "/userapi/performKeyBackup"
|
PerformKeyBackupPath = "/userapi/performKeyBackup"
|
||||||
|
PerformUpdatePolicyVersionPath = "/userapi/performUpdatePolicyVersion"
|
||||||
|
|
||||||
QueryKeyBackupPath = "/userapi/queryKeyBackup"
|
QueryKeyBackupPath = "/userapi/queryKeyBackup"
|
||||||
QueryProfilePath = "/userapi/queryProfile"
|
QueryProfilePath = "/userapi/queryProfile"
|
||||||
|
@ -46,8 +47,8 @@ const (
|
||||||
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
|
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
|
||||||
QuerySearchProfilesPath = "/userapi/querySearchProfiles"
|
QuerySearchProfilesPath = "/userapi/querySearchProfiles"
|
||||||
QueryOpenIDTokenPath = "/userapi/queryOpenIDToken"
|
QueryOpenIDTokenPath = "/userapi/queryOpenIDToken"
|
||||||
QueryPolicyVersion = "/userapi/queryPolicyVersion"
|
QueryPolicyVersionPath = "/userapi/queryPolicyVersion"
|
||||||
QueryOutdatedPolicyUsers = "/userapi/queryOutdatedPolicy"
|
QueryOutdatedPolicyUsersPath = "/userapi/queryOutdatedPolicy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
||||||
|
@ -256,7 +257,7 @@ func (h *httpUserInternalAPI) QueryPolicyVersion(ctx context.Context, req *api.Q
|
||||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup")
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup")
|
||||||
defer span.Finish()
|
defer span.Finish()
|
||||||
|
|
||||||
apiURL := h.apiURL + QueryPolicyVersion
|
apiURL := h.apiURL + QueryPolicyVersionPath
|
||||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -264,6 +265,14 @@ func (h *httpUserInternalAPI) GetOutdatedPolicy(ctx context.Context, req *api.Qu
|
||||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup")
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup")
|
||||||
defer span.Finish()
|
defer span.Finish()
|
||||||
|
|
||||||
apiURL := h.apiURL + QueryOutdatedPolicyUsers
|
apiURL := h.apiURL + QueryOutdatedPolicyUsersPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) PerformUpdatePolicyVersion(ctx context.Context, req *api.UpdatePolicyVersionRequest, res *api.UpdatePolicyVersionResponse) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + PerformUpdatePolicyVersionPath
|
||||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
}
|
}
|
||||||
|
|
|
@ -265,7 +265,7 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
internalAPIMux.Handle(QueryPolicyVersion,
|
internalAPIMux.Handle(QueryPolicyVersionPath,
|
||||||
httputil.MakeInternalAPI("queryPolicyVersion", func(req *http.Request) util.JSONResponse {
|
httputil.MakeInternalAPI("queryPolicyVersion", func(req *http.Request) util.JSONResponse {
|
||||||
request := api.QueryPolicyVersionRequest{}
|
request := api.QueryPolicyVersionRequest{}
|
||||||
response := api.QueryPolicyVersionResponse{}
|
response := api.QueryPolicyVersionResponse{}
|
||||||
|
@ -279,7 +279,7 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
internalAPIMux.Handle(QueryOutdatedPolicyUsers,
|
internalAPIMux.Handle(QueryOutdatedPolicyUsersPath,
|
||||||
httputil.MakeInternalAPI("queryOutdatedPolicyUsers", func(req *http.Request) util.JSONResponse {
|
httputil.MakeInternalAPI("queryOutdatedPolicyUsers", func(req *http.Request) util.JSONResponse {
|
||||||
request := api.QueryOutdatedPolicyUsersRequest{}
|
request := api.QueryOutdatedPolicyUsersRequest{}
|
||||||
response := api.QueryOutdatedPolicyUsersResponse{}
|
response := api.QueryOutdatedPolicyUsersResponse{}
|
||||||
|
@ -293,4 +293,18 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(PerformUpdatePolicyVersionPath,
|
||||||
|
httputil.MakeInternalAPI("performUpdatePolicyVersionPath", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.UpdatePolicyVersionRequest{}
|
||||||
|
response := api.UpdatePolicyVersionResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
err := s.PerformUpdatePolicyVersion(req.Context(), &request, &response)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,6 +54,7 @@ type Database interface {
|
||||||
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||||
GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error)
|
GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error)
|
||||||
GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error)
|
GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error)
|
||||||
|
UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) error
|
||||||
|
|
||||||
// Key backups
|
// Key backups
|
||||||
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
|
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
|
||||||
|
|
|
@ -73,6 +73,9 @@ const selectPrivacyPolicySQL = "" +
|
||||||
const batchSelectPrivacyPolicySQL = "" +
|
const batchSelectPrivacyPolicySQL = "" +
|
||||||
"SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1"
|
"SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1"
|
||||||
|
|
||||||
|
const updatePolicyVersionSQL = "" +
|
||||||
|
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
updatePasswordStmt *sql.Stmt
|
updatePasswordStmt *sql.Stmt
|
||||||
|
@ -82,6 +85,7 @@ type accountsStatements struct {
|
||||||
selectNewNumericLocalpartStmt *sql.Stmt
|
selectNewNumericLocalpartStmt *sql.Stmt
|
||||||
selectPrivacyPolicyStmt *sql.Stmt
|
selectPrivacyPolicyStmt *sql.Stmt
|
||||||
batchSelectPrivacyPolicyStmt *sql.Stmt
|
batchSelectPrivacyPolicyStmt *sql.Stmt
|
||||||
|
updatePolicyVersionStmt *sql.Stmt
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,6 +105,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
||||||
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
||||||
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
|
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
|
||||||
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
||||||
|
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -218,3 +223,15 @@ func (s *accountsStatements) batchSelectPrivacyPolicy(
|
||||||
}
|
}
|
||||||
return userIDs, rows.Err()
|
return userIDs, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updatePolicyVersion sets the policy_version for a specific user
|
||||||
|
func (s *accountsStatements) updatePolicyVersion(
|
||||||
|
ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := s.updatePolicyVersionStmt
|
||||||
|
if txn != nil {
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
|
}
|
||||||
|
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
|
@ -529,7 +529,7 @@ func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (poli
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOutdatedPolicy queries all users which didn't accept the current policy version
|
// GetOutdatedPolicy queries all users which didn't accept the current policy version.
|
||||||
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
|
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion)
|
userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion)
|
||||||
|
@ -537,3 +537,11 @@ func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdatePolicyVersion sets the accepted policy_version for a user.
|
||||||
|
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) {
|
||||||
|
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.accounts.updatePolicyVersion(ctx, txn, policyVersion, localpart)
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -71,6 +71,9 @@ const selectPrivacyPolicySQL = "" +
|
||||||
const batchSelectPrivacyPolicySQL = "" +
|
const batchSelectPrivacyPolicySQL = "" +
|
||||||
"SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1"
|
"SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1"
|
||||||
|
|
||||||
|
const updatePolicyVersionSQL = "" +
|
||||||
|
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
|
@ -81,6 +84,7 @@ type accountsStatements struct {
|
||||||
selectNewNumericLocalpartStmt *sql.Stmt
|
selectNewNumericLocalpartStmt *sql.Stmt
|
||||||
selectPrivacyPolicyStmt *sql.Stmt
|
selectPrivacyPolicyStmt *sql.Stmt
|
||||||
batchSelectPrivacyPolicyStmt *sql.Stmt
|
batchSelectPrivacyPolicyStmt *sql.Stmt
|
||||||
|
updatePolicyVersionStmt *sql.Stmt
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,6 +105,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
||||||
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
||||||
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
|
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
|
||||||
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
||||||
|
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,3 +224,15 @@ func (s *accountsStatements) batchSelectPrivacyPolicy(
|
||||||
}
|
}
|
||||||
return userIDs, rows.Err()
|
return userIDs, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updatePolicyVersion sets the policy_version for a specific user
|
||||||
|
func (s *accountsStatements) updatePolicyVersion(
|
||||||
|
ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := s.updatePolicyVersionStmt
|
||||||
|
if txn != nil {
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
|
}
|
||||||
|
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
|
@ -580,3 +580,11 @@ func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdatePolicyVersion sets the accepted policy_version for a user.
|
||||||
|
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) {
|
||||||
|
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.accounts.updatePolicyVersion(ctx, txn, policyVersion, localpart)
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue