Add a way to update the policy_version for a user
This commit is contained in:
parent
a505471c90
commit
097f1d4609
|
@ -46,6 +46,7 @@ type UserInternalAPI interface {
|
|||
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
|
||||
QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error
|
||||
GetOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyUsersRequest, res *QueryOutdatedPolicyUsersResponse) error
|
||||
PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error
|
||||
}
|
||||
|
||||
type PerformKeyBackupRequest struct {
|
||||
|
@ -347,7 +348,7 @@ type QueryPolicyVersionResponse struct {
|
|||
PolicyVersion string
|
||||
}
|
||||
|
||||
// QueryOutdatedPolicyUsersRequest is the response for QueryOutdatedPolicyUsersRequest
|
||||
// QueryOutdatedPolicyUsersRequest is the request for QueryOutdatedPolicyUsersRequest
|
||||
type QueryOutdatedPolicyUsersRequest struct {
|
||||
PolicyVersion string
|
||||
}
|
||||
|
@ -357,6 +358,14 @@ type QueryOutdatedPolicyUsersResponse struct {
|
|||
OutdatedUsers []string
|
||||
}
|
||||
|
||||
// UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest
|
||||
type UpdatePolicyVersionRequest struct {
|
||||
PolicyVersion, LocalPart string
|
||||
}
|
||||
|
||||
// UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest
|
||||
type UpdatePolicyVersionResponse struct{}
|
||||
|
||||
// Device represents a client's device (mobile, web, etc)
|
||||
type Device struct {
|
||||
ID string
|
||||
|
|
|
@ -127,7 +127,13 @@ func (t *UserInternalAPITrace) QueryPolicyVersion(ctx context.Context, req *Quer
|
|||
|
||||
func (t *UserInternalAPITrace) GetOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyUsersRequest, res *QueryOutdatedPolicyUsersResponse) error {
|
||||
err := t.Impl.GetOutdatedPolicy(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("QueryPolicyVersion req=%+v res=%+v", js(req), js(res))
|
||||
util.GetLogger(ctx).Infof("GetOutdatedPolicy req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *UserInternalAPITrace) PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error {
|
||||
err := t.Impl.PerformUpdatePolicyVersion(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("PerformUpdatePolicyVersion req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -615,3 +615,11 @@ func (a *UserInternalAPI) GetOutdatedPolicy(
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformUpdatePolicyVersion(
|
||||
ctx context.Context,
|
||||
req *api.UpdatePolicyVersionRequest,
|
||||
res *api.UpdatePolicyVersionResponse,
|
||||
) error {
|
||||
return a.AccountDB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart)
|
||||
}
|
||||
|
|
|
@ -37,6 +37,7 @@ const (
|
|||
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
|
||||
PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation"
|
||||
PerformKeyBackupPath = "/userapi/performKeyBackup"
|
||||
PerformUpdatePolicyVersionPath = "/userapi/performUpdatePolicyVersion"
|
||||
|
||||
QueryKeyBackupPath = "/userapi/queryKeyBackup"
|
||||
QueryProfilePath = "/userapi/queryProfile"
|
||||
|
@ -46,8 +47,8 @@ const (
|
|||
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
|
||||
QuerySearchProfilesPath = "/userapi/querySearchProfiles"
|
||||
QueryOpenIDTokenPath = "/userapi/queryOpenIDToken"
|
||||
QueryPolicyVersion = "/userapi/queryPolicyVersion"
|
||||
QueryOutdatedPolicyUsers = "/userapi/queryOutdatedPolicy"
|
||||
QueryPolicyVersionPath = "/userapi/queryPolicyVersion"
|
||||
QueryOutdatedPolicyUsersPath = "/userapi/queryOutdatedPolicy"
|
||||
)
|
||||
|
||||
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
||||
|
@ -256,7 +257,7 @@ func (h *httpUserInternalAPI) QueryPolicyVersion(ctx context.Context, req *api.Q
|
|||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + QueryPolicyVersion
|
||||
apiURL := h.apiURL + QueryPolicyVersionPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
|
@ -264,6 +265,14 @@ func (h *httpUserInternalAPI) GetOutdatedPolicy(ctx context.Context, req *api.Qu
|
|||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + QueryOutdatedPolicyUsers
|
||||
apiURL := h.apiURL + QueryOutdatedPolicyUsersPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) PerformUpdatePolicyVersion(ctx context.Context, req *api.UpdatePolicyVersionRequest, res *api.UpdatePolicyVersionResponse) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + PerformUpdatePolicyVersionPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
|
|
@ -265,7 +265,7 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
|||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(QueryPolicyVersion,
|
||||
internalAPIMux.Handle(QueryPolicyVersionPath,
|
||||
httputil.MakeInternalAPI("queryPolicyVersion", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryPolicyVersionRequest{}
|
||||
response := api.QueryPolicyVersionResponse{}
|
||||
|
@ -279,7 +279,7 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
|||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(QueryOutdatedPolicyUsers,
|
||||
internalAPIMux.Handle(QueryOutdatedPolicyUsersPath,
|
||||
httputil.MakeInternalAPI("queryOutdatedPolicyUsers", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryOutdatedPolicyUsersRequest{}
|
||||
response := api.QueryOutdatedPolicyUsersResponse{}
|
||||
|
@ -293,4 +293,18 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
|||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(PerformUpdatePolicyVersionPath,
|
||||
httputil.MakeInternalAPI("performUpdatePolicyVersionPath", func(req *http.Request) util.JSONResponse {
|
||||
request := api.UpdatePolicyVersionRequest{}
|
||||
response := api.UpdatePolicyVersionResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
err := s.PerformUpdatePolicyVersion(req.Context(), &request, &response)
|
||||
if err != nil {
|
||||
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -54,6 +54,7 @@ type Database interface {
|
|||
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||
GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error)
|
||||
GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error)
|
||||
UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) error
|
||||
|
||||
// Key backups
|
||||
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
|
||||
|
|
|
@ -73,6 +73,9 @@ const selectPrivacyPolicySQL = "" +
|
|||
const batchSelectPrivacyPolicySQL = "" +
|
||||
"SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1"
|
||||
|
||||
const updatePolicyVersionSQL = "" +
|
||||
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
|
||||
|
||||
type accountsStatements struct {
|
||||
insertAccountStmt *sql.Stmt
|
||||
updatePasswordStmt *sql.Stmt
|
||||
|
@ -82,6 +85,7 @@ type accountsStatements struct {
|
|||
selectNewNumericLocalpartStmt *sql.Stmt
|
||||
selectPrivacyPolicyStmt *sql.Stmt
|
||||
batchSelectPrivacyPolicyStmt *sql.Stmt
|
||||
updatePolicyVersionStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
|
@ -101,6 +105,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
|||
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
||||
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
|
||||
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
||||
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
|
@ -218,3 +223,15 @@ func (s *accountsStatements) batchSelectPrivacyPolicy(
|
|||
}
|
||||
return userIDs, rows.Err()
|
||||
}
|
||||
|
||||
// updatePolicyVersion sets the policy_version for a specific user
|
||||
func (s *accountsStatements) updatePolicyVersion(
|
||||
ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
|
||||
) (err error) {
|
||||
stmt := s.updatePolicyVersionStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -529,7 +529,7 @@ func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (poli
|
|||
return
|
||||
}
|
||||
|
||||
// GetOutdatedPolicy queries all users which didn't accept the current policy version
|
||||
// GetOutdatedPolicy queries all users which didn't accept the current policy version.
|
||||
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion)
|
||||
|
@ -537,3 +537,11 @@ func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string)
|
|||
})
|
||||
return
|
||||
}
|
||||
|
||||
// UpdatePolicyVersion sets the accepted policy_version for a user.
|
||||
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) {
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.accounts.updatePolicyVersion(ctx, txn, policyVersion, localpart)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -71,6 +71,9 @@ const selectPrivacyPolicySQL = "" +
|
|||
const batchSelectPrivacyPolicySQL = "" +
|
||||
"SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1"
|
||||
|
||||
const updatePolicyVersionSQL = "" +
|
||||
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
|
||||
|
||||
type accountsStatements struct {
|
||||
db *sql.DB
|
||||
insertAccountStmt *sql.Stmt
|
||||
|
@ -81,6 +84,7 @@ type accountsStatements struct {
|
|||
selectNewNumericLocalpartStmt *sql.Stmt
|
||||
selectPrivacyPolicyStmt *sql.Stmt
|
||||
batchSelectPrivacyPolicyStmt *sql.Stmt
|
||||
updatePolicyVersionStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
|
@ -101,6 +105,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
|||
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
|
||||
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
|
||||
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
||||
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
|
@ -219,3 +224,15 @@ func (s *accountsStatements) batchSelectPrivacyPolicy(
|
|||
}
|
||||
return userIDs, rows.Err()
|
||||
}
|
||||
|
||||
// updatePolicyVersion sets the policy_version for a specific user
|
||||
func (s *accountsStatements) updatePolicyVersion(
|
||||
ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
|
||||
) (err error) {
|
||||
stmt := s.updatePolicyVersionStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -580,3 +580,11 @@ func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string)
|
|||
})
|
||||
return
|
||||
}
|
||||
|
||||
// UpdatePolicyVersion sets the accepted policy_version for a user.
|
||||
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) {
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
return d.accounts.updatePolicyVersion(ctx, txn, policyVersion, localpart)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue