Add a way to update the policy_version for a user

This commit is contained in:
Till Faelligen 2022-02-14 15:08:00 +01:00
parent a505471c90
commit 097f1d4609
10 changed files with 114 additions and 17 deletions

View file

@ -46,6 +46,7 @@ type UserInternalAPI interface {
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error
GetOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyUsersRequest, res *QueryOutdatedPolicyUsersResponse) error GetOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyUsersRequest, res *QueryOutdatedPolicyUsersResponse) error
PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error
} }
type PerformKeyBackupRequest struct { type PerformKeyBackupRequest struct {
@ -347,7 +348,7 @@ type QueryPolicyVersionResponse struct {
PolicyVersion string PolicyVersion string
} }
// QueryOutdatedPolicyUsersRequest is the response for QueryOutdatedPolicyUsersRequest // QueryOutdatedPolicyUsersRequest is the request for QueryOutdatedPolicyUsersRequest
type QueryOutdatedPolicyUsersRequest struct { type QueryOutdatedPolicyUsersRequest struct {
PolicyVersion string PolicyVersion string
} }
@ -357,6 +358,14 @@ type QueryOutdatedPolicyUsersResponse struct {
OutdatedUsers []string OutdatedUsers []string
} }
// UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest
type UpdatePolicyVersionRequest struct {
PolicyVersion, LocalPart string
}
// UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest
type UpdatePolicyVersionResponse struct{}
// Device represents a client's device (mobile, web, etc) // Device represents a client's device (mobile, web, etc)
type Device struct { type Device struct {
ID string ID string

View file

@ -127,7 +127,13 @@ func (t *UserInternalAPITrace) QueryPolicyVersion(ctx context.Context, req *Quer
func (t *UserInternalAPITrace) GetOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyUsersRequest, res *QueryOutdatedPolicyUsersResponse) error { func (t *UserInternalAPITrace) GetOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyUsersRequest, res *QueryOutdatedPolicyUsersResponse) error {
err := t.Impl.GetOutdatedPolicy(ctx, req, res) err := t.Impl.GetOutdatedPolicy(ctx, req, res)
util.GetLogger(ctx).Infof("QueryPolicyVersion req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).Infof("GetOutdatedPolicy req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error {
err := t.Impl.PerformUpdatePolicyVersion(ctx, req, res)
util.GetLogger(ctx).Infof("PerformUpdatePolicyVersion req=%+v res=%+v", js(req), js(res))
return err return err
} }

View file

@ -615,3 +615,11 @@ func (a *UserInternalAPI) GetOutdatedPolicy(
return nil return nil
} }
func (a *UserInternalAPI) PerformUpdatePolicyVersion(
ctx context.Context,
req *api.UpdatePolicyVersionRequest,
res *api.UpdatePolicyVersionResponse,
) error {
return a.AccountDB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart)
}

View file

@ -37,6 +37,7 @@ const (
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation"
PerformKeyBackupPath = "/userapi/performKeyBackup" PerformKeyBackupPath = "/userapi/performKeyBackup"
PerformUpdatePolicyVersionPath = "/userapi/performUpdatePolicyVersion"
QueryKeyBackupPath = "/userapi/queryKeyBackup" QueryKeyBackupPath = "/userapi/queryKeyBackup"
QueryProfilePath = "/userapi/queryProfile" QueryProfilePath = "/userapi/queryProfile"
@ -46,8 +47,8 @@ const (
QueryDeviceInfosPath = "/userapi/queryDeviceInfos" QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
QuerySearchProfilesPath = "/userapi/querySearchProfiles" QuerySearchProfilesPath = "/userapi/querySearchProfiles"
QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" QueryOpenIDTokenPath = "/userapi/queryOpenIDToken"
QueryPolicyVersion = "/userapi/queryPolicyVersion" QueryPolicyVersionPath = "/userapi/queryPolicyVersion"
QueryOutdatedPolicyUsers = "/userapi/queryOutdatedPolicy" QueryOutdatedPolicyUsersPath = "/userapi/queryOutdatedPolicy"
) )
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -256,7 +257,7 @@ func (h *httpUserInternalAPI) QueryPolicyVersion(ctx context.Context, req *api.Q
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup") span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup")
defer span.Finish() defer span.Finish()
apiURL := h.apiURL + QueryPolicyVersion apiURL := h.apiURL + QueryPolicyVersionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
@ -264,6 +265,14 @@ func (h *httpUserInternalAPI) GetOutdatedPolicy(ctx context.Context, req *api.Qu
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup") span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup")
defer span.Finish() defer span.Finish()
apiURL := h.apiURL + QueryOutdatedPolicyUsers apiURL := h.apiURL + QueryOutdatedPolicyUsersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) PerformUpdatePolicyVersion(ctx context.Context, req *api.UpdatePolicyVersionRequest, res *api.UpdatePolicyVersionResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup")
defer span.Finish()
apiURL := h.apiURL + PerformUpdatePolicyVersionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }

View file

