From 699617ee4d13690d8761c0eec0affb050595e59f Mon Sep 17 00:00:00 2001 From: Till Faelligen Date: Mon, 7 Mar 2022 09:41:25 +0100 Subject: [PATCH] Add server_notice_room_id and methods to update/get it --- userapi/api/api.go | 30 +++++++-- userapi/api/api_trace.go | 12 ++++ userapi/internal/api.go | 27 +++++++- userapi/inthttp/client.go | 66 ++++++++++++------- userapi/inthttp/server.go | 28 ++++++++ userapi/storage/interface.go | 2 + userapi/storage/postgres/accounts_table.go | 53 +++++++++++++-- .../2022021414375800_add_policy_version.go | 15 ++++- userapi/storage/shared/storage.go | 13 ++++ userapi/storage/sqlite3/accounts_table.go | 53 +++++++++++++-- .../2022021414375800_add_policy_version.go | 15 ++++- userapi/storage/tables/interface.go | 2 + 12 files changed, 266 insertions(+), 50 deletions(-) diff --git a/userapi/api/api.go b/userapi/api/api.go index 244d13bb1..96a6164e4 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -57,6 +57,8 @@ type UserInternalAPI interface { QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error QueryOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyRequest, res *QueryOutdatedPolicyResponse) error PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error + SelectServerNoticeRoomID(ctx context.Context, req *QueryServerNoticeRoomRequest, res *QueryServerNoticeRoomResponse) (err error) + UpdateServerNoticeRoomID(ctx context.Context, req *UpdateServerNoticeRoomRequest, res *UpdateServerNoticeRoomResponse) (err error) } type PerformKeyBackupRequest struct { @@ -348,12 +350,12 @@ type QueryOpenIDTokenResponse struct { ExpiresAtMS int64 } -// QueryPolicyVersionRequest is the response for QueryPolicyVersionRequest +// QueryPolicyVersionRequest is the request for QueryPolicyVersionRequest type QueryPolicyVersionRequest struct { - LocalPart string + Localpart string } -// QueryPolicyVersionResponsestruct is the response for QueryPolicyVersionResponsestruct +// QueryPolicyVersionResponse is the response for QueryPolicyVersionRequest type QueryPolicyVersionResponse struct { PolicyVersion string } @@ -365,18 +367,36 @@ type QueryOutdatedPolicyRequest struct { // QueryOutdatedPolicyResponse is the response for QueryOutdatedPolicyRequest type QueryOutdatedPolicyResponse struct { - OutdatedUsers []string + UserLocalparts []string } // UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest type UpdatePolicyVersionRequest struct { - PolicyVersion, LocalPart string + PolicyVersion, Localpart string ServerNoticeUpdate bool } // UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest type UpdatePolicyVersionResponse struct{} +// QueryServerNoticeRoomRequest is the request for QueryServerNoticeRoomRequest +type QueryServerNoticeRoomRequest struct { + Localpart string +} + +// QueryServerNoticeRoomResponse is the response for QueryServerNoticeRoomRequest +type QueryServerNoticeRoomResponse struct { + RoomID string +} + +// UpdateServerNoticeRoomRequest is the request for UpdateServerNoticeRoomRequest +type UpdateServerNoticeRoomRequest struct { + Localpart, RoomID string +} + +// UpdateServerNoticeRoomResponse is the response for UpdateServerNoticeRoomRequest +type UpdateServerNoticeRoomResponse struct{} + // Device represents a client's device (mobile, web, etc) type Device struct { ID string diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index b8bde4340..9fbdde616 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -167,6 +167,18 @@ func (t *UserInternalAPITrace) PerformUpdatePolicyVersion(ctx context.Context, r return err } +func (t *UserInternalAPITrace) SelectServerNoticeRoomID(ctx context.Context, req *QueryServerNoticeRoomRequest, res *QueryServerNoticeRoomResponse) error { + err := t.Impl.SelectServerNoticeRoomID(ctx, req, res) + util.GetLogger(ctx).Infof("SelectServerNoticeRoomID req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *UserInternalAPITrace) UpdateServerNoticeRoomID(ctx context.Context, req *UpdateServerNoticeRoomRequest, res *UpdateServerNoticeRoomResponse) error { + err := t.Impl.UpdateServerNoticeRoomID(ctx, req, res) + util.GetLogger(ctx).Infof("UpdateServerNoticeRoomID req=%+v res=%+v", js(req), js(res)) + return err +} + func js(thing interface{}) string { b, err := json.Marshal(thing) if err != nil { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index dafbf2180..fd1ecd459 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -769,7 +769,7 @@ func (a *UserInternalAPI) QueryPolicyVersion( res *api.QueryPolicyVersionResponse, ) error { var err error - res.PolicyVersion, err = a.DB.GetPrivacyPolicy(ctx, req.LocalPart) + res.PolicyVersion, err = a.DB.GetPrivacyPolicy(ctx, req.Localpart) if err != nil { return err } @@ -783,7 +783,7 @@ func (a *UserInternalAPI) QueryOutdatedPolicy( res *api.QueryOutdatedPolicyResponse, ) error { var err error - res.OutdatedUsers, err = a.DB.GetOutdatedPolicy(ctx, req.PolicyVersion) + res.UserLocalparts, err = a.DB.GetOutdatedPolicy(ctx, req.PolicyVersion) if err != nil { return err } @@ -796,5 +796,26 @@ func (a *UserInternalAPI) PerformUpdatePolicyVersion( req *api.UpdatePolicyVersionRequest, res *api.UpdatePolicyVersionResponse, ) error { - return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart, req.ServerNoticeUpdate) + return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.Localpart, req.ServerNoticeUpdate) } + +func (a *UserInternalAPI) SelectServerNoticeRoomID( + ctx context.Context, + req *api.QueryServerNoticeRoomRequest, + res *api.QueryServerNoticeRoomResponse, +) (err error) { + roomID, err := a.DB.SelectServerNoticeRoomID(ctx, req.Localpart) + if err != nil { + return err + } + res.RoomID = roomID + return nil +} + +func (a *UserInternalAPI) UpdateServerNoticeRoomID( + ctx context.Context, + req *api.UpdateServerNoticeRoomRequest, + res *api.UpdateServerNoticeRoomResponse, +) (err error) { + return a.DB.UpdateServerNoticeRoomID(ctx, req.Localpart, req.RoomID) +} \ No newline at end of file diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 59be975cd..4609c2e1a 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -28,33 +28,35 @@ import ( const ( InputAccountDataPath = "/userapi/inputAccountData" - PerformDeviceCreationPath = "/userapi/performDeviceCreation" - PerformAccountCreationPath = "/userapi/performAccountCreation" - PerformPasswordUpdatePath = "/userapi/performPasswordUpdate" - PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" - PerformLastSeenUpdatePath = "/userapi/performLastSeenUpdate" - PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" - PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" - PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" - PerformKeyBackupPath = "/userapi/performKeyBackup" - PerformPusherSetPath = "/pushserver/performPusherSet" - PerformPusherDeletionPath = "/pushserver/performPusherDeletion" - PerformPushRulesPutPath = "/pushserver/performPushRulesPut" - PerformUpdatePolicyVersionPath = "/userapi/performUpdatePolicyVersion" + PerformDeviceCreationPath = "/userapi/performDeviceCreation" + PerformAccountCreationPath = "/userapi/performAccountCreation" + PerformPasswordUpdatePath = "/userapi/performPasswordUpdate" + PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" + PerformLastSeenUpdatePath = "/userapi/performLastSeenUpdate" + PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" + PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" + PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" + PerformKeyBackupPath = "/userapi/performKeyBackup" + PerformPusherSetPath = "/pushserver/performPusherSet" + PerformPusherDeletionPath = "/pushserver/performPusherDeletion" + PerformPushRulesPutPath = "/pushserver/performPushRulesPut" + PerformUpdatePolicyVersionPath = "/userapi/performUpdatePolicyVersion" + PerformUpdateServerNoticeRoomPath = "/userapi/performUpdateServerNoticeRoom" - QueryKeyBackupPath = "/userapi/queryKeyBackup" - QueryProfilePath = "/userapi/queryProfile" - QueryAccessTokenPath = "/userapi/queryAccessToken" - QueryDevicesPath = "/userapi/queryDevices" - QueryAccountDataPath = "/userapi/queryAccountData" - QueryDeviceInfosPath = "/userapi/queryDeviceInfos" - QuerySearchProfilesPath = "/userapi/querySearchProfiles" - QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" - QueryPushersPath = "/pushserver/queryPushers" - QueryPushRulesPath = "/pushserver/queryPushRules" - QueryNotificationsPath = "/pushserver/queryNotifications" + QueryKeyBackupPath = "/userapi/queryKeyBackup" + QueryProfilePath = "/userapi/queryProfile" + QueryAccessTokenPath = "/userapi/queryAccessToken" + QueryDevicesPath = "/userapi/queryDevices" + QueryAccountDataPath = "/userapi/queryAccountData" + QueryDeviceInfosPath = "/userapi/queryDeviceInfos" + QuerySearchProfilesPath = "/userapi/querySearchProfiles" + QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" + QueryPushersPath = "/pushserver/queryPushers" + QueryPushRulesPath = "/pushserver/queryPushRules" + QueryNotificationsPath = "/pushserver/queryNotifications" QueryPolicyVersionPath = "/userapi/queryPolicyVersion" QueryOutdatedPolicyUsersPath = "/userapi/queryOutdatedPolicy" + QueryServerNoticeRoomPath = "/userapi/queryServerNoticeRoom" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -337,3 +339,19 @@ func (h *httpUserInternalAPI) PerformUpdatePolicyVersion(ctx context.Context, re apiURL := h.apiURL + PerformUpdatePolicyVersionPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpUserInternalAPI) SelectServerNoticeRoomID(ctx context.Context, req *api.QueryServerNoticeRoomRequest, res *api.QueryServerNoticeRoomResponse) (err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "SelectServerNoticeRoomID") + defer span.Finish() + + apiURL := h.apiURL + QueryServerNoticeRoomPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) UpdateServerNoticeRoomID(ctx context.Context, req *api.UpdateServerNoticeRoomRequest, res *api.UpdateServerNoticeRoomResponse) (err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "UpdateServerNoticeRoomID") + defer span.Finish() + + apiURL := h.apiURL + PerformUpdateServerNoticeRoomPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} \ No newline at end of file diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 66e2c6fd0..c81dce19a 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -389,4 +389,32 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryServerNoticeRoomPath, + httputil.MakeInternalAPI("queryServerNoticeRoom", func(req *http.Request) util.JSONResponse { + request := api.QueryServerNoticeRoomRequest{} + response := api.QueryServerNoticeRoomResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + err := s.SelectServerNoticeRoomID(req.Context(), &request, &response) + if err != nil { + return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response} + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(PerformUpdateServerNoticeRoomPath, + httputil.MakeInternalAPI("performUpdateServerNoticeRoom", func(req *http.Request) util.JSONResponse { + request := api.UpdateServerNoticeRoomRequest{} + response := api.UpdateServerNoticeRoomResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + err := s.UpdateServerNoticeRoomID(req.Context(), &request, &response) + if err != nil { + return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response} + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index aca669820..ad99a930f 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -55,6 +55,8 @@ type Database interface { GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) error + SelectServerNoticeRoomID(ctx context.Context, localpart string) (roomID string, err error) + UpdateServerNoticeRoomID(ctx context.Context, localpart, roomID string) (err error) // Key backups CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error) diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 22ef83274..097fe8998 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -47,7 +47,8 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- The policy version this user has accepted policy_version TEXT, -- The policy version the user received from the server notices room - policy_version_sent TEXT + policy_version_sent TEXT, + server_notice_room_id TEXT -- TODO: -- upgraded_ts, devices, any email reset stuff? ); @@ -85,6 +86,12 @@ const updatePolicyVersionSQL = "" + const updatePolicyVersionServerNoticeSQL = "" + "UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2" +const selectServerNoticeRoomSQL = "" + + "SELECT server_notice_room_id FROM account_accounts WHERE localpart = $1" + +const updateServerNoticeRoomSQL = "" + + "UPDATE account_accounts SET server_notice_room_id = $1 WHERE localpart = $2" + type accountsStatements struct { insertAccountStmt *sql.Stmt updatePasswordStmt *sql.Stmt @@ -96,6 +103,8 @@ type accountsStatements struct { batchSelectPrivacyPolicyStmt *sql.Stmt updatePolicyVersionStmt *sql.Stmt updatePolicyVersionServerNoticeStmt *sql.Stmt + selectServerNoticeRoomStmt *sql.Stmt + updateServerNoticeRoomStmt *sql.Stmt serverName gomatrixserverlib.ServerName } @@ -118,6 +127,8 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam {&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL}, {&s.updatePolicyVersionStmt, updatePolicyVersionSQL}, {&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL}, + {&s.selectServerNoticeRoomStmt, selectServerNoticeRoomSQL}, + {&s.updateServerNoticeRoomStmt, updateServerNoticeRoomSQL}, }.Prepare(db) } @@ -130,12 +141,7 @@ func (s *accountsStatements) InsertAccount( createdTimeMS := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) - var err error - if accountType != api.AccountTypeAppService { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion) - } else { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, "") - } + _, err := stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion) if err != nil { return nil, err } @@ -251,3 +257,36 @@ func (s *accountsStatements) UpdatePolicyVersion( _, err = stmt.ExecContext(ctx, policyVersion, localpart) return err } + +// SelectServerNoticeRoomID queries the server notice room ID. +func (s *accountsStatements) SelectServerNoticeRoomID( + ctx context.Context, txn *sql.Tx, localpart string, +) (roomID string, err error) { + stmt := s.selectServerNoticeRoomStmt + if txn != nil { + stmt = sqlutil.TxStmt(txn, stmt) + } + + roomIDNull := sql.NullString{} + row := stmt.QueryRowContext(ctx, localpart) + err = row.Scan(&roomIDNull) + if err != nil { + return "", err + } + if roomIDNull.Valid { + return roomIDNull.String, nil + } + return "", nil +} + +// UpdateServerNoticeRoomID sets the server notice room ID. +func (s *accountsStatements) UpdateServerNoticeRoomID( + ctx context.Context, txn *sql.Tx, localpart, roomID string, +) (err error) { + stmt := s.updateServerNoticeRoomStmt + if txn != nil { + stmt = sqlutil.TxStmt(txn, stmt) + } + _, err = stmt.ExecContext(ctx, roomID, localpart) + return +} \ No newline at end of file diff --git a/userapi/storage/postgres/deltas/2022021414375800_add_policy_version.go b/userapi/storage/postgres/deltas/2022021414375800_add_policy_version.go index 27347d21c..1638fb4fe 100644 --- a/userapi/storage/postgres/deltas/2022021414375800_add_policy_version.go +++ b/userapi/storage/postgres/deltas/2022021414375800_add_policy_version.go @@ -20,13 +20,24 @@ func UpAddPolicyVersion(tx *sql.Tx) error { if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } + _, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS server_notice_room_id TEXT;") + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } return nil } func DownAddPolicyVersion(tx *sql.Tx) error { - _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;" + - "ALTER TABLE account_accounts DROP COLUMN policy_version_sent;") + _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;") + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + _, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version_sent;") + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + _, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN server_notice_room_id;") if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 262531dcb..605b90cd3 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -791,3 +791,16 @@ func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, local }) return } + +// SelectServerNoticeRoomID returns the server notice room, if one is set. +func (d *Database) SelectServerNoticeRoomID(ctx context.Context, localpart string) (roomID string, err error) { + return d.Accounts.SelectServerNoticeRoomID(ctx, nil, localpart) +} + +// UpdateServerNoticeRoomID updates the server notice room +func (d *Database) UpdateServerNoticeRoomID(ctx context.Context, localpart, roomID string) (err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Accounts.UpdateServerNoticeRoomID(ctx, txn, localpart, roomID) + }) + return +} \ No newline at end of file diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index cbd073d51..9f7e75226 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -47,7 +47,8 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- The policy version this user has accepted policy_version TEXT, -- The policy version the user received from the server notices room - policy_version_sent TEXT + policy_version_sent TEXT, + server_notice_room_id TEXT -- TODO: -- upgraded_ts, devices, any email reset stuff? ); @@ -83,6 +84,12 @@ const updatePolicyVersionSQL = "" + const updatePolicyVersionServerNoticeSQL = "" + "UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2" +const selectServerNoticeRoomSQL = "" + + "SELECT server_notice_room_id FROM account_accounts WHERE localpart = $1" + +const updateServerNoticeRoomSQL = "" + + "UPDATE account_accounts SET server_notice_room_id = $1 WHERE localpart = $2" + type accountsStatements struct { db *sql.DB insertAccountStmt *sql.Stmt @@ -95,6 +102,8 @@ type accountsStatements struct { batchSelectPrivacyPolicyStmt *sql.Stmt updatePolicyVersionStmt *sql.Stmt updatePolicyVersionServerNoticeStmt *sql.Stmt + selectServerNoticeRoomStmt *sql.Stmt + updateServerNoticeRoomStmt *sql.Stmt serverName gomatrixserverlib.ServerName } @@ -118,6 +127,8 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) {&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL}, {&s.updatePolicyVersionStmt, updatePolicyVersionSQL}, {&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL}, + {&s.selectServerNoticeRoomStmt, selectServerNoticeRoomSQL}, + {&s.updateServerNoticeRoomStmt, updateServerNoticeRoomSQL}, }.Prepare(db) } @@ -130,12 +141,7 @@ func (s *accountsStatements) InsertAccount( createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt - var err error - if accountType != api.AccountTypeAppService { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion) - } else { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, "") - } + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, policyVersion) if err != nil { return nil, err } @@ -254,3 +260,36 @@ func (s *accountsStatements) UpdatePolicyVersion( _, err = stmt.ExecContext(ctx, policyVersion, localpart) return err } + +// SelectServerNoticeRoomID queries the server notice room ID. +func (s *accountsStatements) SelectServerNoticeRoomID( + ctx context.Context, txn *sql.Tx, localpart string, +) (roomID string, err error) { + stmt := s.selectServerNoticeRoomStmt + if txn != nil { + stmt = sqlutil.TxStmt(txn, stmt) + } + + roomIDNull := sql.NullString{} + row := stmt.QueryRowContext(ctx, localpart) + err = row.Scan(&roomIDNull) + if err != nil { + return "", err + } + if roomIDNull.Valid { + return roomIDNull.String, nil + } + return "", nil +} + +// UpdateServerNoticeRoomID sets the server notice room ID. +func (s *accountsStatements) UpdateServerNoticeRoomID( + ctx context.Context, txn *sql.Tx, localpart, roomID string, +) (err error) { + stmt := s.updateServerNoticeRoomStmt + if txn != nil { + stmt = sqlutil.TxStmt(txn, stmt) + } + _, err = stmt.ExecContext(ctx, roomID, localpart) + return +} \ No newline at end of file diff --git a/userapi/storage/sqlite3/deltas/2022021414375800_add_policy_version.go b/userapi/storage/sqlite3/deltas/2022021414375800_add_policy_version.go index 683210ca7..251ec4e40 100644 --- a/userapi/storage/sqlite3/deltas/2022021414375800_add_policy_version.go +++ b/userapi/storage/sqlite3/deltas/2022021414375800_add_policy_version.go @@ -20,12 +20,23 @@ func UpAddPolicyVersion(tx *sql.Tx) error { if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } + _, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN server_notice_room_id TEXT;") + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } return nil } func DownAddPolicyVersion(tx *sql.Tx) error { - _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;" + - "ALTER TABLE account_accounts DROP COLUMN policy_version_sent;") + _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;") + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + _, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version_sent;") + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + _, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN server_notice_room_id;") if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index ef067ed07..caf41a91c 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -41,6 +41,8 @@ type AccountsTable interface { SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error) BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error) UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool) (err error) + SelectServerNoticeRoomID(ctx context.Context, txn *sql.Tx, localpart string) (roomID string, err error) + UpdateServerNoticeRoomID(ctx context.Context, txn *sql.Tx, localpart, roomID string) (err error) } type DevicesTable interface {