diff --git a/clientapi/routing/key_backup.go b/clientapi/routing/key_backup.go new file mode 100644 index 000000000..ce62a047a --- /dev/null +++ b/clientapi/routing/key_backup.go @@ -0,0 +1,291 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +type keyBackupVersion struct { + Algorithm string `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` +} + +type keyBackupVersionCreateResponse struct { + Version string `json:"version"` +} + +type keyBackupVersionResponse struct { + Algorithm string `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` + Count int64 `json:"count"` + ETag string `json:"etag"` + Version string `json:"version"` +} + +type keyBackupSessionRequest struct { + Rooms map[string]struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + } `json:"rooms"` +} + +type keyBackupSessionResponse struct { + Count int64 `json:"count"` + ETag string `json:"etag"` +} + +// Create a new key backup. Request must contain a `keyBackupVersion`. Returns a `keyBackupVersionCreateResponse`. +// Implements POST /_matrix/client/r0/room_keys/version +func CreateKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device) util.JSONResponse { + var kb keyBackupVersion + resErr := httputil.UnmarshalJSONRequest(req, &kb) + if resErr != nil { + return *resErr + } + var performKeyBackupResp userapi.PerformKeyBackupResponse + userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + UserID: device.UserID, + Version: "", + AuthData: kb.AuthData, + Algorithm: kb.Algorithm, + }, &performKeyBackupResp) + if performKeyBackupResp.Error != "" { + if performKeyBackupResp.BadInput { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), + } + } + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + } + return util.JSONResponse{ + Code: 200, + JSON: keyBackupVersionCreateResponse{ + Version: performKeyBackupResp.Version, + }, + } +} + +// KeyBackupVersion returns the key backup version specified. If `version` is empty, the latest `keyBackupVersionResponse` is returned. +// Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version} +func KeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { + var queryResp userapi.QueryKeyBackupResponse + userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ + UserID: device.UserID, + Version: version, + }, &queryResp) + if queryResp.Error != "" { + return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) + } + if !queryResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("version not found"), + } + } + return util.JSONResponse{ + Code: 200, + JSON: keyBackupVersionResponse{ + Algorithm: queryResp.Algorithm, + AuthData: queryResp.AuthData, + Count: queryResp.Count, + ETag: queryResp.ETag, + Version: queryResp.Version, + }, + } +} + +// Modify the auth data of a key backup. Version must not be empty. Request must contain a `keyBackupVersion` +// Implements PUT /_matrix/client/r0/room_keys/version/{version} +func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { + var kb keyBackupVersion + resErr := httputil.UnmarshalJSONRequest(req, &kb) + if resErr != nil { + return *resErr + } + var performKeyBackupResp userapi.PerformKeyBackupResponse + userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + UserID: device.UserID, + Version: version, + AuthData: kb.AuthData, + Algorithm: kb.Algorithm, + }, &performKeyBackupResp) + if performKeyBackupResp.Error != "" { + if performKeyBackupResp.BadInput { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), + } + } + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + } + if !performKeyBackupResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("backup version not found"), + } + } + // Unclear what the 200 body should be + return util.JSONResponse{ + Code: 200, + JSON: keyBackupVersionCreateResponse{ + Version: performKeyBackupResp.Version, + }, + } +} + +// Delete a version of key backup. Version must not be empty. If the key backup was previously deleted, will return 200 OK. +// Implements DELETE /_matrix/client/r0/room_keys/version/{version} +func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { + var performKeyBackupResp userapi.PerformKeyBackupResponse + userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + UserID: device.UserID, + Version: version, + DeleteBackup: true, + }, &performKeyBackupResp) + if performKeyBackupResp.Error != "" { + if performKeyBackupResp.BadInput { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), + } + } + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + } + if !performKeyBackupResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("backup version not found"), + } + } + // Unclear what the 200 body should be + return util.JSONResponse{ + Code: 200, + JSON: keyBackupVersionCreateResponse{ + Version: performKeyBackupResp.Version, + }, + } +} + +// Upload a bunch of session keys for a given `version`. +func UploadBackupKeys( + req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string, keys *keyBackupSessionRequest, +) util.JSONResponse { + var performKeyBackupResp userapi.PerformKeyBackupResponse + userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + UserID: device.UserID, + Version: version, + Keys: *keys, + }, &performKeyBackupResp) + if performKeyBackupResp.Error != "" { + if performKeyBackupResp.BadInput { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), + } + } + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + } + if !performKeyBackupResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("backup version not found"), + } + } + return util.JSONResponse{ + Code: 200, + JSON: keyBackupSessionResponse{ + Count: performKeyBackupResp.KeyCount, + ETag: performKeyBackupResp.KeyETag, + }, + } +} + +// Get keys from a given backup version. Response returned varies depending on if roomID and sessionID are set. +func GetBackupKeys( + req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version, roomID, sessionID string, +) util.JSONResponse { + var queryResp userapi.QueryKeyBackupResponse + userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ + UserID: device.UserID, + Version: version, + ReturnKeys: true, + KeysForRoomID: roomID, + KeysForSessionID: sessionID, + }, &queryResp) + if queryResp.Error != "" { + return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) + } + if !queryResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("version not found"), + } + } + if sessionID != "" { + // return the key itself if it was found + roomData, ok := queryResp.Keys[roomID] + if ok { + key, ok := roomData[sessionID] + if ok { + return util.JSONResponse{ + Code: 200, + JSON: key, + } + } + } + } else if roomID != "" { + roomData, ok := queryResp.Keys[roomID] + if ok { + // wrap response in "sessions" + return util.JSONResponse{ + Code: 200, + JSON: struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }{ + Sessions: roomData, + }, + } + } + } else { + // response is the same as the upload request + var resp keyBackupSessionRequest + resp.Rooms = make(map[string]struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }) + for roomID, roomData := range queryResp.Keys { + resp.Rooms[roomID] = struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }{ + Sessions: roomData, + } + } + return util.JSONResponse{ + Code: 200, + JSON: resp, + } + } + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("keys not found"), + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 9527261bf..272717560 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -896,6 +896,176 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) + // Key Backup Versions (Metadata) + + getBackupKeysVersion := httputil.MakeAuthAPI("get_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return KeyBackupVersion(req, userAPI, device, vars["version"]) + }) + + getLatestBackupKeysVersion := httputil.MakeAuthAPI("get_latest_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return KeyBackupVersion(req, userAPI, device, "") + }) + + putBackupKeysVersion := httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return ModifyKeyBackupVersionAuthData(req, userAPI, device, vars["version"]) + }) + + deleteBackupKeysVersion := httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return DeleteKeyBackupVersion(req, userAPI, device, vars["version"]) + }) + + postNewBackupKeysVersion := httputil.MakeAuthAPI("post_new_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return CreateKeyBackupVersion(req, userAPI, device) + }) + + r0mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut) + r0mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete) + r0mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions) + + unstableMux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + unstableMux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + unstableMux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut) + unstableMux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete) + unstableMux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions) + + // Inserting E2E Backup Keys + + // Bulk room and session + putBackupKeys := httputil.MakeAuthAPI("put_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + version := req.URL.Query().Get("version") + if version == "" { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("version must be specified"), + } + } + var reqBody keyBackupSessionRequest + resErr := clientutil.UnmarshalJSONRequest(req, &reqBody) + if resErr != nil { + return *resErr + } + return UploadBackupKeys(req, userAPI, device, version, &reqBody) + }) + + // Single room bulk session + putBackupKeysRoom := httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + version := req.URL.Query().Get("version") + if version == "" { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("version must be specified"), + } + } + roomID := vars["roomID"] + var reqBody keyBackupSessionRequest + reqBody.Rooms = make(map[string]struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }) + reqBody.Rooms[roomID] = struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }{ + Sessions: map[string]userapi.KeyBackupSession{}, + } + body := reqBody.Rooms[roomID] + resErr := clientutil.UnmarshalJSONRequest(req, &body) + if resErr != nil { + return *resErr + } + reqBody.Rooms[roomID] = body + return UploadBackupKeys(req, userAPI, device, version, &reqBody) + }) + + // Single room, single session + putBackupKeysRoomSession := httputil.MakeAuthAPI("put_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + version := req.URL.Query().Get("version") + if version == "" { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("version must be specified"), + } + } + var reqBody userapi.KeyBackupSession + resErr := clientutil.UnmarshalJSONRequest(req, &reqBody) + if resErr != nil { + return *resErr + } + roomID := vars["roomID"] + sessionID := vars["sessionID"] + var keyReq keyBackupSessionRequest + keyReq.Rooms = make(map[string]struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }) + keyReq.Rooms[roomID] = struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }{ + Sessions: make(map[string]userapi.KeyBackupSession), + } + keyReq.Rooms[roomID].Sessions[sessionID] = reqBody + return UploadBackupKeys(req, userAPI, device, version, &keyReq) + }) + + r0mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) + r0mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) + r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut) + + unstableMux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) + unstableMux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) + unstableMux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut) + + // Querying E2E Backup Keys + + getBackupKeys := httputil.MakeAuthAPI("get_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), "", "") + }) + + getBackupKeysRoom := httputil.MakeAuthAPI("get_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], "") + }) + + getBackupKeysRoomSession := httputil.MakeAuthAPI("get_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"]) + }) + + r0mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions) + + unstableMux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) + unstableMux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) + unstableMux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions) + + // Deleting E2E Backup Keys + // Supplying a device ID is deprecated. r0mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/go.mod b/go.mod index 30e3c97bd..fbad10d6e 100644 --- a/go.mod +++ b/go.mod @@ -37,7 +37,7 @@ require ( github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb github.com/morikuni/aec v1.0.0 // indirect - github.com/neilalexander/utp v0.1.1-0.20210720104546-52626cdf31b2 + github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6 github.com/opentracing/opentracing-go v1.2.0 diff --git a/go.sum b/go.sum index a090dbc33..ff6090a56 100644 --- a/go.sum +++ b/go.sum @@ -1184,8 +1184,8 @@ github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/neilalexander/utp v0.1.1-0.20210622132614-ee9a34a30488/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= -github.com/neilalexander/utp v0.1.1-0.20210720104546-52626cdf31b2 h1:txJOiDxsypF8RbzbcyOD3ovip+uUWNZE/Zo7qLdARZQ= -github.com/neilalexander/utp v0.1.1-0.20210720104546-52626cdf31b2/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= +github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 h1:lrVQzBtkeQEGGYUHwSX1XPe1E5GL6U3KYCNe2G4bncQ= +github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6 h1:evlcQnJY+v8XRRchV3hXzpHDl6GcEZeLXAhlH9Csdww= diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 6cf17bea0..98cc396ad 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -152,3 +152,19 @@ func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvide } return nil } + +// StatementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. +type StatementList []struct { + Statement **sql.Stmt + SQL string +} + +// Prepare the SQL for each statement in the list and assign the result to the prepared statement. +func (s StatementList) Prepare(db *sql.DB) (err error) { + for _, statement := range s { + if *statement.Statement, err = db.Prepare(statement.SQL); err != nil { + return + } + } + return +} diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index e0976b12c..32e457821 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -20,7 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -67,7 +67,7 @@ func createEventJSONTable(db *sql.DB) error { func prepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) { s := &eventJSONStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, }.Prepare(db) diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go index 616823561..3a7cf03e3 100644 --- a/roomserver/storage/postgres/event_state_keys_table.go +++ b/roomserver/storage/postgres/event_state_keys_table.go @@ -22,7 +22,6 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -85,7 +84,7 @@ func createEventStateKeysTable(db *sql.DB) error { func prepareEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { s := &eventStateKeyStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index f4257850a..e558072a5 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -22,7 +22,6 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -108,7 +107,7 @@ func createEventTypesTable(db *sql.DB) error { func prepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) { s := &eventTypeStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, {&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL}, diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 88c82083c..c549fb650 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -24,7 +24,6 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -160,7 +159,7 @@ func createEventsTable(db *sql.DB) error { func prepareEventsTable(db *sql.DB) (tables.Events, error) { s := &eventStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, diff --git a/roomserver/storage/postgres/invite_table.go b/roomserver/storage/postgres/invite_table.go index 0a2183e27..344302c8f 100644 --- a/roomserver/storage/postgres/invite_table.go +++ b/roomserver/storage/postgres/invite_table.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -90,7 +89,7 @@ func createInvitesTable(db *sql.DB) error { func prepareInvitesTable(db *sql.DB) (tables.Invites, error) { s := &inviteStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertInviteEventStmt, insertInviteEventSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index b4a27900c..b0d906c80 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -23,7 +23,6 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -168,7 +167,7 @@ func createMembershipTable(db *sql.DB) error { func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { s := &membershipStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertMembershipStmt, insertMembershipSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, diff --git a/roomserver/storage/postgres/previous_events_table.go b/roomserver/storage/postgres/previous_events_table.go index 4a93c3d65..bd4e853eb 100644 --- a/roomserver/storage/postgres/previous_events_table.go +++ b/roomserver/storage/postgres/previous_events_table.go @@ -20,7 +20,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -73,7 +72,7 @@ func createPrevEventsTable(db *sql.DB) error { func preparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { s := &previousEventStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, }.Prepare(db) diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go index c180576e3..8deb68441 100644 --- a/roomserver/storage/postgres/published_table.go +++ b/roomserver/storage/postgres/published_table.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -58,7 +57,7 @@ func createPublishedTable(db *sql.DB) error { func preparePublishedTable(db *sql.DB) (tables.Published, error) { s := &publishedStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.upsertPublishedStmt, upsertPublishedSQL}, {&s.selectAllPublishedStmt, selectAllPublishedSQL}, {&s.selectPublishedStmt, selectPublishedSQL}, diff --git a/roomserver/storage/postgres/redactions_table.go b/roomserver/storage/postgres/redactions_table.go index 3741d5f67..5614f2bd8 100644 --- a/roomserver/storage/postgres/redactions_table.go +++ b/roomserver/storage/postgres/redactions_table.go @@ -19,7 +19,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -68,7 +67,7 @@ func createRedactionsTable(db *sql.DB) error { func prepareRedactionsTable(db *sql.DB) (tables.Redactions, error) { s := &redactionStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertRedactionStmt, insertRedactionSQL}, {&s.selectRedactionInfoByRedactionEventIDStmt, selectRedactionInfoByRedactionEventIDSQL}, {&s.selectRedactionInfoByEventBeingRedactedStmt, selectRedactionInfoByEventBeingRedactedSQL}, diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go index c808813ee..031825fee 100644 --- a/roomserver/storage/postgres/room_aliases_table.go +++ b/roomserver/storage/postgres/room_aliases_table.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -70,7 +69,7 @@ func createRoomAliasesTable(db *sql.DB) error { func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { s := &roomAliasesStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index f2b39fe54..ba8eb671b 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -22,7 +22,6 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -104,7 +103,7 @@ func createRoomsTable(db *sql.DB) error { func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { s := &roomStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index 7ae987d6d..27d85e83b 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -23,7 +23,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" @@ -79,7 +79,7 @@ func createStateBlockTable(db *sql.DB) error { func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { s := &stateBlockStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, }.Prepare(db) diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index 15e14e2e0..4fc0fa48a 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -22,7 +22,6 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" @@ -86,7 +85,7 @@ func createStateSnapshotTable(db *sql.DB) error { func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { s := &stateSnapshotStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, }.Prepare(db) diff --git a/roomserver/storage/postgres/transactions_table.go b/roomserver/storage/postgres/transactions_table.go index cada0d8aa..af023b740 100644 --- a/roomserver/storage/postgres/transactions_table.go +++ b/roomserver/storage/postgres/transactions_table.go @@ -19,7 +19,7 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -62,7 +62,7 @@ func createTransactionsTable(db *sql.DB) error { func prepareTransactionsTable(db *sql.DB) (tables.Transactions, error) { s := &transactionStatements{} - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertTransactionStmt, insertTransactionSQL}, {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, }.Prepare(db) diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index 29d54b83d..53b219294 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -22,7 +22,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -63,7 +62,7 @@ func prepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, }.Prepare(db) diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index d430e5535..62fbce2d0 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -22,7 +22,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -80,7 +79,7 @@ func prepareEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 694f4e217..22df3fb22 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -95,7 +94,7 @@ func prepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, {&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL}, {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index e964770d7..a28d95fa5 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -25,7 +25,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -131,7 +130,7 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index e1aa1ebd3..c1d7347ae 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -80,7 +79,7 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertInviteEventStmt, insertInviteEventSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 911a25168..2e58431d3 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -146,7 +145,7 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertMembershipStmt, insertMembershipSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index 3cb527678..7304bf0d5 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -22,7 +22,6 @@ import ( "strings" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -81,7 +80,7 @@ func preparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index 6d9d91355..b07c0ac42 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -60,7 +59,7 @@ func preparePublishedTable(db *sql.DB) (tables.Published, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.upsertPublishedStmt, upsertPublishedSQL}, {&s.selectAllPublishedStmt, selectAllPublishedSQL}, {&s.selectPublishedStmt, selectPublishedSQL}, diff --git a/roomserver/storage/sqlite3/redactions_table.go b/roomserver/storage/sqlite3/redactions_table.go index b34981829..aed190b1e 100644 --- a/roomserver/storage/sqlite3/redactions_table.go +++ b/roomserver/storage/sqlite3/redactions_table.go @@ -19,7 +19,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -69,7 +68,7 @@ func prepareRedactionsTable(db *sql.DB) (tables.Redactions, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertRedactionStmt, insertRedactionSQL}, {&s.selectRedactionInfoByRedactionEventIDStmt, selectRedactionInfoByRedactionEventIDSQL}, {&s.selectRedactionInfoByEventBeingRedactedStmt, selectRedactionInfoByEventBeingRedactedSQL}, diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index 5215fa6f7..323945b88 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -74,7 +73,7 @@ func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 534a870cc..2dfb830d8 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -24,7 +24,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -96,7 +95,7 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 5cb21e91c..a472437a3 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -25,7 +25,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" @@ -75,7 +74,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, }.Prepare(db) diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 95cae99e5..ad623d65a 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -24,7 +24,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" @@ -79,7 +78,7 @@ func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, }.Prepare(db) diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go index e7471d7b0..1fb0a831d 100644 --- a/roomserver/storage/sqlite3/transactions_table.go +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -20,7 +20,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -59,7 +58,7 @@ func prepareTransactionsTable(db *sql.DB) (tables.Transactions, error) { db: db, } - return s, shared.StatementList{ + return s, sqlutil.StatementList{ {&s.insertTransactionStmt, insertTransactionSQL}, {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, }.Prepare(db) diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 79aaebc0b..1bbb485cf 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -554,6 +554,10 @@ func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.Quer func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOpenIDTokenRequest, res *userapi.QueryOpenIDTokenResponse) error { return nil } +func (u *testUserAPI) PerformKeyBackup(ctx context.Context, req *userapi.PerformKeyBackupRequest, res *userapi.PerformKeyBackupResponse) { +} +func (u *testUserAPI) QueryKeyBackup(ctx context.Context, req *userapi.QueryKeyBackupRequest, res *userapi.QueryKeyBackupResponse) { +} type testRoomserverAPI struct { // use a trace API as it implements method stubs so we don't need to have them here. diff --git a/setup/mscs/msc2946/msc2946_test.go b/setup/mscs/msc2946/msc2946_test.go index 96160c10d..2c195d128 100644 --- a/setup/mscs/msc2946/msc2946_test.go +++ b/setup/mscs/msc2946/msc2946_test.go @@ -373,6 +373,10 @@ func (u *testUserAPI) PerformOpenIDTokenCreation(ctx context.Context, req *usera func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error { return nil } +func (u *testUserAPI) PerformKeyBackup(ctx context.Context, req *userapi.PerformKeyBackupRequest, res *userapi.PerformKeyBackupResponse) { +} +func (u *testUserAPI) QueryKeyBackup(ctx context.Context, req *userapi.QueryKeyBackupRequest, res *userapi.QueryKeyBackupResponse) { +} func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error { dev, ok := u.accessTokens[req.AccessToken] if !ok { diff --git a/sytest-whitelist b/sytest-whitelist index 4c25f1f1e..f56ef125e 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -543,3 +543,13 @@ Users with sufficient power-level can delete other's aliases Canonical alias can be set Canonical alias can include alt_aliases Can delete canonical alias +Can create backup version +Can update backup version +Responds correctly when backup is empty +Can backup keys +Can update keys with better versions +Will not update keys with worse versions +Will not back up to an old backup version +Can create more than 10 backup versions +Can delete backup +Deleted & recreated backups are empty diff --git a/userapi/api/api.go b/userapi/api/api.go index 407350123..75d06dd69 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -33,6 +33,8 @@ type UserInternalAPI interface { PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error + PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) + QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error @@ -42,6 +44,84 @@ type UserInternalAPI interface { QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error } +type PerformKeyBackupRequest struct { + UserID string + Version string // optional if modifying a key backup + AuthData json.RawMessage + Algorithm string + DeleteBackup bool // if true will delete the backup based on 'Version'. + + // The keys to upload, if any. If blank, creates/updates/deletes key version metadata only. + Keys struct { + Rooms map[string]struct { + Sessions map[string]KeyBackupSession `json:"sessions"` + } `json:"rooms"` + } +} + +// KeyBackupData in https://spec.matrix.org/unstable/client-server-api/#get_matrixclientr0room_keyskeysroomidsessionid +type KeyBackupSession struct { + FirstMessageIndex int `json:"first_message_index"` + ForwardedCount int `json:"forwarded_count"` + IsVerified bool `json:"is_verified"` + SessionData json.RawMessage `json:"session_data"` +} + +func (a *KeyBackupSession) ShouldReplaceRoomKey(newKey *KeyBackupSession) bool { + // https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 + // "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true" + if newKey.IsVerified && !a.IsVerified { + return true + } else if newKey.FirstMessageIndex < a.FirstMessageIndex { + // "if they have the same values for is_verified, then it will keep the key with a lower first_message_index" + return true + } else if newKey.ForwardedCount < a.ForwardedCount { + // "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count" + return true + } + return false +} + +// Internal KeyBackupData for passing to/from the storage layer +type InternalKeyBackupSession struct { + KeyBackupSession + RoomID string + SessionID string +} + +type PerformKeyBackupResponse struct { + Error string // set if there was a problem performing the request + BadInput bool // if set, the Error was due to bad input (HTTP 400) + + Exists bool // set to true if the Version exists + Version string // the newly created version + + KeyCount int64 // only set if Keys were given in the request + KeyETag string // only set if Keys were given in the request +} + +type QueryKeyBackupRequest struct { + UserID string + Version string // the version to query, if blank it means the latest + + ReturnKeys bool // whether to return keys in the backup response or just the metadata + KeysForRoomID string // optional string to return keys which belong to this room + KeysForSessionID string // optional string to return keys which belong to this (room, session) +} + +type QueryKeyBackupResponse struct { + Error string + Exists bool + + Algorithm string `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` + Count int64 `json:"count"` + ETag string `json:"etag"` + Version string `json:"version"` + + Keys map[string]map[string]KeyBackupSession // the keys if ReturnKeys=true +} + // InputAccountDataRequest is the request for InputAccountData type InputAccountDataRequest struct { UserID string // required: the user to set account data for diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 21933c1c4..a2bc8ecf5 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -442,3 +442,114 @@ func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOp return nil } + +func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { + // Delete metadata + if req.DeleteBackup { + if req.Version == "" { + res.BadInput = true + res.Error = "must specify a version to delete" + return + } + exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version) + if err != nil { + res.Error = fmt.Sprintf("failed to delete backup: %s", err) + } + res.Exists = exists + res.Version = req.Version + return + } + // Create metadata + if req.Version == "" { + version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) + if err != nil { + res.Error = fmt.Sprintf("failed to create backup: %s", err) + } + res.Exists = err == nil + res.Version = version + return + } + // Update metadata + if len(req.Keys.Rooms) == 0 { + err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) + if err != nil { + res.Error = fmt.Sprintf("failed to update backup: %s", err) + } + res.Exists = err == nil + res.Version = req.Version + return + } + // Upload Keys for a specific version metadata + a.uploadBackupKeys(ctx, req, res) +} + +func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { + // you can only upload keys for the CURRENT version + version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "") + if err != nil { + res.Error = fmt.Sprintf("failed to query version: %s", err) + return + } + if deleted { + res.Error = "backup was deleted" + return + } + if version != req.Version { + res.BadInput = true + res.Error = fmt.Sprintf("%s isn't the current version, %s is.", req.Version, version) + return + } + res.Exists = true + res.Version = version + + // map keys to a form we can upload more easily - the map ensures we have no duplicates. + var uploads []api.InternalKeyBackupSession + for roomID, data := range req.Keys.Rooms { + for sessionID, sessionData := range data.Sessions { + uploads = append(uploads, api.InternalKeyBackupSession{ + RoomID: roomID, + SessionID: sessionID, + KeyBackupSession: sessionData, + }) + } + } + count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads) + if err != nil { + res.Error = fmt.Sprintf("failed to upsert keys: %s", err) + return + } + res.KeyCount = count + res.KeyETag = etag +} + +func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { + version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version) + res.Version = version + if err != nil { + if err == sql.ErrNoRows { + res.Exists = false + return + } + res.Error = fmt.Sprintf("failed to query key backup: %s", err) + return + } + res.Algorithm = algorithm + res.AuthData = authData + res.ETag = etag + res.Exists = !deleted + + if !req.ReturnKeys { + res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID) + if err != nil { + res.Error = fmt.Sprintf("failed to count keys: %s", err) + } + return + } + + result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) + if err != nil { + res.Error = fmt.Sprintf("failed to query keys: %s", err) + return + } + res.Keys = result +} diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 1cb5ef0a8..a89d1a266 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -36,7 +36,9 @@ const ( PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" + PerformKeyBackupPath = "/userapi/performKeyBackup" + QueryKeyBackupPath = "/userapi/queryKeyBackup" QueryProfilePath = "/userapi/queryProfile" QueryAccessTokenPath = "/userapi/queryAccessToken" QueryDevicesPath = "/userapi/queryDevices" @@ -225,3 +227,24 @@ func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.Que apiURL := h.apiURL + QueryOpenIDTokenPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpUserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformKeyBackup") + defer span.Finish() + + apiURL := h.apiURL + PerformKeyBackupPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + if err != nil { + res.Error = err.Error() + } +} +func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup") + defer span.Finish() + + apiURL := h.apiURL + QueryKeyBackupPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + if err != nil { + res.Error = err.Error() + } +} diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index 5aa61b909..887f7193b 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -54,6 +54,15 @@ type Database interface { DeactivateAccount(ctx context.Context, localpart string) (err error) CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error) GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) + + // Key backups + CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error) + UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error) + DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error) + GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) + UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error) + GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error) + CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/accounts/postgres/account_data_table.go index 09eb26113..8ba890e75 100644 --- a/userapi/storage/accounts/postgres/account_data_table.go +++ b/userapi/storage/accounts/postgres/account_data_table.go @@ -61,16 +61,11 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { - return - } - if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil { - return - } - if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertAccountDataStmt, insertAccountDataSQL}, + {&s.selectAccountDataStmt, selectAccountDataSQL}, + {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, + }.Prepare(db) } func (s *accountDataStatements) insertAccountData( diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go index 4eaa5b581..b57aa901f 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -81,26 +81,15 @@ func (s *accountsStatements) execSchema(db *sql.DB) error { } func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { - return - } - if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil { - return - } - if s.deactivateAccountStmt, err = db.Prepare(deactivateAccountSQL); err != nil { - return - } - if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { - return - } - if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil { - return - } - if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil { - return - } s.serverName = server - return + return sqlutil.StatementList{ + {&s.insertAccountStmt, insertAccountSQL}, + {&s.updatePasswordStmt, updatePasswordSQL}, + {&s.deactivateAccountStmt, deactivateAccountSQL}, + {&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL}, + {&s.selectPasswordHashStmt, selectPasswordHashSQL}, + {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, + }.Prepare(db) } // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, diff --git a/userapi/storage/accounts/postgres/key_backup_table.go b/userapi/storage/accounts/postgres/key_backup_table.go new file mode 100644 index 000000000..c1402d4d2 --- /dev/null +++ b/userapi/storage/accounts/postgres/key_backup_table.go @@ -0,0 +1,164 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" +) + +const keyBackupTableSchema = ` +CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + + version TEXT NOT NULL, + first_message_index INTEGER NOT NULL, + forwarded_count INTEGER NOT NULL, + is_verified BOOLEAN NOT NULL, + session_data TEXT NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); +CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); +` + +const insertBackupKeySQL = "" + + "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + +const updateBackupKeySQL = "" + + "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + + "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8" + +const countKeysSQL = "" + + "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2" + +const selectKeysSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2" + +const selectKeysByRoomIDSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2 AND room_id = $3" + +const selectKeysByRoomIDAndSessionIDSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" + +type keyBackupStatements struct { + insertBackupKeyStmt *sql.Stmt + updateBackupKeyStmt *sql.Stmt + countKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysByRoomIDStmt *sql.Stmt + selectKeysByRoomIDAndSessionIDStmt *sql.Stmt +} + +func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(keyBackupTableSchema) + if err != nil { + return + } + return sqlutil.StatementList{ + {&s.insertBackupKeyStmt, insertBackupKeySQL}, + {&s.updateBackupKeyStmt, updateBackupKeySQL}, + {&s.countKeysStmt, countKeysSQL}, + {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, + {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, + }.Prepare(db) +} + +func (s keyBackupStatements) countKeys( + ctx context.Context, txn *sql.Tx, userID, version string, +) (count int64, err error) { + err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) + return +} + +func (s *keyBackupStatements) insertBackupKey( + ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, +) (err error) { + _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( + ctx, userID, key.RoomID, key.SessionID, version, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), + ) + return +} + +func (s *keyBackupStatements) updateBackupKey( + ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, +) (err error) { + _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( + ctx, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), userID, key.RoomID, key.SessionID, version, + ) + return +} + +func (s *keyBackupStatements) selectKeys( + ctx context.Context, txn *sql.Tx, userID, version string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func (s *keyBackupStatements) selectKeysByRoomID( + ctx context.Context, txn *sql.Tx, userID, version, roomID string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( + ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) { + result := make(map[string]map[string]api.KeyBackupSession) + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed") + for rows.Next() { + var key api.InternalKeyBackupSession + // room_id, session_id, first_message_index, forwarded_count, is_verified, session_data + var sessionDataStr string + if err := rows.Scan(&key.RoomID, &key.SessionID, &key.FirstMessageIndex, &key.ForwardedCount, &key.IsVerified, &sessionDataStr); err != nil { + return nil, err + } + key.SessionData = json.RawMessage(sessionDataStr) + roomData := result[key.RoomID] + if roomData == nil { + roomData = make(map[string]api.KeyBackupSession) + } + roomData[key.SessionID] = key.KeyBackupSession + result[key.RoomID] = roomData + } + return result, nil +} diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/accounts/postgres/key_backup_version_table.go new file mode 100644 index 000000000..d73447b49 --- /dev/null +++ b/userapi/storage/accounts/postgres/key_backup_version_table.go @@ -0,0 +1,161 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strconv" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const keyBackupVersionTableSchema = ` +CREATE SEQUENCE IF NOT EXISTS account_e2e_room_keys_versions_seq; + +-- the metadata for each generation of encrypted e2e session backups +CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( + user_id TEXT NOT NULL, + -- this means no 2 users will ever have the same version of e2e session backups which strictly + -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. + version BIGINT DEFAULT nextval('account_e2e_room_keys_versions_seq'), + algorithm TEXT NOT NULL, + auth_data TEXT NOT NULL, + etag TEXT NOT NULL, + deleted SMALLINT DEFAULT 0 NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); +` + +const insertKeyBackupSQL = "" + + "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" + +const updateKeyBackupAuthDataSQL = "" + + "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" + +const updateKeyBackupETagSQL = "" + + "UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3" + +const deleteKeyBackupSQL = "" + + "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" + +const selectKeyBackupSQL = "" + + "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + +const selectLatestVersionSQL = "" + + "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" + +type keyBackupVersionStatements struct { + insertKeyBackupStmt *sql.Stmt + updateKeyBackupAuthDataStmt *sql.Stmt + deleteKeyBackupStmt *sql.Stmt + selectKeyBackupStmt *sql.Stmt + selectLatestVersionStmt *sql.Stmt + updateKeyBackupETagStmt *sql.Stmt +} + +func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(keyBackupVersionTableSchema) + if err != nil { + return + } + return sqlutil.StatementList{ + {&s.insertKeyBackupStmt, insertKeyBackupSQL}, + {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, + {&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, + {&s.selectKeyBackupStmt, selectKeyBackupSQL}, + {&s.selectLatestVersionStmt, selectLatestVersionSQL}, + {&s.updateKeyBackupETagStmt, updateKeyBackupETagSQL}, + }.Prepare(db) +} + +func (s *keyBackupVersionStatements) insertKeyBackup( + ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, +) (version string, err error) { + var versionInt int64 + err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData), etag).Scan(&versionInt) + return strconv.FormatInt(versionInt, 10), err +} + +func (s *keyBackupVersionStatements) updateKeyBackupAuthData( + ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, +) error { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return fmt.Errorf("invalid version") + } + _, err = txn.Stmt(s.updateKeyBackupAuthDataStmt).ExecContext(ctx, string(authData), userID, versionInt) + return err +} + +func (s *keyBackupVersionStatements) updateKeyBackupETag( + ctx context.Context, txn *sql.Tx, userID, version, etag string, +) error { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return fmt.Errorf("invalid version") + } + _, err = txn.Stmt(s.updateKeyBackupETagStmt).ExecContext(ctx, etag, userID, versionInt) + return err +} + +func (s *keyBackupVersionStatements) deleteKeyBackup( + ctx context.Context, txn *sql.Tx, userID, version string, +) (bool, error) { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return false, fmt.Errorf("invalid version") + } + result, err := txn.Stmt(s.deleteKeyBackupStmt).ExecContext(ctx, userID, versionInt) + if err != nil { + return false, err + } + ra, err := result.RowsAffected() + if err != nil { + return false, err + } + return ra == 1, nil +} + +func (s *keyBackupVersionStatements) selectKeyBackup( + ctx context.Context, txn *sql.Tx, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { + var versionInt int64 + if version == "" { + var v *int64 // allows nulls + if err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&v); err != nil { + return + } + if v == nil { + err = sql.ErrNoRows + return + } + versionInt = *v + } else { + if versionInt, err = strconv.ParseInt(version, 10, 64); err != nil { + return + } + } + versionResult = strconv.FormatInt(versionInt, 10) + var deletedInt int + var authDataStr string + err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &etag, &deletedInt) + deleted = deletedInt == 1 + authData = json.RawMessage(authDataStr) + return +} diff --git a/userapi/storage/accounts/postgres/openid_table.go b/userapi/storage/accounts/postgres/openid_table.go index 86c197059..190d141b7 100644 --- a/userapi/storage/accounts/postgres/openid_table.go +++ b/userapi/storage/accounts/postgres/openid_table.go @@ -39,14 +39,11 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam if err != nil { return } - if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil { - return - } - if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil { - return - } s.serverName = server - return + return sqlutil.StatementList{ + {&s.insertTokenStmt, insertTokenSQL}, + {&s.selectTokenStmt, selectTokenSQL}, + }.Prepare(db) } // insertToken inserts a new OpenID Connect token to the DB. diff --git a/userapi/storage/accounts/postgres/profile_table.go b/userapi/storage/accounts/postgres/profile_table.go index 45d802f18..9313864be 100644 --- a/userapi/storage/accounts/postgres/profile_table.go +++ b/userapi/storage/accounts/postgres/profile_table.go @@ -64,22 +64,13 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil { - return - } - if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil { - return - } - if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil { - return - } - if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { - return - } - if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertProfileStmt, insertProfileSQL}, + {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, + {&s.setAvatarURLStmt, setAvatarURLSQL}, + {&s.setDisplayNameStmt, setDisplayNameSQL}, + {&s.selectProfilesBySearchStmt, selectProfilesBySearchSQL}, + }.Prepare(db) } func (s *profilesStatements) insertProfile( diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index c5e74ed15..6bddbfc3f 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -19,6 +19,7 @@ import ( "database/sql" "encoding/json" "errors" + "fmt" "strconv" "time" @@ -45,6 +46,8 @@ type Database struct { accountDatas accountDataStatements threepids threepidStatements openIDTokens tokenStatements + keyBackupVersions keyBackupVersionStatements + keyBackups keyBackupStatements serverName gomatrixserverlib.ServerName bcryptCost int openIDTokenLifetimeMS int64 @@ -93,6 +96,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.openIDTokens.prepare(db, serverName); err != nil { return nil, err } + if err = d.keyBackupVersions.prepare(db); err != nil { + return nil, err + } + if err = d.keyBackups.prepare(db); err != nil { + return nil, err + } return d, nil } @@ -368,3 +377,145 @@ func (d *Database) GetOpenIDTokenAttributes( ) (*api.OpenIDTokenAttributes, error) { return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) } + +func (d *Database) CreateKeyBackup( + ctx context.Context, userID, algorithm string, authData json.RawMessage, +) (version string, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "") + return err + }) + return +} + +func (d *Database) UpdateKeyBackupAuthData( + ctx context.Context, userID, version string, authData json.RawMessage, +) (err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData) + }) + return +} + +func (d *Database) DeleteKeyBackup( + ctx context.Context, userID, version string, +) (exists bool, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) GetKeyBackup( + ctx context.Context, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) GetBackupKeys( + ctx context.Context, version, userID, filterRoomID, filterSessionID string, +) (result map[string]map[string]api.KeyBackupSession, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + if filterSessionID != "" { + result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) + return err + } + if filterRoomID != "" { + result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID) + return err + } + result, err = d.keyBackups.selectKeys(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) CountBackupKeys( + ctx context.Context, version, userID string, +) (count int64, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + if err != nil { + return err + } + return nil + }) + return +} + +// nolint:nakedret +func (d *Database) UpsertBackupKeys( + ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, +) (count int64, etag string, err error) { + // wrap the following logic in a txn to ensure we atomically upload keys + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + _, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) + if err != nil { + return err + } + if deleted { + return fmt.Errorf("backup was deleted") + } + // pull out all keys for this (user_id, version) + existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version) + if err != nil { + return err + } + + changed := false + // loop over all the new keys (which should be smaller than the set of backed up keys) + for _, newKey := range uploads { + // if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them. + existingRoom := existingKeys[newKey.RoomID] + if existingRoom != nil { + existingSession, ok := existingRoom[newKey.SessionID] + if ok { + if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { + err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) + } + } + // if we shouldn't replace the key we do nothing with it + continue + } + } + // if we're here, either the room or session are new, either way, we insert + err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) + } + } + + count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + if err != nil { + return err + } + if changed { + // update the etag + var newETag string + if oldETag == "" { + newETag = "1" + } else { + oldETagInt, err := strconv.ParseInt(oldETag, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse old etag: %s", err) + } + newETag = strconv.FormatInt(oldETagInt+1, 10) + } + etag = newETag + return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag) + } else { + etag = oldETag + } + return nil + }) + return +} diff --git a/userapi/storage/accounts/postgres/threepid_table.go b/userapi/storage/accounts/postgres/threepid_table.go index 7de96350c..9280fc87c 100644 --- a/userapi/storage/accounts/postgres/threepid_table.go +++ b/userapi/storage/accounts/postgres/threepid_table.go @@ -63,20 +63,12 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil { - return - } - if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil { - return - } - if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil { - return - } - if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil { - return - } - - return + return sqlutil.StatementList{ + {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, + {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, + {&s.insertThreePIDStmt, insertThreePIDSQL}, + {&s.deleteThreePIDStmt, deleteThreePIDSQL}, + }.Prepare(db) } func (s *threepidStatements) selectLocalpartForThreePID( diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go index 870a37065..871f996e0 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -62,16 +62,11 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { - return - } - if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil { - return - } - if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertAccountDataStmt, insertAccountDataSQL}, + {&s.selectAccountDataStmt, selectAccountDataSQL}, + {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, + }.Prepare(db) } func (s *accountDataStatements) insertAccountData( diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 50f07237e..8a7c8fba7 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -81,26 +81,15 @@ func (s *accountsStatements) execSchema(db *sql.DB) error { func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { s.db = db - if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { - return - } - if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil { - return - } - if s.deactivateAccountStmt, err = db.Prepare(deactivateAccountSQL); err != nil { - return - } - if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { - return - } - if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil { - return - } - if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil { - return - } s.serverName = server - return + return sqlutil.StatementList{ + {&s.insertAccountStmt, insertAccountSQL}, + {&s.updatePasswordStmt, updatePasswordSQL}, + {&s.deactivateAccountStmt, deactivateAccountSQL}, + {&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL}, + {&s.selectPasswordHashStmt, selectPasswordHashSQL}, + {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, + }.Prepare(db) } // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, diff --git a/userapi/storage/accounts/sqlite3/key_backup_table.go b/userapi/storage/accounts/sqlite3/key_backup_table.go new file mode 100644 index 000000000..837d38cf1 --- /dev/null +++ b/userapi/storage/accounts/sqlite3/key_backup_table.go @@ -0,0 +1,164 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" +) + +const keyBackupTableSchema = ` +CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + + version TEXT NOT NULL, + first_message_index INTEGER NOT NULL, + forwarded_count INTEGER NOT NULL, + is_verified BOOLEAN NOT NULL, + session_data TEXT NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); +CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); +` + +const insertBackupKeySQL = "" + + "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + +const updateBackupKeySQL = "" + + "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + + "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8" + +const countKeysSQL = "" + + "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2" + +const selectKeysSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2" + +const selectKeysByRoomIDSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2 AND room_id = $3" + +const selectKeysByRoomIDAndSessionIDSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" + +type keyBackupStatements struct { + insertBackupKeyStmt *sql.Stmt + updateBackupKeyStmt *sql.Stmt + countKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysByRoomIDStmt *sql.Stmt + selectKeysByRoomIDAndSessionIDStmt *sql.Stmt +} + +func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(keyBackupTableSchema) + if err != nil { + return + } + return sqlutil.StatementList{ + {&s.insertBackupKeyStmt, insertBackupKeySQL}, + {&s.updateBackupKeyStmt, updateBackupKeySQL}, + {&s.countKeysStmt, countKeysSQL}, + {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, + {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, + }.Prepare(db) +} + +func (s keyBackupStatements) countKeys( + ctx context.Context, txn *sql.Tx, userID, version string, +) (count int64, err error) { + err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) + return +} + +func (s *keyBackupStatements) insertBackupKey( + ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, +) (err error) { + _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( + ctx, userID, key.RoomID, key.SessionID, version, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), + ) + return +} + +func (s *keyBackupStatements) updateBackupKey( + ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, +) (err error) { + _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( + ctx, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), userID, key.RoomID, key.SessionID, version, + ) + return +} + +func (s *keyBackupStatements) selectKeys( + ctx context.Context, txn *sql.Tx, userID, version string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func (s *keyBackupStatements) selectKeysByRoomID( + ctx context.Context, txn *sql.Tx, userID, version, roomID string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( + ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) { + result := make(map[string]map[string]api.KeyBackupSession) + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed") + for rows.Next() { + var key api.InternalKeyBackupSession + // room_id, session_id, first_message_index, forwarded_count, is_verified, session_data + var sessionDataStr string + if err := rows.Scan(&key.RoomID, &key.SessionID, &key.FirstMessageIndex, &key.ForwardedCount, &key.IsVerified, &sessionDataStr); err != nil { + return nil, err + } + key.SessionData = json.RawMessage(sessionDataStr) + roomData := result[key.RoomID] + if roomData == nil { + roomData = make(map[string]api.KeyBackupSession) + } + roomData[key.SessionID] = key.KeyBackupSession + result[key.RoomID] = roomData + } + return result, nil +} diff --git a/userapi/storage/accounts/sqlite3/key_backup_version_table.go b/userapi/storage/accounts/sqlite3/key_backup_version_table.go new file mode 100644 index 000000000..4211ed0f1 --- /dev/null +++ b/userapi/storage/accounts/sqlite3/key_backup_version_table.go @@ -0,0 +1,159 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strconv" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const keyBackupVersionTableSchema = ` +-- the metadata for each generation of encrypted e2e session backups +CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( + user_id TEXT NOT NULL, + -- this means no 2 users will ever have the same version of e2e session backups which strictly + -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. + version INTEGER PRIMARY KEY AUTOINCREMENT, + algorithm TEXT NOT NULL, + auth_data TEXT NOT NULL, + etag TEXT NOT NULL, + deleted INTEGER DEFAULT 0 NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); +` + +const insertKeyBackupSQL = "" + + "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" + +const updateKeyBackupAuthDataSQL = "" + + "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" + +const updateKeyBackupETagSQL = "" + + "UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3" + +const deleteKeyBackupSQL = "" + + "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" + +const selectKeyBackupSQL = "" + + "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + +const selectLatestVersionSQL = "" + + "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" + +type keyBackupVersionStatements struct { + insertKeyBackupStmt *sql.Stmt + updateKeyBackupAuthDataStmt *sql.Stmt + deleteKeyBackupStmt *sql.Stmt + selectKeyBackupStmt *sql.Stmt + selectLatestVersionStmt *sql.Stmt + updateKeyBackupETagStmt *sql.Stmt +} + +func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(keyBackupVersionTableSchema) + if err != nil { + return + } + return sqlutil.StatementList{ + {&s.insertKeyBackupStmt, insertKeyBackupSQL}, + {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, + {&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, + {&s.selectKeyBackupStmt, selectKeyBackupSQL}, + {&s.selectLatestVersionStmt, selectLatestVersionSQL}, + {&s.updateKeyBackupETagStmt, updateKeyBackupETagSQL}, + }.Prepare(db) +} + +func (s *keyBackupVersionStatements) insertKeyBackup( + ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, +) (version string, err error) { + var versionInt int64 + err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData), etag).Scan(&versionInt) + return strconv.FormatInt(versionInt, 10), err +} + +func (s *keyBackupVersionStatements) updateKeyBackupAuthData( + ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, +) error { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return fmt.Errorf("invalid version") + } + _, err = txn.Stmt(s.updateKeyBackupAuthDataStmt).ExecContext(ctx, string(authData), userID, versionInt) + return err +} + +func (s *keyBackupVersionStatements) updateKeyBackupETag( + ctx context.Context, txn *sql.Tx, userID, version, etag string, +) error { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return fmt.Errorf("invalid version") + } + _, err = txn.Stmt(s.updateKeyBackupETagStmt).ExecContext(ctx, etag, userID, versionInt) + return err +} + +func (s *keyBackupVersionStatements) deleteKeyBackup( + ctx context.Context, txn *sql.Tx, userID, version string, +) (bool, error) { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return false, fmt.Errorf("invalid version") + } + result, err := txn.Stmt(s.deleteKeyBackupStmt).ExecContext(ctx, userID, versionInt) + if err != nil { + return false, err + } + ra, err := result.RowsAffected() + if err != nil { + return false, err + } + return ra == 1, nil +} + +func (s *keyBackupVersionStatements) selectKeyBackup( + ctx context.Context, txn *sql.Tx, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { + var versionInt int64 + if version == "" { + var v *int64 // allows nulls + if err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&v); err != nil { + return + } + if v == nil { + err = sql.ErrNoRows + return + } + versionInt = *v + } else { + if versionInt, err = strconv.ParseInt(version, 10, 64); err != nil { + return + } + } + versionResult = strconv.FormatInt(versionInt, 10) + var deletedInt int + var authDataStr string + err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &etag, &deletedInt) + deleted = deletedInt == 1 + authData = json.RawMessage(authDataStr) + return +} diff --git a/userapi/storage/accounts/sqlite3/openid_table.go b/userapi/storage/accounts/sqlite3/openid_table.go index 80b9dd4cb..98c0488b1 100644 --- a/userapi/storage/accounts/sqlite3/openid_table.go +++ b/userapi/storage/accounts/sqlite3/openid_table.go @@ -41,14 +41,11 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam if err != nil { return err } - if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil { - return - } - if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil { - return - } s.serverName = server - return + return sqlutil.StatementList{ + {&s.insertTokenStmt, insertTokenSQL}, + {&s.selectTokenStmt, selectTokenSQL}, + }.Prepare(db) } // insertToken inserts a new OpenID Connect token to the DB. diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go index a67e892f7..a92e95663 100644 --- a/userapi/storage/accounts/sqlite3/profile_table.go +++ b/userapi/storage/accounts/sqlite3/profile_table.go @@ -66,22 +66,13 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil { - return - } - if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil { - return - } - if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil { - return - } - if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { - return - } - if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertProfileStmt, insertProfileSQL}, + {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, + {&s.setAvatarURLStmt, setAvatarURLSQL}, + {&s.setDisplayNameStmt, setDisplayNameSQL}, + {&s.selectProfilesBySearchStmt, selectProfilesBySearchSQL}, + }.Prepare(db) } func (s *profilesStatements) insertProfile( diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index c0f7118cb..d752e3db8 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -19,6 +19,7 @@ import ( "database/sql" "encoding/json" "errors" + "fmt" "strconv" "sync" "time" @@ -43,6 +44,8 @@ type Database struct { accountDatas accountDataStatements threepids threepidStatements openIDTokens tokenStatements + keyBackupVersions keyBackupVersionStatements + keyBackups keyBackupStatements serverName gomatrixserverlib.ServerName bcryptCost int openIDTokenLifetimeMS int64 @@ -97,6 +100,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.openIDTokens.prepare(db, serverName); err != nil { return nil, err } + if err = d.keyBackupVersions.prepare(db); err != nil { + return nil, err + } + if err = d.keyBackups.prepare(db); err != nil { + return nil, err + } return d, nil } @@ -156,8 +165,9 @@ func (d *Database) SetPassword( if err != nil { return err } - err = d.accounts.updatePassword(ctx, localpart, hash) - return err + return d.writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.accounts.updatePassword(ctx, localpart, hash) + }) } // CreateGuestAccount makes a new guest account and creates an empty profile @@ -384,7 +394,9 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi // DeactivateAccount deactivates the user's account, removing all ability for the user to login again. func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { - return d.accounts.deactivateAccount(ctx, localpart) + return d.writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.accounts.deactivateAccount(ctx, localpart) + }) } // CreateOpenIDToken persists a new token that was issued for OpenID Connect @@ -406,3 +418,146 @@ func (d *Database) GetOpenIDTokenAttributes( ) (*api.OpenIDTokenAttributes, error) { return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) } + +func (d *Database) CreateKeyBackup( + ctx context.Context, userID, algorithm string, authData json.RawMessage, +) (version string, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "") + return err + }) + return +} + +func (d *Database) UpdateKeyBackupAuthData( + ctx context.Context, userID, version string, authData json.RawMessage, +) (err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData) + }) + return +} + +func (d *Database) DeleteKeyBackup( + ctx context.Context, userID, version string, +) (exists bool, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) GetKeyBackup( + ctx context.Context, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) GetBackupKeys( + ctx context.Context, version, userID, filterRoomID, filterSessionID string, +) (result map[string]map[string]api.KeyBackupSession, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + if filterSessionID != "" { + result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) + return err + } + if filterRoomID != "" { + result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID) + return err + } + result, err = d.keyBackups.selectKeys(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) CountBackupKeys( + ctx context.Context, version, userID string, +) (count int64, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + if err != nil { + return err + } + return nil + }) + return +} + +// nolint:nakedret +func (d *Database) UpsertBackupKeys( + ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, +) (count int64, etag string, err error) { + // wrap the following logic in a txn to ensure we atomically upload keys + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + _, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) + if err != nil { + return err + } + if deleted { + return fmt.Errorf("backup was deleted") + } + // pull out all keys for this (user_id, version) + existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version) + if err != nil { + return err + } + + changed := false + // loop over all the new keys (which should be smaller than the set of backed up keys) + for _, newKey := range uploads { + // if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them. + existingRoom := existingKeys[newKey.RoomID] + if existingRoom != nil { + existingSession, ok := existingRoom[newKey.SessionID] + if ok { + if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { + err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) + } + } + // if we shouldn't replace the key we do nothing with it + continue + } + } + // if we're here, either the room or session are new, either way, we insert + err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) + } + } + + count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + if err != nil { + return err + } + if changed { + // update the etag + var newETag string + if oldETag == "" { + newETag = "1" + } else { + oldETagInt, err := strconv.ParseInt(oldETag, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse old etag: %s", err) + } + newETag = strconv.FormatInt(oldETagInt+1, 10) + } + etag = newETag + return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag) + } else { + etag = oldETag + } + + return nil + }) + return +} diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go index 43112d389..9dc0e2d22 100644 --- a/userapi/storage/accounts/sqlite3/threepid_table.go +++ b/userapi/storage/accounts/sqlite3/threepid_table.go @@ -66,20 +66,12 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil { - return - } - if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil { - return - } - if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil { - return - } - if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil { - return - } - - return + return sqlutil.StatementList{ + {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, + {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, + {&s.insertThreePIDStmt, insertThreePIDSQL}, + {&s.deleteThreePIDStmt, deleteThreePIDSQL}, + }.Prepare(db) } func (s *threepidStatements) selectLocalpartForThreePID(