@ -265,7 +265,7 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(QueryPolicyVersion, internalAPIMux.Handle(QueryPolicyVersionPath,
httputil.MakeInternalAPI("queryPolicyVersion", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("queryPolicyVersion", func(req *http.Request) util.JSONResponse {
request := api.QueryPolicyVersionRequest{} request := api.QueryPolicyVersionRequest{}
response := api.QueryPolicyVersionResponse{} response := api.QueryPolicyVersionResponse{}
@ -279,7 +279,7 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(QueryOutdatedPolicyUsers, internalAPIMux.Handle(QueryOutdatedPolicyUsersPath,
httputil.MakeInternalAPI("queryOutdatedPolicyUsers", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("queryOutdatedPolicyUsers", func(req *http.Request) util.JSONResponse {
request := api.QueryOutdatedPolicyUsersRequest{} request := api.QueryOutdatedPolicyUsersRequest{}
response := api.QueryOutdatedPolicyUsersResponse{} response := api.QueryOutdatedPolicyUsersResponse{}
@ -293,4 +293,18 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(PerformUpdatePolicyVersionPath,
httputil.MakeInternalAPI("performUpdatePolicyVersionPath", func(req *http.Request) util.JSONResponse {
request := api.UpdatePolicyVersionRequest{}
response := api.UpdatePolicyVersionResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
err := s.PerformUpdatePolicyVersion(req.Context(), &request, &response)
if err != nil {
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
} }

View file

@ -54,6 +54,7 @@ type Database interface {
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error)
GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error)
UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) error
// Key backups // Key backups
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error) CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)

View file

@ -73,6 +73,9 @@ const selectPrivacyPolicySQL = "" +
const batchSelectPrivacyPolicySQL = "" + const batchSelectPrivacyPolicySQL = "" +
"SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1" "SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1"
const updatePolicyVersionSQL = "" +
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
type accountsStatements struct { type accountsStatements struct {
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt updatePasswordStmt *sql.Stmt
@ -82,6 +85,7 @@ type accountsStatements struct {
selectNewNumericLocalpartStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt
selectPrivacyPolicyStmt *sql.Stmt selectPrivacyPolicyStmt *sql.Stmt
batchSelectPrivacyPolicyStmt *sql.Stmt batchSelectPrivacyPolicyStmt *sql.Stmt
updatePolicyVersionStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -101,6 +105,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL}, {&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL}, {&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -218,3 +223,15 @@ func (s *accountsStatements) batchSelectPrivacyPolicy(
} }
return userIDs, rows.Err() return userIDs, rows.Err()
} }
// updatePolicyVersion sets the policy_version for a specific user
func (s *accountsStatements) updatePolicyVersion(
ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
) (err error) {
stmt := s.updatePolicyVersionStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
return err
}

View file

@ -529,7 +529,7 @@ func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (poli
return return
} }
// GetOutdatedPolicy queries all users which didn't accept the current policy version // GetOutdatedPolicy queries all users which didn't accept the current policy version.
func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) { func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion) userIDs, err = d.accounts.batchSelectPrivacyPolicy(ctx, txn, policyVersion)
@ -537,3 +537,11 @@ func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string)
}) })
return return
} }
// UpdatePolicyVersion sets the accepted policy_version for a user.
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.accounts.updatePolicyVersion(ctx, txn, policyVersion, localpart)
})
return
}

View file

@ -71,6 +71,9 @@ const selectPrivacyPolicySQL = "" +
const batchSelectPrivacyPolicySQL = "" + const batchSelectPrivacyPolicySQL = "" +
"SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1" "SELECT localpart FROM account_accounts WHERE policy_version IS NULL or policy_version <> $1"
const updatePolicyVersionSQL = "" +
"UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2"
type accountsStatements struct { type accountsStatements struct {
db *sql.DB db *sql.DB
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
@ -81,6 +84,7 @@ type accountsStatements struct {
selectNewNumericLocalpartStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt
selectPrivacyPolicyStmt *sql.Stmt selectPrivacyPolicyStmt *sql.Stmt
batchSelectPrivacyPolicyStmt *sql.Stmt batchSelectPrivacyPolicyStmt *sql.Stmt
updatePolicyVersionStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -101,6 +105,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
{&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL},
{&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL}, {&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL},
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL}, {&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -219,3 +224,15 @@ func (s *accountsStatements) batchSelectPrivacyPolicy(
} }
return userIDs, rows.Err() return userIDs, rows.Err()
} }
// updatePolicyVersion sets the policy_version for a specific user
func (s *accountsStatements) updatePolicyVersion(
ctx context.Context, txn *sql.Tx, policyVersion, localpart string,
) (err error) {
stmt := s.updatePolicyVersionStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
return err
}

View file

@ -580,3 +580,11 @@ func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string)
}) })
return return
} }
// UpdatePolicyVersion sets the accepted policy_version for a user.
func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string) (err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.accounts.updatePolicyVersion(ctx, txn, policyVersion, localpart)
})
return
}