From 32538640db6bfbdb890729e8137151ffbbec9a28 Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 27 Jul 2021 12:47:32 +0100 Subject: [PATCH 1/7] Key backups (1/2) : Add E2E session backup metadata tables (#1943) * Initial key backup paths and userapi API * Fix unit tests * Add key backup table * Glue REST API to database * Linting * use writer on sqlite --- clientapi/routing/key_backup.go | 173 ++++++++++++++++++ clientapi/routing/routing.go | 42 +++++ setup/mscs/msc2836/msc2836_test.go | 4 + setup/mscs/msc2946/msc2946_test.go | 4 + sytest-whitelist | 1 + userapi/api/api.go | 33 ++++ userapi/internal/api.go | 54 ++++++ userapi/inthttp/client.go | 23 +++ userapi/storage/accounts/interface.go | 6 + .../postgres/key_backup_version_table.go | 144 +++++++++++++++ userapi/storage/accounts/postgres/storage.go | 43 +++++ .../sqlite3/key_backup_version_table.go | 142 ++++++++++++++ userapi/storage/accounts/sqlite3/storage.go | 43 +++++ 13 files changed, 712 insertions(+) create mode 100644 clientapi/routing/key_backup.go create mode 100644 userapi/storage/accounts/postgres/key_backup_version_table.go create mode 100644 userapi/storage/accounts/sqlite3/key_backup_version_table.go diff --git a/clientapi/routing/key_backup.go b/clientapi/routing/key_backup.go new file mode 100644 index 000000000..0aa736245 --- /dev/null +++ b/clientapi/routing/key_backup.go @@ -0,0 +1,173 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +type keyBackupVersion struct { + Algorithm string `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` +} + +type keyBackupVersionCreateResponse struct { + Version string `json:"version"` +} + +type keyBackupVersionResponse struct { + Algorithm string `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` + Count int `json:"count"` + ETag string `json:"etag"` + Version string `json:"version"` +} + +// Create a new key backup. Request must contain a `keyBackupVersion`. Returns a `keyBackupVersionCreateResponse`. +// Implements POST /_matrix/client/r0/room_keys/version +func CreateKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device) util.JSONResponse { + var kb keyBackupVersion + resErr := httputil.UnmarshalJSONRequest(req, &kb) + if resErr != nil { + return *resErr + } + var performKeyBackupResp userapi.PerformKeyBackupResponse + userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + UserID: device.UserID, + Version: "", + AuthData: kb.AuthData, + Algorithm: kb.Algorithm, + }, &performKeyBackupResp) + if performKeyBackupResp.Error != "" { + if performKeyBackupResp.BadInput { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), + } + } + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + } + return util.JSONResponse{ + Code: 200, + JSON: keyBackupVersionCreateResponse{ + Version: performKeyBackupResp.Version, + }, + } +} + +// KeyBackupVersion returns the key backup version specified. If `version` is empty, the latest `keyBackupVersionResponse` is returned. +// Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version} +func KeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { + var queryResp userapi.QueryKeyBackupResponse + userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{}, &queryResp) + if queryResp.Error != "" { + return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) + } + if !queryResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("version not found"), + } + } + return util.JSONResponse{ + Code: 200, + JSON: keyBackupVersionResponse{ + Algorithm: queryResp.Algorithm, + AuthData: queryResp.AuthData, + Count: queryResp.Count, + ETag: queryResp.ETag, + Version: queryResp.Version, + }, + } +} + +// Modify the auth data of a key backup. Version must not be empty. Request must contain a `keyBackupVersion` +// Implements PUT /_matrix/client/r0/room_keys/version/{version} +func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { + var kb keyBackupVersion + resErr := httputil.UnmarshalJSONRequest(req, &kb) + if resErr != nil { + return *resErr + } + var performKeyBackupResp userapi.PerformKeyBackupResponse + userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + UserID: device.UserID, + Version: version, + AuthData: kb.AuthData, + Algorithm: kb.Algorithm, + }, &performKeyBackupResp) + if performKeyBackupResp.Error != "" { + if performKeyBackupResp.BadInput { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), + } + } + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + } + if !performKeyBackupResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("backup version not found"), + } + } + // Unclear what the 200 body should be + return util.JSONResponse{ + Code: 200, + JSON: keyBackupVersionCreateResponse{ + Version: performKeyBackupResp.Version, + }, + } +} + +// Delete a version of key backup. Version must not be empty. If the key backup was previously deleted, will return 200 OK. +// Implements DELETE /_matrix/client/r0/room_keys/version/{version} +func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { + var performKeyBackupResp userapi.PerformKeyBackupResponse + userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + UserID: device.UserID, + Version: version, + DeleteBackup: true, + }, &performKeyBackupResp) + if performKeyBackupResp.Error != "" { + if performKeyBackupResp.BadInput { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), + } + } + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + } + if !performKeyBackupResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("backup version not found"), + } + } + // Unclear what the 200 body should be + return util.JSONResponse{ + Code: 200, + JSON: keyBackupVersionCreateResponse{ + Version: performKeyBackupResp.Version, + }, + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index d768247a4..194ab2997 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -896,6 +896,48 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) + // Key Backup Versions + r0mux.Handle("/room_keys/version/{versionID}", + httputil.MakeAuthAPI("get_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + version := req.URL.Query().Get("version") + return KeyBackupVersion(req, userAPI, device, version) + }), + ).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/version", + httputil.MakeAuthAPI("get_latest_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return KeyBackupVersion(req, userAPI, device, "") + }), + ).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/version/{versionID}", + httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + version := req.URL.Query().Get("version") + if version == "" { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("version must be specified"), + } + } + return ModifyKeyBackupVersionAuthData(req, userAPI, device, version) + }), + ).Methods(http.MethodPut) + r0mux.Handle("/room_keys/version/{versionID}", + httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + version := req.URL.Query().Get("version") + if version == "" { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("version must be specified"), + } + } + return DeleteKeyBackupVersion(req, userAPI, device, version) + }), + ).Methods(http.MethodDelete) + r0mux.Handle("/room_keys/version", + httputil.MakeAuthAPI("post_new_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return CreateKeyBackupVersion(req, userAPI, device) + }), + ).Methods(http.MethodPost, http.MethodOptions) + // Supplying a device ID is deprecated. r0mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 79aaebc0b..1bbb485cf 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -554,6 +554,10 @@ func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.Quer func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOpenIDTokenRequest, res *userapi.QueryOpenIDTokenResponse) error { return nil } +func (u *testUserAPI) PerformKeyBackup(ctx context.Context, req *userapi.PerformKeyBackupRequest, res *userapi.PerformKeyBackupResponse) { +} +func (u *testUserAPI) QueryKeyBackup(ctx context.Context, req *userapi.QueryKeyBackupRequest, res *userapi.QueryKeyBackupResponse) { +} type testRoomserverAPI struct { // use a trace API as it implements method stubs so we don't need to have them here. diff --git a/setup/mscs/msc2946/msc2946_test.go b/setup/mscs/msc2946/msc2946_test.go index 96160c10d..2c195d128 100644 --- a/setup/mscs/msc2946/msc2946_test.go +++ b/setup/mscs/msc2946/msc2946_test.go @@ -373,6 +373,10 @@ func (u *testUserAPI) PerformOpenIDTokenCreation(ctx context.Context, req *usera func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error { return nil } +func (u *testUserAPI) PerformKeyBackup(ctx context.Context, req *userapi.PerformKeyBackupRequest, res *userapi.PerformKeyBackupResponse) { +} +func (u *testUserAPI) QueryKeyBackup(ctx context.Context, req *userapi.QueryKeyBackupRequest, res *userapi.QueryKeyBackupResponse) { +} func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error { dev, ok := u.accessTokens[req.AccessToken] if !ok { diff --git a/sytest-whitelist b/sytest-whitelist index d90ba4fb9..e2bb41da4 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -540,3 +540,4 @@ Key notary server must not overwrite a valid key with a spurious result from the GET /rooms/:room_id/aliases lists aliases Only room members can list aliases of a room Users with sufficient power-level can delete other's aliases +Can create more than 10 backup versions diff --git a/userapi/api/api.go b/userapi/api/api.go index 407350123..ff10bfcd8 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -33,6 +33,8 @@ type UserInternalAPI interface { PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error + PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) + QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error @@ -42,6 +44,37 @@ type UserInternalAPI interface { QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error } +type PerformKeyBackupRequest struct { + UserID string + Version string // optional if modifying a key backup + AuthData json.RawMessage + Algorithm string + DeleteBackup bool // if true will delete the backup based on 'Version'. +} + +type PerformKeyBackupResponse struct { + Error string // set if there was a problem performing the request + BadInput bool // if set, the Error was due to bad input (HTTP 400) + Exists bool // set to true if the Version exists + Version string +} + +type QueryKeyBackupRequest struct { + UserID string + Version string // the version to query, if blank it means the latest +} + +type QueryKeyBackupResponse struct { + Error string + Exists bool + + Algorithm string `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` + Count int `json:"count"` + ETag string `json:"etag"` + Version string `json:"version"` +} + // InputAccountDataRequest is the request for InputAccountData type InputAccountDataRequest struct { UserID string // required: the user to set account data for diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 21933c1c4..9ff69298a 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -442,3 +442,57 @@ func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOp return nil } + +func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { + // Delete + if req.DeleteBackup { + if req.Version == "" { + res.BadInput = true + res.Error = "must specify a version to delete" + return + } + exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version) + if err != nil { + res.Error = fmt.Sprintf("failed to delete backup: %s", err) + } + res.Exists = exists + res.Version = req.Version + return + } + // Create + if req.Version == "" { + version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) + if err != nil { + res.Error = fmt.Sprintf("failed to create backup: %s", err) + } + res.Exists = err == nil + res.Version = version + return + } + // Update + err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) + if err != nil { + res.Error = fmt.Sprintf("failed to update backup: %s", err) + } + res.Version = req.Version +} + +func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { + version, algorithm, authData, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version) + res.Version = version + if err != nil { + if err == sql.ErrNoRows { + res.Exists = false + return + } + res.Error = fmt.Sprintf("failed to query key backup: %s", err) + return + } + res.Algorithm = algorithm + res.AuthData = authData + res.Exists = !deleted + + // TODO: + res.Count = 0 + res.ETag = "" +} diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 1cb5ef0a8..a89d1a266 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -36,7 +36,9 @@ const ( PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" + PerformKeyBackupPath = "/userapi/performKeyBackup" + QueryKeyBackupPath = "/userapi/queryKeyBackup" QueryProfilePath = "/userapi/queryProfile" QueryAccessTokenPath = "/userapi/queryAccessToken" QueryDevicesPath = "/userapi/queryDevices" @@ -225,3 +227,24 @@ func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.Que apiURL := h.apiURL + QueryOpenIDTokenPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpUserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformKeyBackup") + defer span.Finish() + + apiURL := h.apiURL + PerformKeyBackupPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + if err != nil { + res.Error = err.Error() + } +} +func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup") + defer span.Finish() + + apiURL := h.apiURL + QueryKeyBackupPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + if err != nil { + res.Error = err.Error() + } +} diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index 5aa61b909..88fdab481 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -54,6 +54,12 @@ type Database interface { DeactivateAccount(ctx context.Context, localpart string) (err error) CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error) GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) + + // Key backups + CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error) + UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error) + DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error) + GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/accounts/postgres/key_backup_version_table.go new file mode 100644 index 000000000..1b693e565 --- /dev/null +++ b/userapi/storage/accounts/postgres/key_backup_version_table.go @@ -0,0 +1,144 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strconv" +) + +const keyBackupVersionTableSchema = ` +CREATE SEQUENCE IF NOT EXISTS account_e2e_room_keys_versions_seq; + +-- the metadata for each generation of encrypted e2e session backups +CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( + user_id TEXT NOT NULL, + -- this means no 2 users will ever have the same version of e2e session backups which strictly + -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. + version BIGINT DEFAULT nextval('account_e2e_room_keys_versions_seq'), + algorithm TEXT NOT NULL, + auth_data TEXT NOT NULL, + deleted SMALLINT DEFAULT 0 NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); +` + +const insertKeyBackupSQL = "" + + "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data) VALUES ($1, $2, $3) RETURNING version" + +const updateKeyBackupAuthDataSQL = "" + // TODO: do we need to WHERE algorithm = $3 as well? + "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" + +const deleteKeyBackupSQL = "" + + "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" + +const selectKeyBackupSQL = "" + + "SELECT algorithm, auth_data, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + +const selectLatestVersionSQL = "" + + "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" + +type keyBackupVersionStatements struct { + insertKeyBackupStmt *sql.Stmt + updateKeyBackupAuthDataStmt *sql.Stmt + deleteKeyBackupStmt *sql.Stmt + selectKeyBackupStmt *sql.Stmt + selectLatestVersionStmt *sql.Stmt +} + +func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(keyBackupVersionTableSchema) + if err != nil { + return + } + if s.insertKeyBackupStmt, err = db.Prepare(insertKeyBackupSQL); err != nil { + return + } + if s.updateKeyBackupAuthDataStmt, err = db.Prepare(updateKeyBackupAuthDataSQL); err != nil { + return + } + if s.deleteKeyBackupStmt, err = db.Prepare(deleteKeyBackupSQL); err != nil { + return + } + if s.selectKeyBackupStmt, err = db.Prepare(selectKeyBackupSQL); err != nil { + return + } + if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil { + return + } + return +} + +func (s *keyBackupVersionStatements) insertKeyBackup( + ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, +) (version string, err error) { + var versionInt int64 + err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData)).Scan(&versionInt) + return strconv.FormatInt(versionInt, 10), err +} + +func (s *keyBackupVersionStatements) updateKeyBackupAuthData( + ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, +) error { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return fmt.Errorf("invalid version") + } + _, err = txn.Stmt(s.updateKeyBackupAuthDataStmt).ExecContext(ctx, string(authData), userID, versionInt) + return err +} + +func (s *keyBackupVersionStatements) deleteKeyBackup( + ctx context.Context, txn *sql.Tx, userID, version string, +) (bool, error) { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return false, fmt.Errorf("invalid version") + } + result, err := txn.Stmt(s.deleteKeyBackupStmt).ExecContext(ctx, userID, versionInt) + if err != nil { + return false, err + } + ra, err := result.RowsAffected() + if err != nil { + return false, err + } + return ra == 1, nil +} + +func (s *keyBackupVersionStatements) selectKeyBackup( + ctx context.Context, txn *sql.Tx, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) { + var versionInt int64 + if version == "" { + err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt) + } else { + versionInt, err = strconv.ParseInt(version, 10, 64) + } + if err != nil { + return + } + versionResult = strconv.FormatInt(versionInt, 10) + var deletedInt int + var authDataStr string + err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &deletedInt) + deleted = deletedInt == 1 + authData = json.RawMessage(authDataStr) + return +} diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index c5e74ed15..719e9878c 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -45,6 +45,7 @@ type Database struct { accountDatas accountDataStatements threepids threepidStatements openIDTokens tokenStatements + keyBackups keyBackupVersionStatements serverName gomatrixserverlib.ServerName bcryptCost int openIDTokenLifetimeMS int64 @@ -93,6 +94,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.openIDTokens.prepare(db, serverName); err != nil { return nil, err } + if err = d.keyBackups.prepare(db); err != nil { + return nil, err + } return d, nil } @@ -368,3 +372,42 @@ func (d *Database) GetOpenIDTokenAttributes( ) (*api.OpenIDTokenAttributes, error) { return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) } + +func (d *Database) CreateKeyBackup( + ctx context.Context, userID, algorithm string, authData json.RawMessage, +) (version string, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + version, err = d.keyBackups.insertKeyBackup(ctx, txn, userID, algorithm, authData) + return err + }) + return +} + +func (d *Database) UpdateKeyBackupAuthData( + ctx context.Context, userID, version string, authData json.RawMessage, +) (err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.keyBackups.updateKeyBackupAuthData(ctx, txn, userID, version, authData) + }) + return +} + +func (d *Database) DeleteKeyBackup( + ctx context.Context, userID, version string, +) (exists bool, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + exists, err = d.keyBackups.deleteKeyBackup(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) GetKeyBackup( + ctx context.Context, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + versionResult, algorithm, authData, deleted, err = d.keyBackups.selectKeyBackup(ctx, txn, userID, version) + return err + }) + return +} diff --git a/userapi/storage/accounts/sqlite3/key_backup_version_table.go b/userapi/storage/accounts/sqlite3/key_backup_version_table.go new file mode 100644 index 000000000..3e85705cf --- /dev/null +++ b/userapi/storage/accounts/sqlite3/key_backup_version_table.go @@ -0,0 +1,142 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strconv" +) + +const keyBackupVersionTableSchema = ` +-- the metadata for each generation of encrypted e2e session backups +CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( + user_id TEXT NOT NULL, + -- this means no 2 users will ever have the same version of e2e session backups which strictly + -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. + version INTEGER PRIMARY KEY AUTOINCREMENT, + algorithm TEXT NOT NULL, + auth_data TEXT NOT NULL, + deleted INTEGER DEFAULT 0 NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); +` + +const insertKeyBackupSQL = "" + + "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data) VALUES ($1, $2, $3) RETURNING version" + +const updateKeyBackupAuthDataSQL = "" + // TODO: do we need to WHERE algorithm = $3 as well? + "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" + +const deleteKeyBackupSQL = "" + + "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" + +const selectKeyBackupSQL = "" + + "SELECT algorithm, auth_data, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + +const selectLatestVersionSQL = "" + + "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" + +type keyBackupVersionStatements struct { + insertKeyBackupStmt *sql.Stmt + updateKeyBackupAuthDataStmt *sql.Stmt + deleteKeyBackupStmt *sql.Stmt + selectKeyBackupStmt *sql.Stmt + selectLatestVersionStmt *sql.Stmt +} + +func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(keyBackupVersionTableSchema) + if err != nil { + return + } + if s.insertKeyBackupStmt, err = db.Prepare(insertKeyBackupSQL); err != nil { + return + } + if s.updateKeyBackupAuthDataStmt, err = db.Prepare(updateKeyBackupAuthDataSQL); err != nil { + return + } + if s.deleteKeyBackupStmt, err = db.Prepare(deleteKeyBackupSQL); err != nil { + return + } + if s.selectKeyBackupStmt, err = db.Prepare(selectKeyBackupSQL); err != nil { + return + } + if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil { + return + } + return +} + +func (s *keyBackupVersionStatements) insertKeyBackup( + ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, +) (version string, err error) { + var versionInt int64 + err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData)).Scan(&versionInt) + return strconv.FormatInt(versionInt, 10), err +} + +func (s *keyBackupVersionStatements) updateKeyBackupAuthData( + ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, +) error { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return fmt.Errorf("invalid version") + } + _, err = txn.Stmt(s.updateKeyBackupAuthDataStmt).ExecContext(ctx, string(authData), userID, versionInt) + return err +} + +func (s *keyBackupVersionStatements) deleteKeyBackup( + ctx context.Context, txn *sql.Tx, userID, version string, +) (bool, error) { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return false, fmt.Errorf("invalid version") + } + result, err := txn.Stmt(s.deleteKeyBackupStmt).ExecContext(ctx, userID, versionInt) + if err != nil { + return false, err + } + ra, err := result.RowsAffected() + if err != nil { + return false, err + } + return ra == 1, nil +} + +func (s *keyBackupVersionStatements) selectKeyBackup( + ctx context.Context, txn *sql.Tx, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) { + var versionInt int64 + if version == "" { + err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt) + } else { + versionInt, err = strconv.ParseInt(version, 10, 64) + } + if err != nil { + return + } + versionResult = strconv.FormatInt(versionInt, 10) + var deletedInt int + var authDataStr string + err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &deletedInt) + deleted = deletedInt == 1 + authData = json.RawMessage(authDataStr) + return +} diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index c0f7118cb..99e016ffd 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -43,6 +43,7 @@ type Database struct { accountDatas accountDataStatements threepids threepidStatements openIDTokens tokenStatements + keyBackups keyBackupVersionStatements serverName gomatrixserverlib.ServerName bcryptCost int openIDTokenLifetimeMS int64 @@ -97,6 +98,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.openIDTokens.prepare(db, serverName); err != nil { return nil, err } + if err = d.keyBackups.prepare(db); err != nil { + return nil, err + } return d, nil } @@ -406,3 +410,42 @@ func (d *Database) GetOpenIDTokenAttributes( ) (*api.OpenIDTokenAttributes, error) { return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) } + +func (d *Database) CreateKeyBackup( + ctx context.Context, userID, algorithm string, authData json.RawMessage, +) (version string, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + version, err = d.keyBackups.insertKeyBackup(ctx, txn, userID, algorithm, authData) + return err + }) + return +} + +func (d *Database) UpdateKeyBackupAuthData( + ctx context.Context, userID, version string, authData json.RawMessage, +) (err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.keyBackups.updateKeyBackupAuthData(ctx, txn, userID, version, authData) + }) + return +} + +func (d *Database) DeleteKeyBackup( + ctx context.Context, userID, version string, +) (exists bool, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + exists, err = d.keyBackups.deleteKeyBackup(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) GetKeyBackup( + ctx context.Context, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + versionResult, algorithm, authData, deleted, err = d.keyBackups.selectKeyBackup(ctx, txn, userID, version) + return err + }) + return +} From a060df91e206903e4e3cbf7b7d2dabddfa0bf788 Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 27 Jul 2021 12:47:50 +0100 Subject: [PATCH 2/7] Use db writer on sqlite account table (#1944) --- userapi/storage/accounts/sqlite3/storage.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 99e016ffd..b10f25ade 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -160,8 +160,9 @@ func (d *Database) SetPassword( if err != nil { return err } - err = d.accounts.updatePassword(ctx, localpart, hash) - return err + return d.writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.accounts.updatePassword(ctx, localpart, hash) + }) } // CreateGuestAccount makes a new guest account and creates an empty profile @@ -388,7 +389,9 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi // DeactivateAccount deactivates the user's account, removing all ability for the user to login again. func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { - return d.accounts.deactivateAccount(ctx, localpart) + return d.writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.accounts.deactivateAccount(ctx, localpart) + }) } // CreateOpenIDToken persists a new token that was issued for OpenID Connect From b3754d68fcbe9022eb0bf4f8eda7102b7c27e62d Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 27 Jul 2021 17:08:53 +0100 Subject: [PATCH 3/7] Key Backups (2/3) : Add E2E backup key tables (#1945) * Add PUT key backup endpoints and glue them to PerformKeyBackup * Add tables for storing backup keys and glue them into the user API * Don't create tables whilst still WIPing * writer on sqlite please * Linting --- clientapi/routing/key_backup.go | 45 ++++++ clientapi/routing/routing.go | 79 +++++++++++ sytest-whitelist | 1 - userapi/api/api.go | 30 +++- userapi/internal/api.go | 57 ++++++-- userapi/storage/accounts/interface.go | 3 +- .../accounts/postgres/key_backup_table.go | 134 ++++++++++++++++++ .../postgres/key_backup_version_table.go | 34 ++++- userapi/storage/accounts/postgres/storage.go | 115 +++++++++++++-- .../accounts/sqlite3/key_backup_table.go | 134 ++++++++++++++++++ .../sqlite3/key_backup_version_table.go | 34 ++++- userapi/storage/accounts/sqlite3/storage.go | 116 +++++++++++++-- 12 files changed, 737 insertions(+), 45 deletions(-) create mode 100644 userapi/storage/accounts/postgres/key_backup_table.go create mode 100644 userapi/storage/accounts/sqlite3/key_backup_table.go diff --git a/clientapi/routing/key_backup.go b/clientapi/routing/key_backup.go index 0aa736245..dd21d482e 100644 --- a/clientapi/routing/key_backup.go +++ b/clientapi/routing/key_backup.go @@ -42,6 +42,17 @@ type keyBackupVersionResponse struct { Version string `json:"version"` } +type keyBackupSessionRequest struct { + Rooms map[string]struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + } `json:"rooms"` +} + +type keyBackupSessionResponse struct { + Count int64 `json:"count"` + ETag string `json:"etag"` +} + // Create a new key backup. Request must contain a `keyBackupVersion`. Returns a `keyBackupVersionCreateResponse`. // Implements POST /_matrix/client/r0/room_keys/version func CreateKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device) util.JSONResponse { @@ -171,3 +182,37 @@ func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, }, } } + +// Upload a bunch of session keys for a given `version`. +func UploadBackupKeys( + req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string, keys *keyBackupSessionRequest, +) util.JSONResponse { + var performKeyBackupResp userapi.PerformKeyBackupResponse + userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + UserID: device.UserID, + Version: version, + Keys: *keys, + }, &performKeyBackupResp) + if performKeyBackupResp.Error != "" { + if performKeyBackupResp.BadInput { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), + } + } + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + } + if !performKeyBackupResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("backup version not found"), + } + } + return util.JSONResponse{ + Code: 200, + JSON: keyBackupSessionResponse{ + Count: performKeyBackupResp.KeyCount, + ETag: performKeyBackupResp.KeyETag, + }, + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 194ab2997..3e0c53ee4 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -938,6 +938,85 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) + // E2E Backup Keys + // Bulk room and session + r0mux.Handle("/room_keys/keys", + httputil.MakeAuthAPI("put_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + version := req.URL.Query().Get("version") + if version == "" { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("version must be specified"), + } + } + var reqBody keyBackupSessionRequest + resErr := clientutil.UnmarshalJSONRequest(req, &reqBody) + if resErr != nil { + return *resErr + } + return UploadBackupKeys(req, userAPI, device, version, &reqBody) + }), + ).Methods(http.MethodPut) + // Single room bulk session + r0mux.Handle("/room_keys/keys/{roomID}", + httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + version := req.URL.Query().Get("version") + if version == "" { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("version must be specified"), + } + } + roomID := vars["roomID"] + var reqBody keyBackupSessionRequest + reqBody.Rooms[roomID] = struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }{ + Sessions: map[string]userapi.KeyBackupSession{}, + } + body := reqBody.Rooms[roomID] + resErr := clientutil.UnmarshalJSONRequest(req, &body) + if resErr != nil { + return *resErr + } + reqBody.Rooms[roomID] = body + return UploadBackupKeys(req, userAPI, device, version, &reqBody) + }), + ).Methods(http.MethodPut) + // Single room, single session + r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", + httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + version := req.URL.Query().Get("version") + if version == "" { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("version must be specified"), + } + } + var reqBody userapi.KeyBackupSession + resErr := clientutil.UnmarshalJSONRequest(req, &reqBody) + if resErr != nil { + return *resErr + } + roomID := vars["roomID"] + sessionID := vars["sessionID"] + var keyReq keyBackupSessionRequest + keyReq.Rooms[roomID] = struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }{} + keyReq.Rooms[roomID].Sessions[sessionID] = reqBody + return UploadBackupKeys(req, userAPI, device, version, &keyReq) + }), + ).Methods(http.MethodPut) + // Supplying a device ID is deprecated. r0mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/sytest-whitelist b/sytest-whitelist index e2bb41da4..d90ba4fb9 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -540,4 +540,3 @@ Key notary server must not overwrite a valid key with a spurious result from the GET /rooms/:room_id/aliases lists aliases Only room members can list aliases of a room Users with sufficient power-level can delete other's aliases -Can create more than 10 backup versions diff --git a/userapi/api/api.go b/userapi/api/api.go index ff10bfcd8..7e18d72f9 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -50,13 +50,39 @@ type PerformKeyBackupRequest struct { AuthData json.RawMessage Algorithm string DeleteBackup bool // if true will delete the backup based on 'Version'. + + // The keys to upload, if any. If blank, creates/updates/deletes key version metadata only. + Keys struct { + Rooms map[string]struct { + Sessions map[string]KeyBackupSession `json:"sessions"` + } `json:"rooms"` + } +} + +// KeyBackupData in https://spec.matrix.org/unstable/client-server-api/#get_matrixclientr0room_keyskeysroomidsessionid +type KeyBackupSession struct { + FirstMessageIndex int `json:"first_message_index"` + ForwardedCount int `json:"forwarded_count"` + IsVerified bool `json:"is_verified"` + SessionData json.RawMessage `json:"session_data"` +} + +// Internal KeyBackupData for passing to/from the storage layer +type InternalKeyBackupSession struct { + KeyBackupSession + RoomID string + SessionID string } type PerformKeyBackupResponse struct { Error string // set if there was a problem performing the request BadInput bool // if set, the Error was due to bad input (HTTP 400) - Exists bool // set to true if the Version exists - Version string + + Exists bool // set to true if the Version exists + Version string // the newly created version + + KeyCount int64 // only set if Keys were given in the request + KeyETag string // only set if Keys were given in the request } type QueryKeyBackupRequest struct { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 9ff69298a..27e179636 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -444,7 +444,7 @@ func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOp } func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { - // Delete + // Delete metadata if req.DeleteBackup { if req.Version == "" { res.BadInput = true @@ -459,7 +459,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform res.Version = req.Version return } - // Create + // Create metadata if req.Version == "" { version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) if err != nil { @@ -469,16 +469,55 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform res.Version = version return } - // Update - err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) - if err != nil { - res.Error = fmt.Sprintf("failed to update backup: %s", err) + // Update metadata + if len(req.Keys.Rooms) == 0 { + err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) + if err != nil { + res.Error = fmt.Sprintf("failed to update backup: %s", err) + } + res.Version = req.Version + return } - res.Version = req.Version + // Upload Keys for a specific version metadata + a.uploadBackupKeys(ctx, req, res) +} + +func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { + // ensure the version metadata exists + version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version) + if err != nil { + res.Error = fmt.Sprintf("failed to query version: %s", err) + return + } + if deleted { + res.Error = "backup was deleted" + return + } + res.Exists = true + res.Version = version + + // map keys to a form we can upload more easily - the map ensures we have no duplicates. + var uploads []api.InternalKeyBackupSession + for roomID, data := range req.Keys.Rooms { + for sessionID, sessionData := range data.Sessions { + uploads = append(uploads, api.InternalKeyBackupSession{ + RoomID: roomID, + SessionID: sessionID, + KeyBackupSession: sessionData, + }) + } + } + count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads) + if err != nil { + res.Error = fmt.Sprintf("failed to upsert keys: %s", err) + return + } + res.KeyCount = count + res.KeyETag = etag } func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { - version, algorithm, authData, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version) + version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version) res.Version = version if err != nil { if err == sql.ErrNoRows { @@ -494,5 +533,5 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB // TODO: res.Count = 0 - res.ETag = "" + res.ETag = etag } diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index 88fdab481..4fd9c1772 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -59,7 +59,8 @@ type Database interface { CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error) UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error) DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error) - GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) + GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) + UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/accounts/postgres/key_backup_table.go b/userapi/storage/accounts/postgres/key_backup_table.go new file mode 100644 index 000000000..0dc5879b6 --- /dev/null +++ b/userapi/storage/accounts/postgres/key_backup_table.go @@ -0,0 +1,134 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/userapi/api" +) + +const keyBackupTableSchema = ` +CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + + version TEXT NOT NULL, + first_message_index INTEGER NOT NULL, + forwarded_count INTEGER NOT NULL, + is_verified BOOLEAN NOT NULL, + session_data TEXT NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id); +` + +const insertBackupKeySQL = "" + + "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + +const updateBackupKeySQL = "" + + "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + + "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8" + +const countKeysSQL = "" + + "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2" + +const selectKeysSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2" + +type keyBackupStatements struct { + insertBackupKeyStmt *sql.Stmt + updateBackupKeyStmt *sql.Stmt + countKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt +} + +// nolint:unused +func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(keyBackupTableSchema) + if err != nil { + return + } + if s.insertBackupKeyStmt, err = db.Prepare(insertBackupKeySQL); err != nil { + return + } + if s.updateBackupKeyStmt, err = db.Prepare(updateBackupKeySQL); err != nil { + return + } + if s.countKeysStmt, err = db.Prepare(countKeysSQL); err != nil { + return + } + if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { + return + } + return +} + +func (s keyBackupStatements) countKeys( + ctx context.Context, txn *sql.Tx, userID, version string, +) (count int64, err error) { + err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) + return +} + +func (s *keyBackupStatements) insertBackupKey( + ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, +) (err error) { + _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( + ctx, userID, key.RoomID, key.SessionID, version, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), + ) + return +} + +func (s *keyBackupStatements) updateBackupKey( + ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, +) (err error) { + _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( + ctx, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), userID, key.RoomID, key.SessionID, version, + ) + return +} + +func (s *keyBackupStatements) selectKeys( + ctx context.Context, txn *sql.Tx, userID, version string, +) (map[string]map[string]api.KeyBackupSession, error) { + result := make(map[string]map[string]api.KeyBackupSession) + rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed") + for rows.Next() { + var key api.InternalKeyBackupSession + // room_id, session_id, first_message_index, forwarded_count, is_verified, session_data + var sessionDataStr string + if err := rows.Scan(&key.RoomID, &key.SessionID, &key.FirstMessageIndex, &key.ForwardedCount, &key.IsVerified, &sessionDataStr); err != nil { + return nil, err + } + key.SessionData = json.RawMessage(sessionDataStr) + roomData := result[key.RoomID] + if roomData == nil { + roomData = make(map[string]api.KeyBackupSession) + } + roomData[key.SessionID] = key.KeyBackupSession + result[key.RoomID] = roomData + } + return result, nil +} diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/accounts/postgres/key_backup_version_table.go index 1b693e565..323a842df 100644 --- a/userapi/storage/accounts/postgres/key_backup_version_table.go +++ b/userapi/storage/accounts/postgres/key_backup_version_table.go @@ -33,6 +33,7 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( version BIGINT DEFAULT nextval('account_e2e_room_keys_versions_seq'), algorithm TEXT NOT NULL, auth_data TEXT NOT NULL, + etag TEXT NOT NULL, deleted SMALLINT DEFAULT 0 NOT NULL ); @@ -40,16 +41,19 @@ CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_ ` const insertKeyBackupSQL = "" + - "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data) VALUES ($1, $2, $3) RETURNING version" + "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" -const updateKeyBackupAuthDataSQL = "" + // TODO: do we need to WHERE algorithm = $3 as well? +const updateKeyBackupAuthDataSQL = "" + "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" +const updateKeyBackupETagSQL = "" + + "UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3" + const deleteKeyBackupSQL = "" + "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" const selectKeyBackupSQL = "" + - "SELECT algorithm, auth_data, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" const selectLatestVersionSQL = "" + "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" @@ -60,8 +64,10 @@ type keyBackupVersionStatements struct { deleteKeyBackupStmt *sql.Stmt selectKeyBackupStmt *sql.Stmt selectLatestVersionStmt *sql.Stmt + updateKeyBackupETagStmt *sql.Stmt } +// nolint:unused func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { _, err = db.Exec(keyBackupVersionTableSchema) if err != nil { @@ -82,14 +88,17 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil { return } + if s.updateKeyBackupETagStmt, err = db.Prepare(updateKeyBackupETagSQL); err != nil { + return + } return } func (s *keyBackupVersionStatements) insertKeyBackup( - ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, + ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, ) (version string, err error) { var versionInt int64 - err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData)).Scan(&versionInt) + err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData), etag).Scan(&versionInt) return strconv.FormatInt(versionInt, 10), err } @@ -104,6 +113,17 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData( return err } +func (s *keyBackupVersionStatements) updateKeyBackupETag( + ctx context.Context, txn *sql.Tx, userID, version, etag string, +) error { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return fmt.Errorf("invalid version") + } + _, err = txn.Stmt(s.updateKeyBackupETagStmt).ExecContext(ctx, etag, userID, versionInt) + return err +} + func (s *keyBackupVersionStatements) deleteKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (bool, error) { @@ -124,7 +144,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup( func (s *keyBackupVersionStatements) selectKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, -) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) { +) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { var versionInt int64 if version == "" { err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt) @@ -137,7 +157,7 @@ func (s *keyBackupVersionStatements) selectKeyBackup( versionResult = strconv.FormatInt(versionInt, 10) var deletedInt int var authDataStr string - err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &deletedInt) + err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &etag, &deletedInt) deleted = deletedInt == 1 authData = json.RawMessage(authDataStr) return diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index 719e9878c..b07218b29 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -19,6 +19,7 @@ import ( "database/sql" "encoding/json" "errors" + "fmt" "strconv" "time" @@ -45,7 +46,8 @@ type Database struct { accountDatas accountDataStatements threepids threepidStatements openIDTokens tokenStatements - keyBackups keyBackupVersionStatements + keyBackupVersions keyBackupVersionStatements + keyBackups keyBackupStatements serverName gomatrixserverlib.ServerName bcryptCost int openIDTokenLifetimeMS int64 @@ -94,9 +96,13 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.openIDTokens.prepare(db, serverName); err != nil { return nil, err } - if err = d.keyBackups.prepare(db); err != nil { - return nil, err - } + /* + if err = d.keyBackupVersions.prepare(db); err != nil { + return nil, err + } + if err = d.keyBackups.prepare(db); err != nil { + return nil, err + } */ return d, nil } @@ -377,7 +383,7 @@ func (d *Database) CreateKeyBackup( ctx context.Context, userID, algorithm string, authData json.RawMessage, ) (version string, err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - version, err = d.keyBackups.insertKeyBackup(ctx, txn, userID, algorithm, authData) + version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "") return err }) return @@ -387,7 +393,7 @@ func (d *Database) UpdateKeyBackupAuthData( ctx context.Context, userID, version string, authData json.RawMessage, ) (err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.keyBackups.updateKeyBackupAuthData(ctx, txn, userID, version, authData) + return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData) }) return } @@ -396,7 +402,7 @@ func (d *Database) DeleteKeyBackup( ctx context.Context, userID, version string, ) (exists bool, err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - exists, err = d.keyBackups.deleteKeyBackup(ctx, txn, userID, version) + exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version) return err }) return @@ -404,10 +410,101 @@ func (d *Database) DeleteKeyBackup( func (d *Database) GetKeyBackup( ctx context.Context, userID, version string, -) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) { +) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - versionResult, algorithm, authData, deleted, err = d.keyBackups.selectKeyBackup(ctx, txn, userID, version) + versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) return err }) return } + +// nolint:nakedret +func (d *Database) UpsertBackupKeys( + ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, +) (count int64, etag string, err error) { + // wrap the following logic in a txn to ensure we atomically upload keys + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + _, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) + if err != nil { + return err + } + if deleted { + return fmt.Errorf("backup was deleted") + } + // pull out all keys for this (user_id, version) + existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version) + if err != nil { + return err + } + + changed := false + // loop over all the new keys (which should be smaller than the set of backed up keys) + for _, newKey := range uploads { + // if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them. + existingRoom := existingKeys[newKey.RoomID] + if existingRoom != nil { + existingSession, ok := existingRoom[newKey.SessionID] + if ok { + if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) { + err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return err + } + } + // if we shouldn't replace the key we do nothing with it + continue + } + } + // if we're here, either the room or session are new, either way, we insert + err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return err + } + } + + count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + if err != nil { + return err + } + if changed { + // update the etag + var newETag string + if oldETag == "" { + newETag = "1" + } else { + oldETagInt, err := strconv.ParseInt(oldETag, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse old etag: %s", err) + } + newETag = strconv.FormatInt(oldETagInt+1, 10) + } + etag = newETag + return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag) + } else { + etag = oldETag + } + return nil + }) + return +} + +// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't +// create circular import loops +func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool { + // https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 + // "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true" + if uploaded.IsVerified && !existing.IsVerified { + return true + } + // "if they have the same values for is_verified, then it will keep the key with a lower first_message_index" + if uploaded.FirstMessageIndex < existing.FirstMessageIndex { + return true + } + // "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count" + if uploaded.ForwardedCount < existing.ForwardedCount { + return true + } + return false +} diff --git a/userapi/storage/accounts/sqlite3/key_backup_table.go b/userapi/storage/accounts/sqlite3/key_backup_table.go new file mode 100644 index 000000000..268bda936 --- /dev/null +++ b/userapi/storage/accounts/sqlite3/key_backup_table.go @@ -0,0 +1,134 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/userapi/api" +) + +const keyBackupTableSchema = ` +CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + + version TEXT NOT NULL, + first_message_index INTEGER NOT NULL, + forwarded_count INTEGER NOT NULL, + is_verified BOOLEAN NOT NULL, + session_data TEXT NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id); +` + +const insertBackupKeySQL = "" + + "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + +const updateBackupKeySQL = "" + + "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + + "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8" + +const countKeysSQL = "" + + "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2" + +const selectKeysSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2" + +type keyBackupStatements struct { + insertBackupKeyStmt *sql.Stmt + updateBackupKeyStmt *sql.Stmt + countKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt +} + +// nolint:unused +func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(keyBackupTableSchema) + if err != nil { + return + } + if s.insertBackupKeyStmt, err = db.Prepare(insertBackupKeySQL); err != nil { + return + } + if s.updateBackupKeyStmt, err = db.Prepare(updateBackupKeySQL); err != nil { + return + } + if s.countKeysStmt, err = db.Prepare(countKeysSQL); err != nil { + return + } + if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { + return + } + return +} + +func (s keyBackupStatements) countKeys( + ctx context.Context, txn *sql.Tx, userID, version string, +) (count int64, err error) { + err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) + return +} + +func (s *keyBackupStatements) insertBackupKey( + ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, +) (err error) { + _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( + ctx, userID, key.RoomID, key.SessionID, version, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), + ) + return +} + +func (s *keyBackupStatements) updateBackupKey( + ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, +) (err error) { + _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( + ctx, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), userID, key.RoomID, key.SessionID, version, + ) + return +} + +func (s *keyBackupStatements) selectKeys( + ctx context.Context, txn *sql.Tx, userID, version string, +) (map[string]map[string]api.KeyBackupSession, error) { + result := make(map[string]map[string]api.KeyBackupSession) + rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed") + for rows.Next() { + var key api.InternalKeyBackupSession + // room_id, session_id, first_message_index, forwarded_count, is_verified, session_data + var sessionDataStr string + if err := rows.Scan(&key.RoomID, &key.SessionID, &key.FirstMessageIndex, &key.ForwardedCount, &key.IsVerified, &sessionDataStr); err != nil { + return nil, err + } + key.SessionData = json.RawMessage(sessionDataStr) + roomData := result[key.RoomID] + if roomData == nil { + roomData = make(map[string]api.KeyBackupSession) + } + roomData[key.SessionID] = key.KeyBackupSession + result[key.RoomID] = roomData + } + return result, nil +} diff --git a/userapi/storage/accounts/sqlite3/key_backup_version_table.go b/userapi/storage/accounts/sqlite3/key_backup_version_table.go index 3e85705cf..72e9b132e 100644 --- a/userapi/storage/accounts/sqlite3/key_backup_version_table.go +++ b/userapi/storage/accounts/sqlite3/key_backup_version_table.go @@ -31,6 +31,7 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( version INTEGER PRIMARY KEY AUTOINCREMENT, algorithm TEXT NOT NULL, auth_data TEXT NOT NULL, + etag TEXT NOT NULL, deleted INTEGER DEFAULT 0 NOT NULL ); @@ -38,16 +39,19 @@ CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_ ` const insertKeyBackupSQL = "" + - "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data) VALUES ($1, $2, $3) RETURNING version" + "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" -const updateKeyBackupAuthDataSQL = "" + // TODO: do we need to WHERE algorithm = $3 as well? +const updateKeyBackupAuthDataSQL = "" + "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" +const updateKeyBackupETagSQL = "" + + "UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3" + const deleteKeyBackupSQL = "" + "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" const selectKeyBackupSQL = "" + - "SELECT algorithm, auth_data, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" const selectLatestVersionSQL = "" + "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" @@ -58,8 +62,10 @@ type keyBackupVersionStatements struct { deleteKeyBackupStmt *sql.Stmt selectKeyBackupStmt *sql.Stmt selectLatestVersionStmt *sql.Stmt + updateKeyBackupETagStmt *sql.Stmt } +// nolint:unused func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { _, err = db.Exec(keyBackupVersionTableSchema) if err != nil { @@ -80,14 +86,17 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil { return } + if s.updateKeyBackupETagStmt, err = db.Prepare(updateKeyBackupETagSQL); err != nil { + return + } return } func (s *keyBackupVersionStatements) insertKeyBackup( - ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, + ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, ) (version string, err error) { var versionInt int64 - err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData)).Scan(&versionInt) + err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData), etag).Scan(&versionInt) return strconv.FormatInt(versionInt, 10), err } @@ -102,6 +111,17 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData( return err } +func (s *keyBackupVersionStatements) updateKeyBackupETag( + ctx context.Context, txn *sql.Tx, userID, version, etag string, +) error { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return fmt.Errorf("invalid version") + } + _, err = txn.Stmt(s.updateKeyBackupETagStmt).ExecContext(ctx, etag, userID, versionInt) + return err +} + func (s *keyBackupVersionStatements) deleteKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (bool, error) { @@ -122,7 +142,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup( func (s *keyBackupVersionStatements) selectKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, -) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) { +) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { var versionInt int64 if version == "" { err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt) @@ -135,7 +155,7 @@ func (s *keyBackupVersionStatements) selectKeyBackup( versionResult = strconv.FormatInt(versionInt, 10) var deletedInt int var authDataStr string - err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &deletedInt) + err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &etag, &deletedInt) deleted = deletedInt == 1 authData = json.RawMessage(authDataStr) return diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index b10f25ade..4fae621fa 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -19,6 +19,7 @@ import ( "database/sql" "encoding/json" "errors" + "fmt" "strconv" "sync" "time" @@ -43,7 +44,8 @@ type Database struct { accountDatas accountDataStatements threepids threepidStatements openIDTokens tokenStatements - keyBackups keyBackupVersionStatements + keyBackupVersions keyBackupVersionStatements + keyBackups keyBackupStatements serverName gomatrixserverlib.ServerName bcryptCost int openIDTokenLifetimeMS int64 @@ -98,9 +100,13 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.openIDTokens.prepare(db, serverName); err != nil { return nil, err } - if err = d.keyBackups.prepare(db); err != nil { - return nil, err - } + /* + if err = d.keyBackupVersions.prepare(db); err != nil { + return nil, err + } + if err = d.keyBackups.prepare(db); err != nil { + return nil, err + } */ return d, nil } @@ -418,7 +424,7 @@ func (d *Database) CreateKeyBackup( ctx context.Context, userID, algorithm string, authData json.RawMessage, ) (version string, err error) { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - version, err = d.keyBackups.insertKeyBackup(ctx, txn, userID, algorithm, authData) + version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "") return err }) return @@ -428,7 +434,7 @@ func (d *Database) UpdateKeyBackupAuthData( ctx context.Context, userID, version string, authData json.RawMessage, ) (err error) { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.keyBackups.updateKeyBackupAuthData(ctx, txn, userID, version, authData) + return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData) }) return } @@ -437,7 +443,7 @@ func (d *Database) DeleteKeyBackup( ctx context.Context, userID, version string, ) (exists bool, err error) { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - exists, err = d.keyBackups.deleteKeyBackup(ctx, txn, userID, version) + exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version) return err }) return @@ -445,10 +451,102 @@ func (d *Database) DeleteKeyBackup( func (d *Database) GetKeyBackup( ctx context.Context, userID, version string, -) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) { +) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - versionResult, algorithm, authData, deleted, err = d.keyBackups.selectKeyBackup(ctx, txn, userID, version) + versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) return err }) return } + +// nolint:nakedret +func (d *Database) UpsertBackupKeys( + ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, +) (count int64, etag string, err error) { + // wrap the following logic in a txn to ensure we atomically upload keys + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + _, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) + if err != nil { + return err + } + if deleted { + return fmt.Errorf("backup was deleted") + } + // pull out all keys for this (user_id, version) + existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version) + if err != nil { + return err + } + + changed := false + // loop over all the new keys (which should be smaller than the set of backed up keys) + for _, newKey := range uploads { + // if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them. + existingRoom := existingKeys[newKey.RoomID] + if existingRoom != nil { + existingSession, ok := existingRoom[newKey.SessionID] + if ok { + if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) { + err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return err + } + } + // if we shouldn't replace the key we do nothing with it + continue + } + } + // if we're here, either the room or session are new, either way, we insert + err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return err + } + } + + count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + if err != nil { + return err + } + if changed { + // update the etag + var newETag string + if oldETag == "" { + newETag = "1" + } else { + oldETagInt, err := strconv.ParseInt(oldETag, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse old etag: %s", err) + } + newETag = strconv.FormatInt(oldETagInt+1, 10) + } + etag = newETag + return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag) + } else { + etag = oldETag + } + + return nil + }) + return +} + +// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't +// create circular import loops +func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool { + // https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 + // "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true" + if uploaded.IsVerified && !existing.IsVerified { + return true + } + // "if they have the same values for is_verified, then it will keep the key with a lower first_message_index" + if uploaded.FirstMessageIndex < existing.FirstMessageIndex { + return true + } + // "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count" + if uploaded.ForwardedCount < existing.ForwardedCount { + return true + } + return false +} From 32bf14a37c79a02fc5e76a7071d7494bb362be53 Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 27 Jul 2021 19:29:32 +0100 Subject: [PATCH 4/7] Key Backups (3/3) : Implement querying keys and various bugfixes (#1946) * Add querying device keys Makes a bunch of sytests pass * Apparently only the current version supports uploading keys * Linting --- clientapi/routing/key_backup.go | 77 +++++++++++++++++- clientapi/routing/routing.go | 81 +++++++++++++------ sytest-whitelist | 10 +++ userapi/api/api.go | 25 +++++- userapi/internal/api.go | 28 +++++-- userapi/storage/accounts/interface.go | 2 + .../accounts/postgres/key_backup_table.go | 54 +++++++++++-- .../postgres/key_backup_version_table.go | 1 - userapi/storage/accounts/postgres/storage.go | 65 ++++++++------- .../accounts/sqlite3/key_backup_table.go | 54 +++++++++++-- .../sqlite3/key_backup_version_table.go | 1 - userapi/storage/accounts/sqlite3/storage.go | 65 ++++++++------- 12 files changed, 362 insertions(+), 101 deletions(-) diff --git a/clientapi/routing/key_backup.go b/clientapi/routing/key_backup.go index dd21d482e..ce62a047a 100644 --- a/clientapi/routing/key_backup.go +++ b/clientapi/routing/key_backup.go @@ -37,7 +37,7 @@ type keyBackupVersionCreateResponse struct { type keyBackupVersionResponse struct { Algorithm string `json:"algorithm"` AuthData json.RawMessage `json:"auth_data"` - Count int `json:"count"` + Count int64 `json:"count"` ETag string `json:"etag"` Version string `json:"version"` } @@ -89,7 +89,10 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, // Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version} func KeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { var queryResp userapi.QueryKeyBackupResponse - userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{}, &queryResp) + userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ + UserID: device.UserID, + Version: version, + }, &queryResp) if queryResp.Error != "" { return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) } @@ -216,3 +219,73 @@ func UploadBackupKeys( }, } } + +// Get keys from a given backup version. Response returned varies depending on if roomID and sessionID are set. +func GetBackupKeys( + req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version, roomID, sessionID string, +) util.JSONResponse { + var queryResp userapi.QueryKeyBackupResponse + userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ + UserID: device.UserID, + Version: version, + ReturnKeys: true, + KeysForRoomID: roomID, + KeysForSessionID: sessionID, + }, &queryResp) + if queryResp.Error != "" { + return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) + } + if !queryResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("version not found"), + } + } + if sessionID != "" { + // return the key itself if it was found + roomData, ok := queryResp.Keys[roomID] + if ok { + key, ok := roomData[sessionID] + if ok { + return util.JSONResponse{ + Code: 200, + JSON: key, + } + } + } + } else if roomID != "" { + roomData, ok := queryResp.Keys[roomID] + if ok { + // wrap response in "sessions" + return util.JSONResponse{ + Code: 200, + JSON: struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }{ + Sessions: roomData, + }, + } + } + } else { + // response is the same as the upload request + var resp keyBackupSessionRequest + resp.Rooms = make(map[string]struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }) + for roomID, roomData := range queryResp.Keys { + resp.Rooms[roomID] = struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }{ + Sessions: roomData, + } + } + return util.JSONResponse{ + Code: 200, + JSON: resp, + } + } + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("keys not found"), + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 3e0c53ee4..cfc6c47fb 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -896,11 +896,15 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - // Key Backup Versions - r0mux.Handle("/room_keys/version/{versionID}", + // Key Backup Versions (Metadata) + + r0mux.Handle("/room_keys/version/{version}", httputil.MakeAuthAPI("get_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - version := req.URL.Query().Get("version") - return KeyBackupVersion(req, userAPI, device, version) + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return KeyBackupVersion(req, userAPI, device, vars["version"]) }), ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/room_keys/version", @@ -908,28 +912,22 @@ func Setup( return KeyBackupVersion(req, userAPI, device, "") }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/room_keys/version/{versionID}", + r0mux.Handle("/room_keys/version/{version}", httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - version := req.URL.Query().Get("version") - if version == "" { - return util.JSONResponse{ - Code: 400, - JSON: jsonerror.InvalidArgumentValue("version must be specified"), - } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) } - return ModifyKeyBackupVersionAuthData(req, userAPI, device, version) + return ModifyKeyBackupVersionAuthData(req, userAPI, device, vars["version"]) }), ).Methods(http.MethodPut) - r0mux.Handle("/room_keys/version/{versionID}", + r0mux.Handle("/room_keys/version/{version}", httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - version := req.URL.Query().Get("version") - if version == "" { - return util.JSONResponse{ - Code: 400, - JSON: jsonerror.InvalidArgumentValue("version must be specified"), - } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) } - return DeleteKeyBackupVersion(req, userAPI, device, version) + return DeleteKeyBackupVersion(req, userAPI, device, vars["version"]) }), ).Methods(http.MethodDelete) r0mux.Handle("/room_keys/version", @@ -938,7 +936,8 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - // E2E Backup Keys + // Inserting E2E Backup Keys + // Bulk room and session r0mux.Handle("/room_keys/keys", httputil.MakeAuthAPI("put_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -973,6 +972,9 @@ func Setup( } roomID := vars["roomID"] var reqBody keyBackupSessionRequest + reqBody.Rooms = make(map[string]struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }) reqBody.Rooms[roomID] = struct { Sessions map[string]userapi.KeyBackupSession `json:"sessions"` }{ @@ -989,7 +991,7 @@ func Setup( ).Methods(http.MethodPut) // Single room, single session r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", - httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("put_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1009,14 +1011,47 @@ func Setup( roomID := vars["roomID"] sessionID := vars["sessionID"] var keyReq keyBackupSessionRequest + keyReq.Rooms = make(map[string]struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }) keyReq.Rooms[roomID] = struct { Sessions map[string]userapi.KeyBackupSession `json:"sessions"` - }{} + }{ + Sessions: make(map[string]userapi.KeyBackupSession), + } keyReq.Rooms[roomID].Sessions[sessionID] = reqBody return UploadBackupKeys(req, userAPI, device, version, &keyReq) }), ).Methods(http.MethodPut) + // Querying E2E Backup Keys + + r0mux.Handle("/room_keys/keys", + httputil.MakeAuthAPI("get_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), "", "") + }), + ).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/keys/{roomID}", + httputil.MakeAuthAPI("get_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], "") + }), + ).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", + httputil.MakeAuthAPI("get_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"]) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + // Deleting E2E Backup Keys + // Supplying a device ID is deprecated. r0mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/sytest-whitelist b/sytest-whitelist index d90ba4fb9..6bb3d770a 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -540,3 +540,13 @@ Key notary server must not overwrite a valid key with a spurious result from the GET /rooms/:room_id/aliases lists aliases Only room members can list aliases of a room Users with sufficient power-level can delete other's aliases +Can create backup version +Can update backup version +Responds correctly when backup is empty +Can backup keys +Can update keys with better versions +Will not update keys with worse versions +Will not back up to an old backup version +Can create more than 10 backup versions +Can delete backup +Deleted & recreated backups are empty diff --git a/userapi/api/api.go b/userapi/api/api.go index 7e18d72f9..b0d918560 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -67,6 +67,23 @@ type KeyBackupSession struct { SessionData json.RawMessage `json:"session_data"` } +func (a *KeyBackupSession) ShouldReplaceRoomKey(newKey *KeyBackupSession) bool { + // https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 + // "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true" + if newKey.IsVerified && !a.IsVerified { + return true + } + // "if they have the same values for is_verified, then it will keep the key with a lower first_message_index" + if newKey.FirstMessageIndex < a.FirstMessageIndex { + return true + } + // "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count" + if newKey.ForwardedCount < a.ForwardedCount { + return true + } + return false +} + // Internal KeyBackupData for passing to/from the storage layer type InternalKeyBackupSession struct { KeyBackupSession @@ -88,6 +105,10 @@ type PerformKeyBackupResponse struct { type QueryKeyBackupRequest struct { UserID string Version string // the version to query, if blank it means the latest + + ReturnKeys bool // whether to return keys in the backup response or just the metadata + KeysForRoomID string // optional string to return keys which belong to this room + KeysForSessionID string // optional string to return keys which belong to this (room, session) } type QueryKeyBackupResponse struct { @@ -96,9 +117,11 @@ type QueryKeyBackupResponse struct { Algorithm string `json:"algorithm"` AuthData json.RawMessage `json:"auth_data"` - Count int `json:"count"` + Count int64 `json:"count"` ETag string `json:"etag"` Version string `json:"version"` + + Keys map[string]map[string]KeyBackupSession // the keys if ReturnKeys=true } // InputAccountDataRequest is the request for InputAccountData diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 27e179636..a2bc8ecf5 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -475,6 +475,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform if err != nil { res.Error = fmt.Sprintf("failed to update backup: %s", err) } + res.Exists = err == nil res.Version = req.Version return } @@ -483,8 +484,8 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { - // ensure the version metadata exists - version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version) + // you can only upload keys for the CURRENT version + version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "") if err != nil { res.Error = fmt.Sprintf("failed to query version: %s", err) return @@ -493,6 +494,11 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform res.Error = "backup was deleted" return } + if version != req.Version { + res.BadInput = true + res.Error = fmt.Sprintf("%s isn't the current version, %s is.", req.Version, version) + return + } res.Exists = true res.Version = version @@ -529,9 +535,21 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB } res.Algorithm = algorithm res.AuthData = authData + res.ETag = etag res.Exists = !deleted - // TODO: - res.Count = 0 - res.ETag = etag + if !req.ReturnKeys { + res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID) + if err != nil { + res.Error = fmt.Sprintf("failed to count keys: %s", err) + } + return + } + + result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) + if err != nil { + res.Error = fmt.Sprintf("failed to query keys: %s", err) + return + } + res.Keys = result } diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index 4fd9c1772..887f7193b 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -61,6 +61,8 @@ type Database interface { DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error) GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error) + GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error) + CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/accounts/postgres/key_backup_table.go b/userapi/storage/accounts/postgres/key_backup_table.go index 0dc5879b6..ec6518267 100644 --- a/userapi/storage/accounts/postgres/key_backup_table.go +++ b/userapi/storage/accounts/postgres/key_backup_table.go @@ -35,7 +35,8 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( is_verified BOOLEAN NOT NULL, session_data TEXT NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id); +CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); +CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); ` const insertBackupKeySQL = "" + @@ -53,14 +54,23 @@ const selectKeysSQL = "" + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + "WHERE user_id = $1 AND version = $2" +const selectKeysByRoomIDSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2 AND room_id = $3" + +const selectKeysByRoomIDAndSessionIDSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" + type keyBackupStatements struct { - insertBackupKeyStmt *sql.Stmt - updateBackupKeyStmt *sql.Stmt - countKeysStmt *sql.Stmt - selectKeysStmt *sql.Stmt + insertBackupKeyStmt *sql.Stmt + updateBackupKeyStmt *sql.Stmt + countKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysByRoomIDStmt *sql.Stmt + selectKeysByRoomIDAndSessionIDStmt *sql.Stmt } -// nolint:unused func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { _, err = db.Exec(keyBackupTableSchema) if err != nil { @@ -78,6 +88,12 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { return } + if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil { + return + } + if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil { + return + } return } @@ -109,11 +125,35 @@ func (s *keyBackupStatements) updateBackupKey( func (s *keyBackupStatements) selectKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (map[string]map[string]api.KeyBackupSession, error) { - result := make(map[string]map[string]api.KeyBackupSession) rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) if err != nil { return nil, err } + return unpackKeys(ctx, rows) +} + +func (s *keyBackupStatements) selectKeysByRoomID( + ctx context.Context, txn *sql.Tx, userID, version, roomID string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( + ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) { + result := make(map[string]map[string]api.KeyBackupSession) defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed") for rows.Next() { var key api.InternalKeyBackupSession diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/accounts/postgres/key_backup_version_table.go index 323a842df..aca575df2 100644 --- a/userapi/storage/accounts/postgres/key_backup_version_table.go +++ b/userapi/storage/accounts/postgres/key_backup_version_table.go @@ -67,7 +67,6 @@ type keyBackupVersionStatements struct { updateKeyBackupETagStmt *sql.Stmt } -// nolint:unused func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { _, err = db.Exec(keyBackupVersionTableSchema) if err != nil { diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index b07218b29..9d6fd13a1 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -96,13 +96,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.openIDTokens.prepare(db, serverName); err != nil { return nil, err } - /* - if err = d.keyBackupVersions.prepare(db); err != nil { - return nil, err - } - if err = d.keyBackups.prepare(db); err != nil { - return nil, err - } */ + if err = d.keyBackupVersions.prepare(db); err != nil { + return nil, err + } + if err = d.keyBackups.prepare(db); err != nil { + return nil, err + } return d, nil } @@ -418,6 +417,37 @@ func (d *Database) GetKeyBackup( return } +func (d *Database) GetBackupKeys( + ctx context.Context, version, userID, filterRoomID, filterSessionID string, +) (result map[string]map[string]api.KeyBackupSession, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + if filterSessionID != "" { + result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) + return err + } + if filterRoomID != "" { + result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID) + return err + } + result, err = d.keyBackups.selectKeys(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) CountBackupKeys( + ctx context.Context, version, userID string, +) (count int64, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + if err != nil { + return err + } + return nil + }) + return +} + // nolint:nakedret func (d *Database) UpsertBackupKeys( ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, @@ -445,7 +475,7 @@ func (d *Database) UpsertBackupKeys( if existingRoom != nil { existingSession, ok := existingRoom[newKey.SessionID] if ok { - if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) { + if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) changed = true if err != nil { @@ -489,22 +519,3 @@ func (d *Database) UpsertBackupKeys( }) return } - -// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't -// create circular import loops -func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool { - // https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 - // "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true" - if uploaded.IsVerified && !existing.IsVerified { - return true - } - // "if they have the same values for is_verified, then it will keep the key with a lower first_message_index" - if uploaded.FirstMessageIndex < existing.FirstMessageIndex { - return true - } - // "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count" - if uploaded.ForwardedCount < existing.ForwardedCount { - return true - } - return false -} diff --git a/userapi/storage/accounts/sqlite3/key_backup_table.go b/userapi/storage/accounts/sqlite3/key_backup_table.go index 268bda936..c1a698e66 100644 --- a/userapi/storage/accounts/sqlite3/key_backup_table.go +++ b/userapi/storage/accounts/sqlite3/key_backup_table.go @@ -35,7 +35,8 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( is_verified BOOLEAN NOT NULL, session_data TEXT NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id); +CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); +CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); ` const insertBackupKeySQL = "" + @@ -53,14 +54,23 @@ const selectKeysSQL = "" + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + "WHERE user_id = $1 AND version = $2" +const selectKeysByRoomIDSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2 AND room_id = $3" + +const selectKeysByRoomIDAndSessionIDSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" + type keyBackupStatements struct { - insertBackupKeyStmt *sql.Stmt - updateBackupKeyStmt *sql.Stmt - countKeysStmt *sql.Stmt - selectKeysStmt *sql.Stmt + insertBackupKeyStmt *sql.Stmt + updateBackupKeyStmt *sql.Stmt + countKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysByRoomIDStmt *sql.Stmt + selectKeysByRoomIDAndSessionIDStmt *sql.Stmt } -// nolint:unused func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { _, err = db.Exec(keyBackupTableSchema) if err != nil { @@ -78,6 +88,12 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { return } + if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil { + return + } + if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil { + return + } return } @@ -109,11 +125,35 @@ func (s *keyBackupStatements) updateBackupKey( func (s *keyBackupStatements) selectKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (map[string]map[string]api.KeyBackupSession, error) { - result := make(map[string]map[string]api.KeyBackupSession) rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) if err != nil { return nil, err } + return unpackKeys(ctx, rows) +} + +func (s *keyBackupStatements) selectKeysByRoomID( + ctx context.Context, txn *sql.Tx, userID, version, roomID string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( + ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) { + result := make(map[string]map[string]api.KeyBackupSession) defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed") for rows.Next() { var key api.InternalKeyBackupSession diff --git a/userapi/storage/accounts/sqlite3/key_backup_version_table.go b/userapi/storage/accounts/sqlite3/key_backup_version_table.go index 72e9b132e..9a58fee75 100644 --- a/userapi/storage/accounts/sqlite3/key_backup_version_table.go +++ b/userapi/storage/accounts/sqlite3/key_backup_version_table.go @@ -65,7 +65,6 @@ type keyBackupVersionStatements struct { updateKeyBackupETagStmt *sql.Stmt } -// nolint:unused func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { _, err = db.Exec(keyBackupVersionTableSchema) if err != nil { diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 4fae621fa..728ae9011 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -100,13 +100,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.openIDTokens.prepare(db, serverName); err != nil { return nil, err } - /* - if err = d.keyBackupVersions.prepare(db); err != nil { - return nil, err - } - if err = d.keyBackups.prepare(db); err != nil { - return nil, err - } */ + if err = d.keyBackupVersions.prepare(db); err != nil { + return nil, err + } + if err = d.keyBackups.prepare(db); err != nil { + return nil, err + } return d, nil } @@ -459,6 +458,37 @@ func (d *Database) GetKeyBackup( return } +func (d *Database) GetBackupKeys( + ctx context.Context, version, userID, filterRoomID, filterSessionID string, +) (result map[string]map[string]api.KeyBackupSession, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + if filterSessionID != "" { + result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) + return err + } + if filterRoomID != "" { + result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID) + return err + } + result, err = d.keyBackups.selectKeys(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) CountBackupKeys( + ctx context.Context, version, userID string, +) (count int64, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + if err != nil { + return err + } + return nil + }) + return +} + // nolint:nakedret func (d *Database) UpsertBackupKeys( ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, @@ -486,7 +516,7 @@ func (d *Database) UpsertBackupKeys( if existingRoom != nil { existingSession, ok := existingRoom[newKey.SessionID] if ok { - if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) { + if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) changed = true if err != nil { @@ -531,22 +561,3 @@ func (d *Database) UpsertBackupKeys( }) return } - -// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't -// create circular import loops -func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool { - // https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 - // "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true" - if uploaded.IsVerified && !existing.IsVerified { - return true - } - // "if they have the same values for is_verified, then it will keep the key with a lower first_message_index" - if uploaded.FirstMessageIndex < existing.FirstMessageIndex { - return true - } - // "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count" - if uploaded.ForwardedCount < existing.ForwardedCount { - return true - } - return false -} From 3e01a88a0cb5b22dfa5247e855ba2f5f7c97d8c2 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 27 Jul 2021 21:34:40 +0100 Subject: [PATCH 5/7] Update to neilalexander/utp@54ae7b1 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 30e3c97bd..fbad10d6e 100644 --- a/go.mod +++ b/go.mod @@ -37,7 +37,7 @@ require ( github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb github.com/morikuni/aec v1.0.0 // indirect - github.com/neilalexander/utp v0.1.1-0.20210720104546-52626cdf31b2 + github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6 github.com/opentracing/opentracing-go v1.2.0 diff --git a/go.sum b/go.sum index a090dbc33..ff6090a56 100644 --- a/go.sum +++ b/go.sum @@ -1184,8 +1184,8 @@ github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/neilalexander/utp v0.1.1-0.20210622132614-ee9a34a30488/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= -github.com/neilalexander/utp v0.1.1-0.20210720104546-52626cdf31b2 h1:txJOiDxsypF8RbzbcyOD3ovip+uUWNZE/Zo7qLdARZQ= -github.com/neilalexander/utp v0.1.1-0.20210720104546-52626cdf31b2/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= +github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 h1:lrVQzBtkeQEGGYUHwSX1XPe1E5GL6U3KYCNe2G4bncQ= +github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6 h1:evlcQnJY+v8XRRchV3hXzpHDl6GcEZeLXAhlH9Csdww= From 9e4618000e0347741eac1279bf6c94c3b9980785 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 28 Jul 2021 10:25:45 +0100 Subject: [PATCH 6/7] Alias key backup endpoints onto /unstable, fix key backup bugs (#1947) * Default /unstable requests to stable endpoints if not overridden specifically with a custom route * Rewrite URL * Try something different * Fix routing manually * Fix selectLatestVersionSQL * Don't return 0 if no backup version exists * Log more useful error * fix up replace keys check * Don't enforce uniqueness on e2e_room_keys_versions_idx Co-authored-by: kegsay --- clientapi/routing/routing.go | 292 +++++++++--------- userapi/api/api.go | 10 +- .../accounts/postgres/key_backup_table.go | 2 +- .../postgres/key_backup_version_table.go | 17 +- userapi/storage/accounts/postgres/storage.go | 4 +- .../accounts/sqlite3/key_backup_table.go | 2 +- .../sqlite3/key_backup_version_table.go | 17 +- userapi/storage/accounts/sqlite3/storage.go | 4 +- 8 files changed, 187 insertions(+), 161 deletions(-) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index cfc6c47fb..541c76282 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -898,157 +898,171 @@ func Setup( // Key Backup Versions (Metadata) - r0mux.Handle("/room_keys/version/{version}", - httputil.MakeAuthAPI("get_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.ErrorResponse(err) - } - return KeyBackupVersion(req, userAPI, device, vars["version"]) - }), - ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/room_keys/version", - httputil.MakeAuthAPI("get_latest_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return KeyBackupVersion(req, userAPI, device, "") - }), - ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/room_keys/version/{version}", - httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.ErrorResponse(err) - } - return ModifyKeyBackupVersionAuthData(req, userAPI, device, vars["version"]) - }), - ).Methods(http.MethodPut) - r0mux.Handle("/room_keys/version/{version}", - httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.ErrorResponse(err) - } - return DeleteKeyBackupVersion(req, userAPI, device, vars["version"]) - }), - ).Methods(http.MethodDelete) - r0mux.Handle("/room_keys/version", - httputil.MakeAuthAPI("post_new_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return CreateKeyBackupVersion(req, userAPI, device) - }), - ).Methods(http.MethodPost, http.MethodOptions) + getBackupKeysVersion := httputil.MakeAuthAPI("get_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return KeyBackupVersion(req, userAPI, device, vars["version"]) + }) + + getLatestBackupKeysVersion := httputil.MakeAuthAPI("get_latest_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return KeyBackupVersion(req, userAPI, device, "") + }) + + putBackupKeysVersion := httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return ModifyKeyBackupVersionAuthData(req, userAPI, device, vars["version"]) + }) + + deleteBackupKeysVersion := httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return DeleteKeyBackupVersion(req, userAPI, device, vars["version"]) + }) + + postNewBackupKeysVersion := httputil.MakeAuthAPI("post_new_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return CreateKeyBackupVersion(req, userAPI, device) + }) + + r0mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut) + r0mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete) + r0mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions) + + unstableMux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + unstableMux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + unstableMux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut) + unstableMux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete) + unstableMux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions) // Inserting E2E Backup Keys // Bulk room and session - r0mux.Handle("/room_keys/keys", - httputil.MakeAuthAPI("put_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - version := req.URL.Query().Get("version") - if version == "" { - return util.JSONResponse{ - Code: 400, - JSON: jsonerror.InvalidArgumentValue("version must be specified"), - } + putBackupKeys := httputil.MakeAuthAPI("put_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + version := req.URL.Query().Get("version") + if version == "" { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("version must be specified"), } - var reqBody keyBackupSessionRequest - resErr := clientutil.UnmarshalJSONRequest(req, &reqBody) - if resErr != nil { - return *resErr - } - return UploadBackupKeys(req, userAPI, device, version, &reqBody) - }), - ).Methods(http.MethodPut) + } + var reqBody keyBackupSessionRequest + resErr := clientutil.UnmarshalJSONRequest(req, &reqBody) + if resErr != nil { + return *resErr + } + return UploadBackupKeys(req, userAPI, device, version, &reqBody) + }) + // Single room bulk session - r0mux.Handle("/room_keys/keys/{roomID}", - httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.ErrorResponse(err) + putBackupKeysRoom := httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + version := req.URL.Query().Get("version") + if version == "" { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("version must be specified"), } - version := req.URL.Query().Get("version") - if version == "" { - return util.JSONResponse{ - Code: 400, - JSON: jsonerror.InvalidArgumentValue("version must be specified"), - } - } - roomID := vars["roomID"] - var reqBody keyBackupSessionRequest - reqBody.Rooms = make(map[string]struct { - Sessions map[string]userapi.KeyBackupSession `json:"sessions"` - }) - reqBody.Rooms[roomID] = struct { - Sessions map[string]userapi.KeyBackupSession `json:"sessions"` - }{ - Sessions: map[string]userapi.KeyBackupSession{}, - } - body := reqBody.Rooms[roomID] - resErr := clientutil.UnmarshalJSONRequest(req, &body) - if resErr != nil { - return *resErr - } - reqBody.Rooms[roomID] = body - return UploadBackupKeys(req, userAPI, device, version, &reqBody) - }), - ).Methods(http.MethodPut) + } + roomID := vars["roomID"] + var reqBody keyBackupSessionRequest + reqBody.Rooms = make(map[string]struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }) + reqBody.Rooms[roomID] = struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }{ + Sessions: map[string]userapi.KeyBackupSession{}, + } + body := reqBody.Rooms[roomID] + resErr := clientutil.UnmarshalJSONRequest(req, &body) + if resErr != nil { + return *resErr + } + reqBody.Rooms[roomID] = body + return UploadBackupKeys(req, userAPI, device, version, &reqBody) + }) + // Single room, single session - r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", - httputil.MakeAuthAPI("put_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.ErrorResponse(err) + putBackupKeysRoomSession := httputil.MakeAuthAPI("put_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + version := req.URL.Query().Get("version") + if version == "" { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("version must be specified"), } - version := req.URL.Query().Get("version") - if version == "" { - return util.JSONResponse{ - Code: 400, - JSON: jsonerror.InvalidArgumentValue("version must be specified"), - } - } - var reqBody userapi.KeyBackupSession - resErr := clientutil.UnmarshalJSONRequest(req, &reqBody) - if resErr != nil { - return *resErr - } - roomID := vars["roomID"] - sessionID := vars["sessionID"] - var keyReq keyBackupSessionRequest - keyReq.Rooms = make(map[string]struct { - Sessions map[string]userapi.KeyBackupSession `json:"sessions"` - }) - keyReq.Rooms[roomID] = struct { - Sessions map[string]userapi.KeyBackupSession `json:"sessions"` - }{ - Sessions: make(map[string]userapi.KeyBackupSession), - } - keyReq.Rooms[roomID].Sessions[sessionID] = reqBody - return UploadBackupKeys(req, userAPI, device, version, &keyReq) - }), - ).Methods(http.MethodPut) + } + var reqBody userapi.KeyBackupSession + resErr := clientutil.UnmarshalJSONRequest(req, &reqBody) + if resErr != nil { + return *resErr + } + roomID := vars["roomID"] + sessionID := vars["sessionID"] + var keyReq keyBackupSessionRequest + keyReq.Rooms = make(map[string]struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }) + keyReq.Rooms[roomID] = struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }{ + Sessions: make(map[string]userapi.KeyBackupSession), + } + keyReq.Rooms[roomID].Sessions[sessionID] = reqBody + return UploadBackupKeys(req, userAPI, device, version, &keyReq) + }) + + r0mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) + r0mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) + r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut) + + unstableMux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) + unstableMux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) + unstableMux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut) // Querying E2E Backup Keys - r0mux.Handle("/room_keys/keys", - httputil.MakeAuthAPI("get_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), "", "") - }), - ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/room_keys/keys/{roomID}", - httputil.MakeAuthAPI("get_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.ErrorResponse(err) - } - return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], "") - }), - ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", - httputil.MakeAuthAPI("get_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.ErrorResponse(err) - } - return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"]) - }), - ).Methods(http.MethodGet, http.MethodOptions) + getBackupKeys := httputil.MakeAuthAPI("get_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), "", "") + }) + + getBackupKeysRoom := httputil.MakeAuthAPI("get_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], "") + }) + + getBackupKeysRoomSession := httputil.MakeAuthAPI("get_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"]) + }) + + r0mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions) + + unstableMux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) + unstableMux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) + unstableMux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions) // Deleting E2E Backup Keys diff --git a/userapi/api/api.go b/userapi/api/api.go index b0d918560..75d06dd69 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -72,13 +72,11 @@ func (a *KeyBackupSession) ShouldReplaceRoomKey(newKey *KeyBackupSession) bool { // "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true" if newKey.IsVerified && !a.IsVerified { return true - } - // "if they have the same values for is_verified, then it will keep the key with a lower first_message_index" - if newKey.FirstMessageIndex < a.FirstMessageIndex { + } else if newKey.FirstMessageIndex < a.FirstMessageIndex { + // "if they have the same values for is_verified, then it will keep the key with a lower first_message_index" return true - } - // "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count" - if newKey.ForwardedCount < a.ForwardedCount { + } else if newKey.ForwardedCount < a.ForwardedCount { + // "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count" return true } return false diff --git a/userapi/storage/accounts/postgres/key_backup_table.go b/userapi/storage/accounts/postgres/key_backup_table.go index ec6518267..0a2a26550 100644 --- a/userapi/storage/accounts/postgres/key_backup_table.go +++ b/userapi/storage/accounts/postgres/key_backup_table.go @@ -36,7 +36,7 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( session_data TEXT NOT NULL ); CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); -CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); +CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); ` const insertBackupKeySQL = "" + diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/accounts/postgres/key_backup_version_table.go index aca575df2..51a462b32 100644 --- a/userapi/storage/accounts/postgres/key_backup_version_table.go +++ b/userapi/storage/accounts/postgres/key_backup_version_table.go @@ -146,12 +146,19 @@ func (s *keyBackupVersionStatements) selectKeyBackup( ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { var versionInt int64 if version == "" { - err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt) + var v *int64 // allows nulls + if err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&v); err != nil { + return + } + if v == nil { + err = sql.ErrNoRows + return + } + versionInt = *v } else { - versionInt, err = strconv.ParseInt(version, 10, 64) - } - if err != nil { - return + if versionInt, err = strconv.ParseInt(version, 10, 64); err != nil { + return + } } versionResult = strconv.FormatInt(versionInt, 10) var deletedInt int diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index 9d6fd13a1..6bddbfc3f 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -479,7 +479,7 @@ func (d *Database) UpsertBackupKeys( err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) changed = true if err != nil { - return err + return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) } } // if we shouldn't replace the key we do nothing with it @@ -490,7 +490,7 @@ func (d *Database) UpsertBackupKeys( err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) changed = true if err != nil { - return err + return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) } } diff --git a/userapi/storage/accounts/sqlite3/key_backup_table.go b/userapi/storage/accounts/sqlite3/key_backup_table.go index c1a698e66..675093514 100644 --- a/userapi/storage/accounts/sqlite3/key_backup_table.go +++ b/userapi/storage/accounts/sqlite3/key_backup_table.go @@ -36,7 +36,7 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( session_data TEXT NOT NULL ); CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); -CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); +CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); ` const insertBackupKeySQL = "" + diff --git a/userapi/storage/accounts/sqlite3/key_backup_version_table.go b/userapi/storage/accounts/sqlite3/key_backup_version_table.go index 9a58fee75..a9e7bf5db 100644 --- a/userapi/storage/accounts/sqlite3/key_backup_version_table.go +++ b/userapi/storage/accounts/sqlite3/key_backup_version_table.go @@ -144,12 +144,19 @@ func (s *keyBackupVersionStatements) selectKeyBackup( ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { var versionInt int64 if version == "" { - err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt) + var v *int64 // allows nulls + if err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&v); err != nil { + return + } + if v == nil { + err = sql.ErrNoRows + return + } + versionInt = *v } else { - versionInt, err = strconv.ParseInt(version, 10, 64) - } - if err != nil { - return + if versionInt, err = strconv.ParseInt(version, 10, 64); err != nil { + return + } } versionResult = strconv.FormatInt(versionInt, 10) var deletedInt int diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 728ae9011..d752e3db8 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -520,7 +520,7 @@ func (d *Database) UpsertBackupKeys( err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) changed = true if err != nil { - return err + return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) } } // if we shouldn't replace the key we do nothing with it @@ -531,7 +531,7 @@ func (d *Database) UpsertBackupKeys( err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) changed = true if err != nil { - return err + return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) } } From ed4097825bc65f2332bcdc975ed201841221ff7c Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Wed, 28 Jul 2021 18:30:04 +0100 Subject: [PATCH 7/7] Factor out StatementList to `sqlutil` and use it in `userapi` It helps with the boilerplate. --- internal/sqlutil/sql.go | 16 ++++++++++ .../storage/postgres/event_json_table.go | 4 +-- .../postgres/event_state_keys_table.go | 3 +- .../storage/postgres/event_types_table.go | 3 +- roomserver/storage/postgres/events_table.go | 3 +- roomserver/storage/postgres/invite_table.go | 3 +- .../storage/postgres/membership_table.go | 3 +- .../storage/postgres/previous_events_table.go | 3 +- .../storage/postgres/published_table.go | 3 +- .../storage/postgres/redactions_table.go | 3 +- .../storage/postgres/room_aliases_table.go | 3 +- roomserver/storage/postgres/rooms_table.go | 3 +- .../storage/postgres/state_block_table.go | 4 +-- .../storage/postgres/state_snapshot_table.go | 3 +- .../storage/postgres/transactions_table.go | 4 +-- .../storage/sqlite3/event_json_table.go | 3 +- .../storage/sqlite3/event_state_keys_table.go | 3 +- .../storage/sqlite3/event_types_table.go | 3 +- roomserver/storage/sqlite3/events_table.go | 3 +- roomserver/storage/sqlite3/invite_table.go | 3 +- .../storage/sqlite3/membership_table.go | 3 +- .../storage/sqlite3/previous_events_table.go | 3 +- roomserver/storage/sqlite3/published_table.go | 3 +- .../storage/sqlite3/redactions_table.go | 3 +- .../storage/sqlite3/room_aliases_table.go | 3 +- roomserver/storage/sqlite3/rooms_table.go | 3 +- .../storage/sqlite3/state_block_table.go | 3 +- .../storage/sqlite3/state_snapshot_table.go | 3 +- .../storage/sqlite3/transactions_table.go | 3 +- .../accounts/postgres/account_data_table.go | 15 ++++------ .../accounts/postgres/accounts_table.go | 27 +++++------------ .../accounts/postgres/key_backup_table.go | 28 ++++++------------ .../postgres/key_backup_version_table.go | 29 +++++++------------ .../storage/accounts/postgres/openid_table.go | 11 +++---- .../accounts/postgres/profile_table.go | 23 +++++---------- .../accounts/postgres/threepid_table.go | 20 ++++--------- .../accounts/sqlite3/account_data_table.go | 15 ++++------ .../accounts/sqlite3/accounts_table.go | 27 +++++------------ .../accounts/sqlite3/key_backup_table.go | 28 ++++++------------ .../sqlite3/key_backup_version_table.go | 29 +++++++------------ .../storage/accounts/sqlite3/openid_table.go | 11 +++---- .../storage/accounts/sqlite3/profile_table.go | 23 +++++---------- .../accounts/sqlite3/threepid_table.go | 20 ++++--------- 43 files changed, 145 insertions(+), 264 deletions(-) diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 6cf17bea0..98cc396ad 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -152,3 +152,19 @@ func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvide } return nil } + +// StatementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. +type StatementList []struct { + Statement **sql.Stmt + SQL string +} + +// Prepare the SQL for each statement in the list and assign the result to the prepared statement. +func (s StatementList) Prepare(db *sql.DB) (err error) { + for _, statement := range s { + if *statement.Statement, err = db.Prepare(statement.SQL); err != nil { + return + } + } + return +} diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index e0976b12c..32e457821 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -20,7 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -67,7 +67,7 @@ func createEventJSONTable(db *sql.DB) error { func prepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) { s := &eventJSONStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, }.Prepare(db) diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go index 616823561..3a7cf03e3 100644 --- a/roomserver/storage/postgres/event_state_keys_table.go +++ b/roomserver/storage/postgres/event_state_keys_table.go @@ -22,7 +22,6 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -85,7 +84,7 @@ func createEventStateKeysTable(db *sql.DB) error { func prepareEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { s := &eventStateKeyStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index f4257850a..e558072a5 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -22,7 +22,6 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -108,7 +107,7 @@ func createEventTypesTable(db *sql.DB) error { func prepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) { s := &eventTypeStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, {&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL}, diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 88c82083c..c549fb650 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -24,7 +24,6 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -160,7 +159,7 @@ func createEventsTable(db *sql.DB) error { func prepareEventsTable(db *sql.DB) (tables.Events, error) { s := &eventStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, diff --git a/roomserver/storage/postgres/invite_table.go b/roomserver/storage/postgres/invite_table.go index 0a2183e27..344302c8f 100644 --- a/roomserver/storage/postgres/invite_table.go +++ b/roomserver/storage/postgres/invite_table.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -90,7 +89,7 @@ func createInvitesTable(db *sql.DB) error { func prepareInvitesTable(db *sql.DB) (tables.Invites, error) { s := &inviteStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertInviteEventStmt, insertInviteEventSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index b4a27900c..b0d906c80 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -23,7 +23,6 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -168,7 +167,7 @@ func createMembershipTable(db *sql.DB) error { func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { s := &membershipStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertMembershipStmt, insertMembershipSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, diff --git a/roomserver/storage/postgres/previous_events_table.go b/roomserver/storage/postgres/previous_events_table.go index 4a93c3d65..bd4e853eb 100644 --- a/roomserver/storage/postgres/previous_events_table.go +++ b/roomserver/storage/postgres/previous_events_table.go @@ -20,7 +20,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -73,7 +72,7 @@ func createPrevEventsTable(db *sql.DB) error { func preparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { s := &previousEventStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, }.Prepare(db) diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go index c180576e3..8deb68441 100644 --- a/roomserver/storage/postgres/published_table.go +++ b/roomserver/storage/postgres/published_table.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -58,7 +57,7 @@ func createPublishedTable(db *sql.DB) error { func preparePublishedTable(db *sql.DB) (tables.Published, error) { s := &publishedStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.upsertPublishedStmt, upsertPublishedSQL}, {&s.selectAllPublishedStmt, selectAllPublishedSQL}, {&s.selectPublishedStmt, selectPublishedSQL}, diff --git a/roomserver/storage/postgres/redactions_table.go b/roomserver/storage/postgres/redactions_table.go index 3741d5f67..5614f2bd8 100644 --- a/roomserver/storage/postgres/redactions_table.go +++ b/roomserver/storage/postgres/redactions_table.go @@ -19,7 +19,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -68,7 +67,7 @@ func createRedactionsTable(db *sql.DB) error { func prepareRedactionsTable(db *sql.DB) (tables.Redactions, error) { s := &redactionStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertRedactionStmt, insertRedactionSQL}, {&s.selectRedactionInfoByRedactionEventIDStmt, selectRedactionInfoByRedactionEventIDSQL}, {&s.selectRedactionInfoByEventBeingRedactedStmt, selectRedactionInfoByEventBeingRedactedSQL}, diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go index c808813ee..031825fee 100644 --- a/roomserver/storage/postgres/room_aliases_table.go +++ b/roomserver/storage/postgres/room_aliases_table.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -70,7 +69,7 @@ func createRoomAliasesTable(db *sql.DB) error { func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { s := &roomAliasesStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index f2b39fe54..ba8eb671b 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -22,7 +22,6 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -104,7 +103,7 @@ func createRoomsTable(db *sql.DB) error { func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { s := &roomStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index 7ae987d6d..27d85e83b 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -23,7 +23,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" @@ -79,7 +79,7 @@ func createStateBlockTable(db *sql.DB) error { func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { s := &stateBlockStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, }.Prepare(db) diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index 15e14e2e0..4fc0fa48a 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -22,7 +22,6 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" @@ -86,7 +85,7 @@ func createStateSnapshotTable(db *sql.DB) error { func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { s := &stateSnapshotStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, }.Prepare(db) diff --git a/roomserver/storage/postgres/transactions_table.go b/roomserver/storage/postgres/transactions_table.go index cada0d8aa..af023b740 100644 --- a/roomserver/storage/postgres/transactions_table.go +++ b/roomserver/storage/postgres/transactions_table.go @@ -19,7 +19,7 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -62,7 +62,7 @@ func createTransactionsTable(db *sql.DB) error { func prepareTransactionsTable(db *sql.DB) (tables.Transactions, error) { s := &transactionStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertTransactionStmt, insertTransactionSQL}, {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, }.Prepare(db) diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index 29d54b83d..53b219294 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -22,7 +22,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -63,7 +62,7 @@ func prepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, }.Prepare(db) diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index d430e5535..62fbce2d0 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -22,7 +22,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -80,7 +79,7 @@ func prepareEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 694f4e217..22df3fb22 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -95,7 +94,7 @@ func prepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, {&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL}, {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index e964770d7..a28d95fa5 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -25,7 +25,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -131,7 +130,7 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index e1aa1ebd3..c1d7347ae 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -80,7 +79,7 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertInviteEventStmt, insertInviteEventSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 911a25168..2e58431d3 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -146,7 +145,7 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertMembershipStmt, insertMembershipSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index 3cb527678..7304bf0d5 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -22,7 +22,6 @@ import ( "strings" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -81,7 +80,7 @@ func preparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index 6d9d91355..b07c0ac42 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -60,7 +59,7 @@ func preparePublishedTable(db *sql.DB) (tables.Published, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.upsertPublishedStmt, upsertPublishedSQL}, {&s.selectAllPublishedStmt, selectAllPublishedSQL}, {&s.selectPublishedStmt, selectPublishedSQL}, diff --git a/roomserver/storage/sqlite3/redactions_table.go b/roomserver/storage/sqlite3/redactions_table.go index b34981829..aed190b1e 100644 --- a/roomserver/storage/sqlite3/redactions_table.go +++ b/roomserver/storage/sqlite3/redactions_table.go @@ -19,7 +19,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -69,7 +68,7 @@ func prepareRedactionsTable(db *sql.DB) (tables.Redactions, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertRedactionStmt, insertRedactionSQL}, {&s.selectRedactionInfoByRedactionEventIDStmt, selectRedactionInfoByRedactionEventIDSQL}, {&s.selectRedactionInfoByEventBeingRedactedStmt, selectRedactionInfoByEventBeingRedactedSQL}, diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index 5215fa6f7..323945b88 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -74,7 +73,7 @@ func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 534a870cc..2dfb830d8 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -24,7 +24,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -96,7 +95,7 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 5cb21e91c..a472437a3 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -25,7 +25,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" @@ -75,7 +74,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, }.Prepare(db) diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 95cae99e5..ad623d65a 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -24,7 +24,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" @@ -79,7 +78,7 @@ func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, }.Prepare(db) diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go index e7471d7b0..1fb0a831d 100644 --- a/roomserver/storage/sqlite3/transactions_table.go +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -20,7 +20,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -59,7 +58,7 @@ func prepareTransactionsTable(db *sql.DB) (tables.Transactions, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertTransactionStmt, insertTransactionSQL}, {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, }.Prepare(db) diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/accounts/postgres/account_data_table.go index 09eb26113..8ba890e75 100644 --- a/userapi/storage/accounts/postgres/account_data_table.go +++ b/userapi/storage/accounts/postgres/account_data_table.go @@ -61,16 +61,11 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { - return - } - if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil { - return - } - if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertAccountDataStmt, insertAccountDataSQL}, + {&s.selectAccountDataStmt, selectAccountDataSQL}, + {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, + }.Prepare(db) } func (s *accountDataStatements) insertAccountData( diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go index 4eaa5b581..b57aa901f 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -81,26 +81,15 @@ func (s *accountsStatements) execSchema(db *sql.DB) error { } func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { - return - } - if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil { - return - } - if s.deactivateAccountStmt, err = db.Prepare(deactivateAccountSQL); err != nil { - return - } - if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { - return - } - if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil { - return - } - if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil { - return - } s.serverName = server - return + return sqlutil.StatementList{ + {&s.insertAccountStmt, insertAccountSQL}, + {&s.updatePasswordStmt, updatePasswordSQL}, + {&s.deactivateAccountStmt, deactivateAccountSQL}, + {&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL}, + {&s.selectPasswordHashStmt, selectPasswordHashSQL}, + {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, + }.Prepare(db) } // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, diff --git a/userapi/storage/accounts/postgres/key_backup_table.go b/userapi/storage/accounts/postgres/key_backup_table.go index 0a2a26550..c1402d4d2 100644 --- a/userapi/storage/accounts/postgres/key_backup_table.go +++ b/userapi/storage/accounts/postgres/key_backup_table.go @@ -20,6 +20,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" ) @@ -76,25 +77,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertBackupKeyStmt, err = db.Prepare(insertBackupKeySQL); err != nil { - return - } - if s.updateBackupKeyStmt, err = db.Prepare(updateBackupKeySQL); err != nil { - return - } - if s.countKeysStmt, err = db.Prepare(countKeysSQL); err != nil { - return - } - if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { - return - } - if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil { - return - } - if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertBackupKeyStmt, insertBackupKeySQL}, + {&s.updateBackupKeyStmt, updateBackupKeySQL}, + {&s.countKeysStmt, countKeysSQL}, + {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, + {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, + }.Prepare(db) } func (s keyBackupStatements) countKeys( diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/accounts/postgres/key_backup_version_table.go index 51a462b32..d73447b49 100644 --- a/userapi/storage/accounts/postgres/key_backup_version_table.go +++ b/userapi/storage/accounts/postgres/key_backup_version_table.go @@ -20,6 +20,8 @@ import ( "encoding/json" "fmt" "strconv" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) const keyBackupVersionTableSchema = ` @@ -72,25 +74,14 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertKeyBackupStmt, err = db.Prepare(insertKeyBackupSQL); err != nil { - return - } - if s.updateKeyBackupAuthDataStmt, err = db.Prepare(updateKeyBackupAuthDataSQL); err != nil { - return - } - if s.deleteKeyBackupStmt, err = db.Prepare(deleteKeyBackupSQL); err != nil { - return - } - if s.selectKeyBackupStmt, err = db.Prepare(selectKeyBackupSQL); err != nil { - return - } - if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil { - return - } - if s.updateKeyBackupETagStmt, err = db.Prepare(updateKeyBackupETagSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertKeyBackupStmt, insertKeyBackupSQL}, + {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, + {&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, + {&s.selectKeyBackupStmt, selectKeyBackupSQL}, + {&s.selectLatestVersionStmt, selectLatestVersionSQL}, + {&s.updateKeyBackupETagStmt, updateKeyBackupETagSQL}, + }.Prepare(db) } func (s *keyBackupVersionStatements) insertKeyBackup( diff --git a/userapi/storage/accounts/postgres/openid_table.go b/userapi/storage/accounts/postgres/openid_table.go index 86c197059..190d141b7 100644 --- a/userapi/storage/accounts/postgres/openid_table.go +++ b/userapi/storage/accounts/postgres/openid_table.go @@ -39,14 +39,11 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam if err != nil { return } - if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil { - return - } - if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil { - return - } s.serverName = server - return + return sqlutil.StatementList{ + {&s.insertTokenStmt, insertTokenSQL}, + {&s.selectTokenStmt, selectTokenSQL}, + }.Prepare(db) } // insertToken inserts a new OpenID Connect token to the DB. diff --git a/userapi/storage/accounts/postgres/profile_table.go b/userapi/storage/accounts/postgres/profile_table.go index 45d802f18..9313864be 100644 --- a/userapi/storage/accounts/postgres/profile_table.go +++ b/userapi/storage/accounts/postgres/profile_table.go @@ -64,22 +64,13 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil { - return - } - if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil { - return - } - if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil { - return - } - if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { - return - } - if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertProfileStmt, insertProfileSQL}, + {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, + {&s.setAvatarURLStmt, setAvatarURLSQL}, + {&s.setDisplayNameStmt, setDisplayNameSQL}, + {&s.selectProfilesBySearchStmt, selectProfilesBySearchSQL}, + }.Prepare(db) } func (s *profilesStatements) insertProfile( diff --git a/userapi/storage/accounts/postgres/threepid_table.go b/userapi/storage/accounts/postgres/threepid_table.go index 7de96350c..9280fc87c 100644 --- a/userapi/storage/accounts/postgres/threepid_table.go +++ b/userapi/storage/accounts/postgres/threepid_table.go @@ -63,20 +63,12 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil { - return - } - if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil { - return - } - if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil { - return - } - if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil { - return - } - - return + return sqlutil.StatementList{ + {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, + {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, + {&s.insertThreePIDStmt, insertThreePIDSQL}, + {&s.deleteThreePIDStmt, deleteThreePIDSQL}, + }.Prepare(db) } func (s *threepidStatements) selectLocalpartForThreePID( diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go index 870a37065..871f996e0 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -62,16 +62,11 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { - return - } - if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil { - return - } - if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertAccountDataStmt, insertAccountDataSQL}, + {&s.selectAccountDataStmt, selectAccountDataSQL}, + {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, + }.Prepare(db) } func (s *accountDataStatements) insertAccountData( diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 50f07237e..8a7c8fba7 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -81,26 +81,15 @@ func (s *accountsStatements) execSchema(db *sql.DB) error { func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { s.db = db - if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { - return - } - if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil { - return - } - if s.deactivateAccountStmt, err = db.Prepare(deactivateAccountSQL); err != nil { - return - } - if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { - return - } - if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil { - return - } - if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil { - return - } s.serverName = server - return + return sqlutil.StatementList{ + {&s.insertAccountStmt, insertAccountSQL}, + {&s.updatePasswordStmt, updatePasswordSQL}, + {&s.deactivateAccountStmt, deactivateAccountSQL}, + {&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL}, + {&s.selectPasswordHashStmt, selectPasswordHashSQL}, + {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, + }.Prepare(db) } // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, diff --git a/userapi/storage/accounts/sqlite3/key_backup_table.go b/userapi/storage/accounts/sqlite3/key_backup_table.go index 675093514..837d38cf1 100644 --- a/userapi/storage/accounts/sqlite3/key_backup_table.go +++ b/userapi/storage/accounts/sqlite3/key_backup_table.go @@ -20,6 +20,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" ) @@ -76,25 +77,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertBackupKeyStmt, err = db.Prepare(insertBackupKeySQL); err != nil { - return - } - if s.updateBackupKeyStmt, err = db.Prepare(updateBackupKeySQL); err != nil { - return - } - if s.countKeysStmt, err = db.Prepare(countKeysSQL); err != nil { - return - } - if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { - return - } - if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil { - return - } - if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertBackupKeyStmt, insertBackupKeySQL}, + {&s.updateBackupKeyStmt, updateBackupKeySQL}, + {&s.countKeysStmt, countKeysSQL}, + {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, + {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, + }.Prepare(db) } func (s keyBackupStatements) countKeys( diff --git a/userapi/storage/accounts/sqlite3/key_backup_version_table.go b/userapi/storage/accounts/sqlite3/key_backup_version_table.go index a9e7bf5db..4211ed0f1 100644 --- a/userapi/storage/accounts/sqlite3/key_backup_version_table.go +++ b/userapi/storage/accounts/sqlite3/key_backup_version_table.go @@ -20,6 +20,8 @@ import ( "encoding/json" "fmt" "strconv" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) const keyBackupVersionTableSchema = ` @@ -70,25 +72,14 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertKeyBackupStmt, err = db.Prepare(insertKeyBackupSQL); err != nil { - return - } - if s.updateKeyBackupAuthDataStmt, err = db.Prepare(updateKeyBackupAuthDataSQL); err != nil { - return - } - if s.deleteKeyBackupStmt, err = db.Prepare(deleteKeyBackupSQL); err != nil { - return - } - if s.selectKeyBackupStmt, err = db.Prepare(selectKeyBackupSQL); err != nil { - return - } - if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil { - return - } - if s.updateKeyBackupETagStmt, err = db.Prepare(updateKeyBackupETagSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertKeyBackupStmt, insertKeyBackupSQL}, + {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, + {&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, + {&s.selectKeyBackupStmt, selectKeyBackupSQL}, + {&s.selectLatestVersionStmt, selectLatestVersionSQL}, + {&s.updateKeyBackupETagStmt, updateKeyBackupETagSQL}, + }.Prepare(db) } func (s *keyBackupVersionStatements) insertKeyBackup( diff --git a/userapi/storage/accounts/sqlite3/openid_table.go b/userapi/storage/accounts/sqlite3/openid_table.go index 80b9dd4cb..98c0488b1 100644 --- a/userapi/storage/accounts/sqlite3/openid_table.go +++ b/userapi/storage/accounts/sqlite3/openid_table.go @@ -41,14 +41,11 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam if err != nil { return err } - if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil { - return - } - if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil { - return - } s.serverName = server - return + return sqlutil.StatementList{ + {&s.insertTokenStmt, insertTokenSQL}, + {&s.selectTokenStmt, selectTokenSQL}, + }.Prepare(db) } // insertToken inserts a new OpenID Connect token to the DB. diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go index a67e892f7..a92e95663 100644 --- a/userapi/storage/accounts/sqlite3/profile_table.go +++ b/userapi/storage/accounts/sqlite3/profile_table.go @@ -66,22 +66,13 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil { - return - } - if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil { - return - } - if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil { - return - } - if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { - return - } - if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertProfileStmt, insertProfileSQL}, + {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, + {&s.setAvatarURLStmt, setAvatarURLSQL}, + {&s.setDisplayNameStmt, setDisplayNameSQL}, + {&s.selectProfilesBySearchStmt, selectProfilesBySearchSQL}, + }.Prepare(db) } func (s *profilesStatements) insertProfile( diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go index 43112d389..9dc0e2d22 100644 --- a/userapi/storage/accounts/sqlite3/threepid_table.go +++ b/userapi/storage/accounts/sqlite3/threepid_table.go @@ -66,20 +66,12 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil { - return - } - if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil { - return - } - if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil { - return - } - if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil { - return - } - - return + return sqlutil.StatementList{ + {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, + {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, + {&s.insertThreePIDStmt, insertThreePIDSQL}, + {&s.deleteThreePIDStmt, deleteThreePIDSQL}, + }.Prepare(db) } func (s *threepidStatements) selectLocalpartForThreePID(