Add server_notice_room_id and methods to update/get it
This commit is contained in:
parent
519ea13510
commit
699617ee4d
|
@ -57,6 +57,8 @@ type UserInternalAPI interface {
|
||||||
QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error
|
QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error
|
||||||
QueryOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyRequest, res *QueryOutdatedPolicyResponse) error
|
QueryOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyRequest, res *QueryOutdatedPolicyResponse) error
|
||||||
PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error
|
PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error
|
||||||
|
SelectServerNoticeRoomID(ctx context.Context, req *QueryServerNoticeRoomRequest, res *QueryServerNoticeRoomResponse) (err error)
|
||||||
|
UpdateServerNoticeRoomID(ctx context.Context, req *UpdateServerNoticeRoomRequest, res *UpdateServerNoticeRoomResponse) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type PerformKeyBackupRequest struct {
|
type PerformKeyBackupRequest struct {
|
||||||
|
@ -348,12 +350,12 @@ type QueryOpenIDTokenResponse struct {
|
||||||
ExpiresAtMS int64
|
ExpiresAtMS int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryPolicyVersionRequest is the response for QueryPolicyVersionRequest
|
// QueryPolicyVersionRequest is the request for QueryPolicyVersionRequest
|
||||||
type QueryPolicyVersionRequest struct {
|
type QueryPolicyVersionRequest struct {
|
||||||
LocalPart string
|
Localpart string
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryPolicyVersionResponsestruct is the response for QueryPolicyVersionResponsestruct
|
// QueryPolicyVersionResponse is the response for QueryPolicyVersionRequest
|
||||||
type QueryPolicyVersionResponse struct {
|
type QueryPolicyVersionResponse struct {
|
||||||
PolicyVersion string
|
PolicyVersion string
|
||||||
}
|
}
|
||||||
|
@ -365,18 +367,36 @@ type QueryOutdatedPolicyRequest struct {
|
||||||
|
|
||||||
// QueryOutdatedPolicyResponse is the response for QueryOutdatedPolicyRequest
|
// QueryOutdatedPolicyResponse is the response for QueryOutdatedPolicyRequest
|
||||||
type QueryOutdatedPolicyResponse struct {
|
type QueryOutdatedPolicyResponse struct {
|
||||||
OutdatedUsers []string
|
UserLocalparts []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest
|
// UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest
|
||||||
type UpdatePolicyVersionRequest struct {
|
type UpdatePolicyVersionRequest struct {
|
||||||
PolicyVersion, LocalPart string
|
PolicyVersion, Localpart string
|
||||||
ServerNoticeUpdate bool
|
ServerNoticeUpdate bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest
|
// UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest
|
||||||
type UpdatePolicyVersionResponse struct{}
|
type UpdatePolicyVersionResponse struct{}
|
||||||
|
|
||||||
|
// QueryServerNoticeRoomRequest is the request for QueryServerNoticeRoomRequest
|
||||||
|
type QueryServerNoticeRoomRequest struct {
|
||||||
|
Localpart string
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryServerNoticeRoomResponse is the response for QueryServerNoticeRoomRequest
|
||||||
|
type QueryServerNoticeRoomResponse struct {
|
||||||
|
RoomID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateServerNoticeRoomRequest is the request for UpdateServerNoticeRoomRequest
|
||||||
|
type UpdateServerNoticeRoomRequest struct {
|
||||||
|
Localpart, RoomID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateServerNoticeRoomResponse is the response for UpdateServerNoticeRoomRequest
|
||||||
|
type UpdateServerNoticeRoomResponse struct{}
|
||||||
|
|
||||||
// Device represents a client's device (mobile, web, etc)
|
// Device represents a client's device (mobile, web, etc)
|
||||||
type Device struct {
|
type Device struct {
|
||||||
ID string
|
ID string
|
||||||
|
|
|
@ -167,6 +167,18 @@ func (t *UserInternalAPITrace) PerformUpdatePolicyVersion(ctx context.Context, r
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *UserInternalAPITrace) SelectServerNoticeRoomID(ctx context.Context, req *QueryServerNoticeRoomRequest, res *QueryServerNoticeRoomResponse) error {
|
||||||
|
err := t.Impl.SelectServerNoticeRoomID(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).Infof("SelectServerNoticeRoomID req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UserInternalAPITrace) UpdateServerNoticeRoomID(ctx context.Context, req *UpdateServerNoticeRoomRequest, res *UpdateServerNoticeRoomResponse) error {
|
||||||
|
err := t.Impl.UpdateServerNoticeRoomID(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).Infof("UpdateServerNoticeRoomID req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func js(thing interface{}) string {
|
func js(thing interface{}) string {
|
||||||
b, err := json.Marshal(thing)
|
b, err := json.Marshal(thing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -769,7 +769,7 @@ func (a *UserInternalAPI) QueryPolicyVersion(
|
||||||
res *api.QueryPolicyVersionResponse,
|
res *api.QueryPolicyVersionResponse,
|
||||||
) error {
|
) error {
|
||||||
var err error
|
var err error
|
||||||
res.PolicyVersion, err = a.DB.GetPrivacyPolicy(ctx, req.LocalPart)
|
res.PolicyVersion, err = a.DB.GetPrivacyPolicy(ctx, req.Localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -783,7 +783,7 @@ func (a *UserInternalAPI) QueryOutdatedPolicy(
|
||||||
res *api.QueryOutdatedPolicyResponse,
|
res *api.QueryOutdatedPolicyResponse,
|
||||||
) error {
|
) error {
|
||||||
var err error
|
var err error
|
||||||
res.OutdatedUsers, err = a.DB.GetOutdatedPolicy(ctx, req.PolicyVersion)
|
res.UserLocalparts, err = a.DB.GetOutdatedPolicy(ctx, req.PolicyVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -796,5 +796,26 @@ func (a *UserInternalAPI) PerformUpdatePolicyVersion(
|
||||||
req *api.UpdatePolicyVersionRequest,
|
req *api.UpdatePolicyVersionRequest,
|
||||||
res *api.UpdatePolicyVersionResponse,
|
res *api.UpdatePolicyVersionResponse,
|
||||||
) error {
|
) error {
|
||||||
return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart, req.ServerNoticeUpdate)
|
return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.Localpart, req.ServerNoticeUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) SelectServerNoticeRoomID(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.QueryServerNoticeRoomRequest,
|
||||||
|
res *api.QueryServerNoticeRoomResponse,
|
||||||
|
) (err error) {
|
||||||
|
roomID, err := a.DB.SelectServerNoticeRoomID(ctx, req.Localpart)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.RoomID = roomID
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) UpdateServerNoticeRoomID(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.UpdateServerNoticeRoomRequest,
|
||||||
|
res *api.UpdateServerNoticeRoomResponse,
|
||||||
|
) (err error) {
|
||||||
|
return a.DB.UpdateServerNoticeRoomID(ctx, req.Localpart, req.RoomID)
|
||||||
}
|
}
|
|
@ -41,6 +41,7 @@ const (
|
||||||
PerformPusherDeletionPath = "/pushserver/performPusherDeletion"
|
PerformPusherDeletionPath = "/pushserver/performPusherDeletion"
|
||||||
PerformPushRulesPutPath = "/pushserver/performPushRulesPut"
|
PerformPushRulesPutPath = "/pushserver/performPushRulesPut"
|
||||||
PerformUpdatePolicyVersionPath = "/userapi/performUpdatePolicyVersion"
|
PerformUpdatePolicyVersionPath = "/userapi/performUpdatePolicyVersion"
|
||||||
|
PerformUpdateServerNoticeRoomPath = "/userapi/performUpdateServerNoticeRoom"
|
||||||
|
|
||||||
QueryKeyBackupPath = "/userapi/queryKeyBackup"
|
QueryKeyBackupPath = "/userapi/queryKeyBackup"
|
||||||
QueryProfilePath = "/userapi/queryProfile"
|
QueryProfilePath = "/userapi/queryProfile"
|
||||||
|
@ -55,6 +56,7 @@ const (
|
||||||
QueryNotificationsPath = "/pushserver/queryNotifications"
|
QueryNotificationsPath = "/pushserver/queryNotifications"
|
||||||
QueryPolicyVersionPath = "/userapi/queryPolicyVersion"
|
QueryPolicyVersionPath = "/userapi/queryPolicyVersion"
|
||||||
QueryOutdatedPolicyUsersPath = "/userapi/queryOutdatedPolicy"
|
QueryOutdatedPolicyUsersPath = "/userapi/queryOutdatedPolicy"
|
||||||
|
QueryServerNoticeRoomPath = "/userapi/queryServerNoticeRoom"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
||||||
|
@ -337,3 +339,19 @@ func (h *httpUserInternalAPI) PerformUpdatePolicyVersion(ctx context.Context, re
|
||||||
apiURL := h.apiURL + PerformUpdatePolicyVersionPath
|
apiURL := h.apiURL + PerformUpdatePolicyVersionPath
|
||||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) SelectServerNoticeRoomID(ctx context.Context, req *api.QueryServerNoticeRoomRequest, res *api.QueryServerNoticeRoomResponse) (err error) {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "SelectServerNoticeRoomID")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + QueryServerNoticeRoomPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) UpdateServerNoticeRoomID(ctx context.Context, req *api.UpdateServerNoticeRoomRequest, res *api.UpdateServerNoticeRoomResponse) (err error) {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "UpdateServerNoticeRoomID")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + PerformUpdateServerNoticeRoomPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
|
@ -389,4 +389,32 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(QueryServerNoticeRoomPath,
|
||||||
|
httputil.MakeInternalAPI("queryServerNoticeRoom", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryServerNoticeRoomRequest{}
|
||||||
|
response := api.QueryServerNoticeRoomResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
err := s.SelectServerNoticeRoomID(req.Context(), &request, &response)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(PerformUpdateServerNoticeRoomPath,
|
||||||
|
httputil.MakeInternalAPI("performUpdateServerNoticeRoom", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.UpdateServerNoticeRoomRequest{}
|
||||||
|
response := api.UpdateServerNoticeRoomResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
err := s.UpdateServerNoticeRoomID(req.Context(), &request, &response)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,6 +55,8 @@ type Database interface {
|
||||||
GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error)
|
GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error)
|
||||||
GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error)
|
GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error)
|
||||||
UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) error
|
UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) error
|
||||||
|
SelectServerNoticeRoomID(ctx context.Context, localpart string) (roomID string, err error)
|
||||||
|
UpdateServerNoticeRoomID(ctx context.Context, localpart, roomID string) (err error)
|
||||||
|
|
||||||
// Key backups
|
// Key backups
|
||||||
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
|
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
|
||||||
|
|
|
@ -47,7 +47,8 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
||||||
-- The policy version this user has accepted
|
-- The policy version this user has accepted
|
||||||
policy_version TEXT,
|
policy_version TEXT,
|
||||||
-- The policy version the user received from the server notices room
|
-- The policy version the user received from the server notices room
|
||||||
policy_version_sent TEXT
|
policy_version_sent TEXT,
|
||||||
|
server_notice_room_id TEXT
|
||||||
-- TODO:
|
-- TODO:
|
||||||
-- upgraded_ts, devices, any email reset stuff?
|
-- upgraded_ts, devices, any email reset stuff?
|
||||||
);
|
);
|
||||||
|
@ -85,6 +86,12 @@ const updatePolicyVersionSQL = "" +
|
||||||
const updatePolicyVersionServerNoticeSQL = "" +
|
const updatePolicyVersionServerNoticeSQL = "" +
|
||||||
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
|
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
|
const selectServerNoticeRoomSQL = "" +
|
||||||
|
"SELECT server_notice_room_id FROM account_accounts WHERE localpart = $1"
|
||||||
|
|
||||||
|
const updateServerNoticeRoomSQL = "" +
|
||||||
|
"UPDATE account_accounts SET server_notice_room_id = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
updatePasswordStmt *sql.Stmt
|
updatePasswordStmt *sql.Stmt
|
||||||
|
@ -96,6 +103,8 @@ type accountsStatements struct {
|
||||||
batchSelectPrivacyPolicyStmt *sql.Stmt
|
batchSelectPrivacyPolicyStmt *sql.Stmt
|
||||||
updatePolicyVersionStmt *sql.Stmt
|
updatePolicyVersionStmt *sql.Stmt
|
||||||
updatePolicyVersionServerNoticeStmt *sql.Stmt
|
updatePolicyVersionServerNoticeStmt *sql.Stmt
|
||||||
|
selectServerNoticeRoomStmt *sql.Stmt
|
||||||
|
updateServerNoticeRoomStmt *sql.Stmt
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,6 +127,8 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
|
||||||
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
||||||
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
|
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
|
||||||
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
|
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
|
||||||
|
{&s.selectServerNoticeRoomStmt, selectServerNoticeRoomSQL},
|
||||||
|
{&s.updateServerNoticeRoomStmt, updateServerNoticeRoomSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,12 +141,7 @@ func (s *accountsStatements) InsertAccount(
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||||
|
|
||||||
var err error
|
_, err := stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
|
||||||
if accountType != api.AccountTypeAppService {
|
|
||||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
|
|
||||||
} else {
|
|
||||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, "")
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -251,3 +257,36 @@ func (s *accountsStatements) UpdatePolicyVersion(
|
||||||
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
|
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SelectServerNoticeRoomID queries the server notice room ID.
|
||||||
|
func (s *accountsStatements) SelectServerNoticeRoomID(
|
||||||
|
ctx context.Context, txn *sql.Tx, localpart string,
|
||||||
|
) (roomID string, err error) {
|
||||||
|
stmt := s.selectServerNoticeRoomStmt
|
||||||
|
if txn != nil {
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
|
}
|
||||||
|
|
||||||
|
roomIDNull := sql.NullString{}
|
||||||
|
row := stmt.QueryRowContext(ctx, localpart)
|
||||||
|
err = row.Scan(&roomIDNull)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if roomIDNull.Valid {
|
||||||
|
return roomIDNull.String, nil
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateServerNoticeRoomID sets the server notice room ID.
|
||||||
|
func (s *accountsStatements) UpdateServerNoticeRoomID(
|
||||||
|
ctx context.Context, txn *sql.Tx, localpart, roomID string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := s.updateServerNoticeRoomStmt
|
||||||
|
if txn != nil {
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
|
}
|
||||||
|
_, err = stmt.ExecContext(ctx, roomID, localpart)
|
||||||
|
return
|
||||||
|
}
|
|
@ -20,13 +20,24 @@ func UpAddPolicyVersion(tx *sql.Tx) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
}
|
}
|
||||||
|
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS server_notice_room_id TEXT;")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DownAddPolicyVersion(tx *sql.Tx) error {
|
func DownAddPolicyVersion(tx *sql.Tx) error {
|
||||||
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;" +
|
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;")
|
||||||
"ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN server_notice_room_id;")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -791,3 +791,16 @@ func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, local
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SelectServerNoticeRoomID returns the server notice room, if one is set.
|
||||||
|
func (d *Database) SelectServerNoticeRoomID(ctx context.Context, localpart string) (roomID string, err error) {
|
||||||
|
return d.Accounts.SelectServerNoticeRoomID(ctx, nil, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateServerNoticeRoomID updates the server notice room
|
||||||
|
func (d *Database) UpdateServerNoticeRoomID(ctx context.Context, localpart, roomID string) (err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.Accounts.UpdateServerNoticeRoomID(ctx, txn, localpart, roomID)
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
|
@ -47,7 +47,8 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
||||||
-- The policy version this user has accepted
|
-- The policy version this user has accepted
|
||||||
policy_version TEXT,
|
policy_version TEXT,
|
||||||
-- The policy version the user received from the server notices room
|
-- The policy version the user received from the server notices room
|
||||||
policy_version_sent TEXT
|
policy_version_sent TEXT,
|
||||||
|
server_notice_room_id TEXT
|
||||||
-- TODO:
|
-- TODO:
|
||||||
-- upgraded_ts, devices, any email reset stuff?
|
-- upgraded_ts, devices, any email reset stuff?
|
||||||
);
|
);
|
||||||
|
@ -83,6 +84,12 @@ const updatePolicyVersionSQL = "" +
|
||||||
const updatePolicyVersionServerNoticeSQL = "" +
|
const updatePolicyVersionServerNoticeSQL = "" +
|
||||||
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
|
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
|
const selectServerNoticeRoomSQL = "" +
|
||||||
|
"SELECT server_notice_room_id FROM account_accounts WHERE localpart = $1"
|
||||||
|
|
||||||
|
const updateServerNoticeRoomSQL = "" +
|
||||||
|
"UPDATE account_accounts SET server_notice_room_id = $1 WHERE localpart = $2"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
|
@ -95,6 +102,8 @@ type accountsStatements struct {
|
||||||
batchSelectPrivacyPolicyStmt *sql.Stmt
|
batchSelectPrivacyPolicyStmt *sql.Stmt
|
||||||
updatePolicyVersionStmt *sql.Stmt
|
updatePolicyVersionStmt *sql.Stmt
|
||||||
updatePolicyVersionServerNoticeStmt *sql.Stmt
|
updatePolicyVersionServerNoticeStmt *sql.Stmt
|
||||||
|
selectServerNoticeRoomStmt *sql.Stmt
|
||||||
|
updateServerNoticeRoomStmt *sql.Stmt
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,6 +127,8 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
||||||
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
|
||||||
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
|
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
|
||||||
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
|
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
|
||||||
|
{&s.selectServerNoticeRoomStmt, selectServerNoticeRoomSQL},
|
||||||
|
{&s.updateServerNoticeRoomStmt, updateServerNoticeRoomSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,12 +141,7 @@ func (s *accountsStatements) InsertAccount(
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
stmt := s.insertAccountStmt
|
stmt := s.insertAccountStmt
|
||||||
|
|
||||||
var err error
|
_, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, policyVersion)
|
||||||
if accountType != api.AccountTypeAppService {
|
|
||||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
|
|
||||||
} else {
|
|
||||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, "")
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -254,3 +260,36 @@ func (s *accountsStatements) UpdatePolicyVersion(
|
||||||
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
|
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SelectServerNoticeRoomID queries the server notice room ID.
|
||||||
|
func (s *accountsStatements) SelectServerNoticeRoomID(
|
||||||
|
ctx context.Context, txn *sql.Tx, localpart string,
|
||||||
|
) (roomID string, err error) {
|
||||||
|
stmt := s.selectServerNoticeRoomStmt
|
||||||
|
if txn != nil {
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
|
}
|
||||||
|
|
||||||
|
roomIDNull := sql.NullString{}
|
||||||
|
row := stmt.QueryRowContext(ctx, localpart)
|
||||||
|
err = row.Scan(&roomIDNull)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if roomIDNull.Valid {
|
||||||
|
return roomIDNull.String, nil
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateServerNoticeRoomID sets the server notice room ID.
|
||||||
|
func (s *accountsStatements) UpdateServerNoticeRoomID(
|
||||||
|
ctx context.Context, txn *sql.Tx, localpart, roomID string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := s.updateServerNoticeRoomStmt
|
||||||
|
if txn != nil {
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
|
}
|
||||||
|
_, err = stmt.ExecContext(ctx, roomID, localpart)
|
||||||
|
return
|
||||||
|
}
|
|
@ -20,12 +20,23 @@ func UpAddPolicyVersion(tx *sql.Tx) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
}
|
}
|
||||||
|
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN server_notice_room_id TEXT;")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DownAddPolicyVersion(tx *sql.Tx) error {
|
func DownAddPolicyVersion(tx *sql.Tx) error {
|
||||||
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;" +
|
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;")
|
||||||
"ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN server_notice_room_id;")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,6 +41,8 @@ type AccountsTable interface {
|
||||||
SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error)
|
SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error)
|
||||||
BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error)
|
BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error)
|
||||||
UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool) (err error)
|
UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool) (err error)
|
||||||
|
SelectServerNoticeRoomID(ctx context.Context, txn *sql.Tx, localpart string) (roomID string, err error)
|
||||||
|
UpdateServerNoticeRoomID(ctx context.Context, txn *sql.Tx, localpart, roomID string) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DevicesTable interface {
|
type DevicesTable interface {
|
||||||
|
|
Loading…
Reference in a new issue