diff --git a/CHANGES.md b/CHANGES.md index 27356b3cb..8207d4844 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,25 @@ # Changelog +## Dendrite 0.4.1 (2021-07-26) + +### Features + +* Support for room version 7 has been added +* Key notary support is now more complete, allowing Dendrite to be used as a notary server for looking up signing keys +* State resolution v2 performance has been optimised further by caching the create event, power levels and join rules in memory instead of parsing them repeatedly +* The media API now handles cases where the maximum file size is configured to be less than 0 for unlimited size +* The `initial_state` in a `/createRoom` request is now respected when creating a room +* Code paths for checking if servers are joined to rooms have been optimised significantly + +### Fixes + +* A bug resulting in `cannot xref null state block with snapshot` during the new state storage migration has been fixed +* Invites are now retired correctly when rejecting an invite from a remote server which is no longer reachable +* The DNS cache `cache_lifetime` option is now handled correctly (contributed by [S7evinK](https://github.com/S7evinK)) +* Invalid events in a room join response are now dropped correctly, rather than failing the entire join +* The `prev_state` of an event will no longer be populated incorrectly to the state of the current event +* Receiving an invite to an unsupported room version will now correctly return the `M_UNSUPPORTED_ROOM_VERSION` error code instead of `M_BAD_JSON` (contributed by [meenal06](https://github.com/meenal06)) + ## Dendrite 0.4.0 (2021-07-12) ### Features diff --git a/clientapi/routing/key_backup.go b/clientapi/routing/key_backup.go new file mode 100644 index 000000000..ce62a047a --- /dev/null +++ b/clientapi/routing/key_backup.go @@ -0,0 +1,291 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +type keyBackupVersion struct { + Algorithm string `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` +} + +type keyBackupVersionCreateResponse struct { + Version string `json:"version"` +} + +type keyBackupVersionResponse struct { + Algorithm string `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` + Count int64 `json:"count"` + ETag string `json:"etag"` + Version string `json:"version"` +} + +type keyBackupSessionRequest struct { + Rooms map[string]struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + } `json:"rooms"` +} + +type keyBackupSessionResponse struct { + Count int64 `json:"count"` + ETag string `json:"etag"` +} + +// Create a new key backup. Request must contain a `keyBackupVersion`. Returns a `keyBackupVersionCreateResponse`. +// Implements POST /_matrix/client/r0/room_keys/version +func CreateKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device) util.JSONResponse { + var kb keyBackupVersion + resErr := httputil.UnmarshalJSONRequest(req, &kb) + if resErr != nil { + return *resErr + } + var performKeyBackupResp userapi.PerformKeyBackupResponse + userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + UserID: device.UserID, + Version: "", + AuthData: kb.AuthData, + Algorithm: kb.Algorithm, + }, &performKeyBackupResp) + if performKeyBackupResp.Error != "" { + if performKeyBackupResp.BadInput { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), + } + } + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + } + return util.JSONResponse{ + Code: 200, + JSON: keyBackupVersionCreateResponse{ + Version: performKeyBackupResp.Version, + }, + } +} + +// KeyBackupVersion returns the key backup version specified. If `version` is empty, the latest `keyBackupVersionResponse` is returned. +// Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version} +func KeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { + var queryResp userapi.QueryKeyBackupResponse + userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ + UserID: device.UserID, + Version: version, + }, &queryResp) + if queryResp.Error != "" { + return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) + } + if !queryResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("version not found"), + } + } + return util.JSONResponse{ + Code: 200, + JSON: keyBackupVersionResponse{ + Algorithm: queryResp.Algorithm, + AuthData: queryResp.AuthData, + Count: queryResp.Count, + ETag: queryResp.ETag, + Version: queryResp.Version, + }, + } +} + +// Modify the auth data of a key backup. Version must not be empty. Request must contain a `keyBackupVersion` +// Implements PUT /_matrix/client/r0/room_keys/version/{version} +func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { + var kb keyBackupVersion + resErr := httputil.UnmarshalJSONRequest(req, &kb) + if resErr != nil { + return *resErr + } + var performKeyBackupResp userapi.PerformKeyBackupResponse + userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + UserID: device.UserID, + Version: version, + AuthData: kb.AuthData, + Algorithm: kb.Algorithm, + }, &performKeyBackupResp) + if performKeyBackupResp.Error != "" { + if performKeyBackupResp.BadInput { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), + } + } + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + } + if !performKeyBackupResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("backup version not found"), + } + } + // Unclear what the 200 body should be + return util.JSONResponse{ + Code: 200, + JSON: keyBackupVersionCreateResponse{ + Version: performKeyBackupResp.Version, + }, + } +} + +// Delete a version of key backup. Version must not be empty. If the key backup was previously deleted, will return 200 OK. +// Implements DELETE /_matrix/client/r0/room_keys/version/{version} +func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { + var performKeyBackupResp userapi.PerformKeyBackupResponse + userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + UserID: device.UserID, + Version: version, + DeleteBackup: true, + }, &performKeyBackupResp) + if performKeyBackupResp.Error != "" { + if performKeyBackupResp.BadInput { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), + } + } + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + } + if !performKeyBackupResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("backup version not found"), + } + } + // Unclear what the 200 body should be + return util.JSONResponse{ + Code: 200, + JSON: keyBackupVersionCreateResponse{ + Version: performKeyBackupResp.Version, + }, + } +} + +// Upload a bunch of session keys for a given `version`. +func UploadBackupKeys( + req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string, keys *keyBackupSessionRequest, +) util.JSONResponse { + var performKeyBackupResp userapi.PerformKeyBackupResponse + userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + UserID: device.UserID, + Version: version, + Keys: *keys, + }, &performKeyBackupResp) + if performKeyBackupResp.Error != "" { + if performKeyBackupResp.BadInput { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), + } + } + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + } + if !performKeyBackupResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("backup version not found"), + } + } + return util.JSONResponse{ + Code: 200, + JSON: keyBackupSessionResponse{ + Count: performKeyBackupResp.KeyCount, + ETag: performKeyBackupResp.KeyETag, + }, + } +} + +// Get keys from a given backup version. Response returned varies depending on if roomID and sessionID are set. +func GetBackupKeys( + req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version, roomID, sessionID string, +) util.JSONResponse { + var queryResp userapi.QueryKeyBackupResponse + userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ + UserID: device.UserID, + Version: version, + ReturnKeys: true, + KeysForRoomID: roomID, + KeysForSessionID: sessionID, + }, &queryResp) + if queryResp.Error != "" { + return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) + } + if !queryResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("version not found"), + } + } + if sessionID != "" { + // return the key itself if it was found + roomData, ok := queryResp.Keys[roomID] + if ok { + key, ok := roomData[sessionID] + if ok { + return util.JSONResponse{ + Code: 200, + JSON: key, + } + } + } + } else if roomID != "" { + roomData, ok := queryResp.Keys[roomID] + if ok { + // wrap response in "sessions" + return util.JSONResponse{ + Code: 200, + JSON: struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }{ + Sessions: roomData, + }, + } + } + } else { + // response is the same as the upload request + var resp keyBackupSessionRequest + resp.Rooms = make(map[string]struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }) + for roomID, roomData := range queryResp.Keys { + resp.Rooms[roomID] = struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }{ + Sessions: roomData, + } + } + return util.JSONResponse{ + Code: 200, + JSON: resp, + } + } + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("keys not found"), + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 8fb657174..7c27d4069 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -882,6 +882,176 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) + // Key Backup Versions (Metadata) + + getBackupKeysVersion := httputil.MakeAuthAPI("get_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return KeyBackupVersion(req, userAPI, device, vars["version"]) + }) + + getLatestBackupKeysVersion := httputil.MakeAuthAPI("get_latest_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return KeyBackupVersion(req, userAPI, device, "") + }) + + putBackupKeysVersion := httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return ModifyKeyBackupVersionAuthData(req, userAPI, device, vars["version"]) + }) + + deleteBackupKeysVersion := httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return DeleteKeyBackupVersion(req, userAPI, device, vars["version"]) + }) + + postNewBackupKeysVersion := httputil.MakeAuthAPI("post_new_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return CreateKeyBackupVersion(req, userAPI, device) + }) + + r0mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut) + r0mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete) + r0mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions) + + unstableMux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + unstableMux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + unstableMux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut) + unstableMux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete) + unstableMux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions) + + // Inserting E2E Backup Keys + + // Bulk room and session + putBackupKeys := httputil.MakeAuthAPI("put_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + version := req.URL.Query().Get("version") + if version == "" { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("version must be specified"), + } + } + var reqBody keyBackupSessionRequest + resErr := clientutil.UnmarshalJSONRequest(req, &reqBody) + if resErr != nil { + return *resErr + } + return UploadBackupKeys(req, userAPI, device, version, &reqBody) + }) + + // Single room bulk session + putBackupKeysRoom := httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + version := req.URL.Query().Get("version") + if version == "" { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("version must be specified"), + } + } + roomID := vars["roomID"] + var reqBody keyBackupSessionRequest + reqBody.Rooms = make(map[string]struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }) + reqBody.Rooms[roomID] = struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }{ + Sessions: map[string]userapi.KeyBackupSession{}, + } + body := reqBody.Rooms[roomID] + resErr := clientutil.UnmarshalJSONRequest(req, &body) + if resErr != nil { + return *resErr + } + reqBody.Rooms[roomID] = body + return UploadBackupKeys(req, userAPI, device, version, &reqBody) + }) + + // Single room, single session + putBackupKeysRoomSession := httputil.MakeAuthAPI("put_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + version := req.URL.Query().Get("version") + if version == "" { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("version must be specified"), + } + } + var reqBody userapi.KeyBackupSession + resErr := clientutil.UnmarshalJSONRequest(req, &reqBody) + if resErr != nil { + return *resErr + } + roomID := vars["roomID"] + sessionID := vars["sessionID"] + var keyReq keyBackupSessionRequest + keyReq.Rooms = make(map[string]struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }) + keyReq.Rooms[roomID] = struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }{ + Sessions: make(map[string]userapi.KeyBackupSession), + } + keyReq.Rooms[roomID].Sessions[sessionID] = reqBody + return UploadBackupKeys(req, userAPI, device, version, &keyReq) + }) + + r0mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) + r0mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) + r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut) + + unstableMux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) + unstableMux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) + unstableMux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut) + + // Querying E2E Backup Keys + + getBackupKeys := httputil.MakeAuthAPI("get_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), "", "") + }) + + getBackupKeysRoom := httputil.MakeAuthAPI("get_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], "") + }) + + getBackupKeysRoomSession := httputil.MakeAuthAPI("get_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"]) + }) + + r0mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions) + + unstableMux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) + unstableMux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) + unstableMux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions) + + // Deleting E2E Backup Keys + // Supplying a device ID is deprecated. r0mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 8795118ee..468659651 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -40,22 +40,29 @@ func InviteV2( ) util.JSONResponse { inviteReq := gomatrixserverlib.InviteV2Request{} err := json.Unmarshal(request.Content(), &inviteReq) - switch err.(type) { + switch e := err.(type) { + case gomatrixserverlib.UnsupportedRoomVersionError: + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.UnsupportedRoomVersion( + fmt.Sprintf("Room version %q is not supported by this server.", e.Version), + ), + } case gomatrixserverlib.BadJSONError: return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error()), } case nil: + return processInvite( + httpReq.Context(), true, inviteReq.Event(), inviteReq.RoomVersion(), inviteReq.InviteRoomState(), roomID, eventID, cfg, rsAPI, keys, + ) default: return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.NotJSON("The request body could not be decoded into an invite request. " + err.Error()), } } - return processInvite( - httpReq.Context(), true, inviteReq.Event(), inviteReq.RoomVersion(), inviteReq.InviteRoomState(), roomID, eventID, cfg, rsAPI, keys, - ) } // InviteV1 implements /_matrix/federation/v1/invite/{roomID}/{eventID} diff --git a/go.mod b/go.mod index 30e3c97bd..fbad10d6e 100644 --- a/go.mod +++ b/go.mod @@ -37,7 +37,7 @@ require ( github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb github.com/morikuni/aec v1.0.0 // indirect - github.com/neilalexander/utp v0.1.1-0.20210720104546-52626cdf31b2 + github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6 github.com/opentracing/opentracing-go v1.2.0 diff --git a/go.sum b/go.sum index a090dbc33..ff6090a56 100644 --- a/go.sum +++ b/go.sum @@ -1184,8 +1184,8 @@ github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/neilalexander/utp v0.1.1-0.20210622132614-ee9a34a30488/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= -github.com/neilalexander/utp v0.1.1-0.20210720104546-52626cdf31b2 h1:txJOiDxsypF8RbzbcyOD3ovip+uUWNZE/Zo7qLdARZQ= -github.com/neilalexander/utp v0.1.1-0.20210720104546-52626cdf31b2/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= +github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 h1:lrVQzBtkeQEGGYUHwSX1XPe1E5GL6U3KYCNe2G4bncQ= +github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6 h1:evlcQnJY+v8XRRchV3hXzpHDl6GcEZeLXAhlH9Csdww= diff --git a/internal/version.go b/internal/version.go index 37f0c30d3..55997ffc4 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 4 - VersionPatch = 0 + VersionPatch = 1 VersionTag = "" // example: "rc1" ) diff --git a/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go b/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go index 26d88e00f..940a920e1 100644 --- a/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go +++ b/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go @@ -174,6 +174,9 @@ func UpStateBlocksRefactor(tx *sql.Tx) error { err = tx.QueryRow( `SELECT event_nid FROM roomserver_events WHERE state_snapshot_nid = $1 AND event_type_nid = 1`, s.StateSnapshotNID, ).Scan(&createEventNID) + if err == sql.ErrNoRows { + continue + } if err != nil { return fmt.Errorf("cannot xref null state block with snapshot %d: %s", s.StateSnapshotNID, err) } diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 79aaebc0b..1bbb485cf 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -554,6 +554,10 @@ func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.Quer func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOpenIDTokenRequest, res *userapi.QueryOpenIDTokenResponse) error { return nil } +func (u *testUserAPI) PerformKeyBackup(ctx context.Context, req *userapi.PerformKeyBackupRequest, res *userapi.PerformKeyBackupResponse) { +} +func (u *testUserAPI) QueryKeyBackup(ctx context.Context, req *userapi.QueryKeyBackupRequest, res *userapi.QueryKeyBackupResponse) { +} type testRoomserverAPI struct { // use a trace API as it implements method stubs so we don't need to have them here. diff --git a/setup/mscs/msc2946/msc2946_test.go b/setup/mscs/msc2946/msc2946_test.go index 96160c10d..2c195d128 100644 --- a/setup/mscs/msc2946/msc2946_test.go +++ b/setup/mscs/msc2946/msc2946_test.go @@ -373,6 +373,10 @@ func (u *testUserAPI) PerformOpenIDTokenCreation(ctx context.Context, req *usera func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error { return nil } +func (u *testUserAPI) PerformKeyBackup(ctx context.Context, req *userapi.PerformKeyBackupRequest, res *userapi.PerformKeyBackupResponse) { +} +func (u *testUserAPI) QueryKeyBackup(ctx context.Context, req *userapi.QueryKeyBackupRequest, res *userapi.QueryKeyBackupResponse) { +} func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error { dev, ok := u.accessTokens[req.AccessToken] if !ok { diff --git a/sytest-whitelist b/sytest-whitelist index 720230e40..657d9c394 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -540,5 +540,15 @@ Key notary server must not overwrite a valid key with a spurious result from the GET /rooms/:room_id/aliases lists aliases Only room members can list aliases of a room Users with sufficient power-level can delete other's aliases +Can create backup version +Can update backup version +Responds correctly when backup is empty +Can backup keys +Can update keys with better versions +Will not update keys with worse versions +Will not back up to an old backup version +Can create more than 10 backup versions +Can delete backup +Deleted & recreated backups are empty GET /presence/:user_id/status fetches initial status PUT /presence/:user_id/status updates my presence \ No newline at end of file diff --git a/userapi/api/api.go b/userapi/api/api.go index 1fb8b7b73..d16505b1e 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -35,6 +35,8 @@ type UserInternalAPI interface { PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error + PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) + QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error @@ -45,6 +47,84 @@ type UserInternalAPI interface { QueryPresenceForUser(ctx context.Context, req *QueryPresenceForUserRequest, res *QueryPresenceForUserResponse) error } +type PerformKeyBackupRequest struct { + UserID string + Version string // optional if modifying a key backup + AuthData json.RawMessage + Algorithm string + DeleteBackup bool // if true will delete the backup based on 'Version'. + + // The keys to upload, if any. If blank, creates/updates/deletes key version metadata only. + Keys struct { + Rooms map[string]struct { + Sessions map[string]KeyBackupSession `json:"sessions"` + } `json:"rooms"` + } +} + +// KeyBackupData in https://spec.matrix.org/unstable/client-server-api/#get_matrixclientr0room_keyskeysroomidsessionid +type KeyBackupSession struct { + FirstMessageIndex int `json:"first_message_index"` + ForwardedCount int `json:"forwarded_count"` + IsVerified bool `json:"is_verified"` + SessionData json.RawMessage `json:"session_data"` +} + +func (a *KeyBackupSession) ShouldReplaceRoomKey(newKey *KeyBackupSession) bool { + // https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 + // "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true" + if newKey.IsVerified && !a.IsVerified { + return true + } else if newKey.FirstMessageIndex < a.FirstMessageIndex { + // "if they have the same values for is_verified, then it will keep the key with a lower first_message_index" + return true + } else if newKey.ForwardedCount < a.ForwardedCount { + // "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count" + return true + } + return false +} + +// Internal KeyBackupData for passing to/from the storage layer +type InternalKeyBackupSession struct { + KeyBackupSession + RoomID string + SessionID string +} + +type PerformKeyBackupResponse struct { + Error string // set if there was a problem performing the request + BadInput bool // if set, the Error was due to bad input (HTTP 400) + + Exists bool // set to true if the Version exists + Version string // the newly created version + + KeyCount int64 // only set if Keys were given in the request + KeyETag string // only set if Keys were given in the request +} + +type QueryKeyBackupRequest struct { + UserID string + Version string // the version to query, if blank it means the latest + + ReturnKeys bool // whether to return keys in the backup response or just the metadata + KeysForRoomID string // optional string to return keys which belong to this room + KeysForSessionID string // optional string to return keys which belong to this (room, session) +} + +type QueryKeyBackupResponse struct { + Error string + Exists bool + + Algorithm string `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` + Count int64 `json:"count"` + ETag string `json:"etag"` + Version string `json:"version"` + + Keys map[string]map[string]KeyBackupSession // the keys if ReturnKeys=true +} + // InputAccountDataRequest is the request for InputAccountData type InputAccountDataRequest struct { UserID string // required: the user to set account data for diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 57abda410..dc57c22fd 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -483,3 +483,114 @@ func (a *UserInternalAPI) QueryPresenceForUser(ctx context.Context, req *api.Que } return nil } + +func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { + // Delete metadata + if req.DeleteBackup { + if req.Version == "" { + res.BadInput = true + res.Error = "must specify a version to delete" + return + } + exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version) + if err != nil { + res.Error = fmt.Sprintf("failed to delete backup: %s", err) + } + res.Exists = exists + res.Version = req.Version + return + } + // Create metadata + if req.Version == "" { + version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) + if err != nil { + res.Error = fmt.Sprintf("failed to create backup: %s", err) + } + res.Exists = err == nil + res.Version = version + return + } + // Update metadata + if len(req.Keys.Rooms) == 0 { + err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) + if err != nil { + res.Error = fmt.Sprintf("failed to update backup: %s", err) + } + res.Exists = err == nil + res.Version = req.Version + return + } + // Upload Keys for a specific version metadata + a.uploadBackupKeys(ctx, req, res) +} + +func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { + // you can only upload keys for the CURRENT version + version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "") + if err != nil { + res.Error = fmt.Sprintf("failed to query version: %s", err) + return + } + if deleted { + res.Error = "backup was deleted" + return + } + if version != req.Version { + res.BadInput = true + res.Error = fmt.Sprintf("%s isn't the current version, %s is.", req.Version, version) + return + } + res.Exists = true + res.Version = version + + // map keys to a form we can upload more easily - the map ensures we have no duplicates. + var uploads []api.InternalKeyBackupSession + for roomID, data := range req.Keys.Rooms { + for sessionID, sessionData := range data.Sessions { + uploads = append(uploads, api.InternalKeyBackupSession{ + RoomID: roomID, + SessionID: sessionID, + KeyBackupSession: sessionData, + }) + } + } + count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads) + if err != nil { + res.Error = fmt.Sprintf("failed to upsert keys: %s", err) + return + } + res.KeyCount = count + res.KeyETag = etag +} + +func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { + version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version) + res.Version = version + if err != nil { + if err == sql.ErrNoRows { + res.Exists = false + return + } + res.Error = fmt.Sprintf("failed to query key backup: %s", err) + return + } + res.Algorithm = algorithm + res.AuthData = authData + res.ETag = etag + res.Exists = !deleted + + if !req.ReturnKeys { + res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID) + if err != nil { + res.Error = fmt.Sprintf("failed to count keys: %s", err) + } + return + } + + result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) + if err != nil { + res.Error = fmt.Sprintf("failed to query keys: %s", err) + return + } + res.Keys = result +} diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 8d4acbbc3..9423160b1 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -37,6 +37,7 @@ const ( PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" + PerformKeyBackupPath = "/userapi/performKeyBackup" QueryProfilePath = "/userapi/queryProfile" QueryAccessTokenPath = "/userapi/queryAccessToken" @@ -46,6 +47,7 @@ const ( QuerySearchProfilesPath = "/userapi/querySearchProfiles" QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" QueryPresenceForUserPath = "/userapi/queryPresenceForUser" + QueryKeyBackupPath = "/userapi/queryKeyBackup" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -243,3 +245,25 @@ func (h *httpUserInternalAPI) QueryPresenceForUser(ctx context.Context, req *api apiURL := h.apiURL + QueryPresenceForUserPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpUserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformKeyBackup") + defer span.Finish() + + apiURL := h.apiURL + PerformKeyBackupPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + if err != nil { + res.Error = err.Error() + } +} + +func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup") + defer span.Finish() + + apiURL := h.apiURL + QueryKeyBackupPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + if err != nil { + res.Error = err.Error() + } +} diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index 5aa61b909..887f7193b 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -54,6 +54,15 @@ type Database interface { DeactivateAccount(ctx context.Context, localpart string) (err error) CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error) GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) + + // Key backups + CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error) + UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error) + DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error) + GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) + UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error) + GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error) + CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/accounts/postgres/key_backup_table.go b/userapi/storage/accounts/postgres/key_backup_table.go new file mode 100644 index 000000000..0a2a26550 --- /dev/null +++ b/userapi/storage/accounts/postgres/key_backup_table.go @@ -0,0 +1,174 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/userapi/api" +) + +const keyBackupTableSchema = ` +CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + + version TEXT NOT NULL, + first_message_index INTEGER NOT NULL, + forwarded_count INTEGER NOT NULL, + is_verified BOOLEAN NOT NULL, + session_data TEXT NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); +CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); +` + +const insertBackupKeySQL = "" + + "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + +const updateBackupKeySQL = "" + + "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + + "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8" + +const countKeysSQL = "" + + "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2" + +const selectKeysSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2" + +const selectKeysByRoomIDSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2 AND room_id = $3" + +const selectKeysByRoomIDAndSessionIDSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" + +type keyBackupStatements struct { + insertBackupKeyStmt *sql.Stmt + updateBackupKeyStmt *sql.Stmt + countKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysByRoomIDStmt *sql.Stmt + selectKeysByRoomIDAndSessionIDStmt *sql.Stmt +} + +func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(keyBackupTableSchema) + if err != nil { + return + } + if s.insertBackupKeyStmt, err = db.Prepare(insertBackupKeySQL); err != nil { + return + } + if s.updateBackupKeyStmt, err = db.Prepare(updateBackupKeySQL); err != nil { + return + } + if s.countKeysStmt, err = db.Prepare(countKeysSQL); err != nil { + return + } + if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { + return + } + if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil { + return + } + if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil { + return + } + return +} + +func (s keyBackupStatements) countKeys( + ctx context.Context, txn *sql.Tx, userID, version string, +) (count int64, err error) { + err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) + return +} + +func (s *keyBackupStatements) insertBackupKey( + ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, +) (err error) { + _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( + ctx, userID, key.RoomID, key.SessionID, version, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), + ) + return +} + +func (s *keyBackupStatements) updateBackupKey( + ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, +) (err error) { + _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( + ctx, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), userID, key.RoomID, key.SessionID, version, + ) + return +} + +func (s *keyBackupStatements) selectKeys( + ctx context.Context, txn *sql.Tx, userID, version string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func (s *keyBackupStatements) selectKeysByRoomID( + ctx context.Context, txn *sql.Tx, userID, version, roomID string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( + ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) { + result := make(map[string]map[string]api.KeyBackupSession) + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed") + for rows.Next() { + var key api.InternalKeyBackupSession + // room_id, session_id, first_message_index, forwarded_count, is_verified, session_data + var sessionDataStr string + if err := rows.Scan(&key.RoomID, &key.SessionID, &key.FirstMessageIndex, &key.ForwardedCount, &key.IsVerified, &sessionDataStr); err != nil { + return nil, err + } + key.SessionData = json.RawMessage(sessionDataStr) + roomData := result[key.RoomID] + if roomData == nil { + roomData = make(map[string]api.KeyBackupSession) + } + roomData[key.SessionID] = key.KeyBackupSession + result[key.RoomID] = roomData + } + return result, nil +} diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/accounts/postgres/key_backup_version_table.go new file mode 100644 index 000000000..51a462b32 --- /dev/null +++ b/userapi/storage/accounts/postgres/key_backup_version_table.go @@ -0,0 +1,170 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strconv" +) + +const keyBackupVersionTableSchema = ` +CREATE SEQUENCE IF NOT EXISTS account_e2e_room_keys_versions_seq; + +-- the metadata for each generation of encrypted e2e session backups +CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( + user_id TEXT NOT NULL, + -- this means no 2 users will ever have the same version of e2e session backups which strictly + -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. + version BIGINT DEFAULT nextval('account_e2e_room_keys_versions_seq'), + algorithm TEXT NOT NULL, + auth_data TEXT NOT NULL, + etag TEXT NOT NULL, + deleted SMALLINT DEFAULT 0 NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); +` + +const insertKeyBackupSQL = "" + + "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" + +const updateKeyBackupAuthDataSQL = "" + + "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" + +const updateKeyBackupETagSQL = "" + + "UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3" + +const deleteKeyBackupSQL = "" + + "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" + +const selectKeyBackupSQL = "" + + "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + +const selectLatestVersionSQL = "" + + "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" + +type keyBackupVersionStatements struct { + insertKeyBackupStmt *sql.Stmt + updateKeyBackupAuthDataStmt *sql.Stmt + deleteKeyBackupStmt *sql.Stmt + selectKeyBackupStmt *sql.Stmt + selectLatestVersionStmt *sql.Stmt + updateKeyBackupETagStmt *sql.Stmt +} + +func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(keyBackupVersionTableSchema) + if err != nil { + return + } + if s.insertKeyBackupStmt, err = db.Prepare(insertKeyBackupSQL); err != nil { + return + } + if s.updateKeyBackupAuthDataStmt, err = db.Prepare(updateKeyBackupAuthDataSQL); err != nil { + return + } + if s.deleteKeyBackupStmt, err = db.Prepare(deleteKeyBackupSQL); err != nil { + return + } + if s.selectKeyBackupStmt, err = db.Prepare(selectKeyBackupSQL); err != nil { + return + } + if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil { + return + } + if s.updateKeyBackupETagStmt, err = db.Prepare(updateKeyBackupETagSQL); err != nil { + return + } + return +} + +func (s *keyBackupVersionStatements) insertKeyBackup( + ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, +) (version string, err error) { + var versionInt int64 + err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData), etag).Scan(&versionInt) + return strconv.FormatInt(versionInt, 10), err +} + +func (s *keyBackupVersionStatements) updateKeyBackupAuthData( + ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, +) error { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return fmt.Errorf("invalid version") + } + _, err = txn.Stmt(s.updateKeyBackupAuthDataStmt).ExecContext(ctx, string(authData), userID, versionInt) + return err +} + +func (s *keyBackupVersionStatements) updateKeyBackupETag( + ctx context.Context, txn *sql.Tx, userID, version, etag string, +) error { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return fmt.Errorf("invalid version") + } + _, err = txn.Stmt(s.updateKeyBackupETagStmt).ExecContext(ctx, etag, userID, versionInt) + return err +} + +func (s *keyBackupVersionStatements) deleteKeyBackup( + ctx context.Context, txn *sql.Tx, userID, version string, +) (bool, error) { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return false, fmt.Errorf("invalid version") + } + result, err := txn.Stmt(s.deleteKeyBackupStmt).ExecContext(ctx, userID, versionInt) + if err != nil { + return false, err + } + ra, err := result.RowsAffected() + if err != nil { + return false, err + } + return ra == 1, nil +} + +func (s *keyBackupVersionStatements) selectKeyBackup( + ctx context.Context, txn *sql.Tx, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { + var versionInt int64 + if version == "" { + var v *int64 // allows nulls + if err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&v); err != nil { + return + } + if v == nil { + err = sql.ErrNoRows + return + } + versionInt = *v + } else { + if versionInt, err = strconv.ParseInt(version, 10, 64); err != nil { + return + } + } + versionResult = strconv.FormatInt(versionInt, 10) + var deletedInt int + var authDataStr string + err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &etag, &deletedInt) + deleted = deletedInt == 1 + authData = json.RawMessage(authDataStr) + return +} diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index c5e74ed15..6bddbfc3f 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -19,6 +19,7 @@ import ( "database/sql" "encoding/json" "errors" + "fmt" "strconv" "time" @@ -45,6 +46,8 @@ type Database struct { accountDatas accountDataStatements threepids threepidStatements openIDTokens tokenStatements + keyBackupVersions keyBackupVersionStatements + keyBackups keyBackupStatements serverName gomatrixserverlib.ServerName bcryptCost int openIDTokenLifetimeMS int64 @@ -93,6 +96,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.openIDTokens.prepare(db, serverName); err != nil { return nil, err } + if err = d.keyBackupVersions.prepare(db); err != nil { + return nil, err + } + if err = d.keyBackups.prepare(db); err != nil { + return nil, err + } return d, nil } @@ -368,3 +377,145 @@ func (d *Database) GetOpenIDTokenAttributes( ) (*api.OpenIDTokenAttributes, error) { return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) } + +func (d *Database) CreateKeyBackup( + ctx context.Context, userID, algorithm string, authData json.RawMessage, +) (version string, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "") + return err + }) + return +} + +func (d *Database) UpdateKeyBackupAuthData( + ctx context.Context, userID, version string, authData json.RawMessage, +) (err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData) + }) + return +} + +func (d *Database) DeleteKeyBackup( + ctx context.Context, userID, version string, +) (exists bool, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) GetKeyBackup( + ctx context.Context, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) GetBackupKeys( + ctx context.Context, version, userID, filterRoomID, filterSessionID string, +) (result map[string]map[string]api.KeyBackupSession, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + if filterSessionID != "" { + result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) + return err + } + if filterRoomID != "" { + result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID) + return err + } + result, err = d.keyBackups.selectKeys(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) CountBackupKeys( + ctx context.Context, version, userID string, +) (count int64, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + if err != nil { + return err + } + return nil + }) + return +} + +// nolint:nakedret +func (d *Database) UpsertBackupKeys( + ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, +) (count int64, etag string, err error) { + // wrap the following logic in a txn to ensure we atomically upload keys + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + _, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) + if err != nil { + return err + } + if deleted { + return fmt.Errorf("backup was deleted") + } + // pull out all keys for this (user_id, version) + existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version) + if err != nil { + return err + } + + changed := false + // loop over all the new keys (which should be smaller than the set of backed up keys) + for _, newKey := range uploads { + // if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them. + existingRoom := existingKeys[newKey.RoomID] + if existingRoom != nil { + existingSession, ok := existingRoom[newKey.SessionID] + if ok { + if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { + err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) + } + } + // if we shouldn't replace the key we do nothing with it + continue + } + } + // if we're here, either the room or session are new, either way, we insert + err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) + } + } + + count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + if err != nil { + return err + } + if changed { + // update the etag + var newETag string + if oldETag == "" { + newETag = "1" + } else { + oldETagInt, err := strconv.ParseInt(oldETag, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse old etag: %s", err) + } + newETag = strconv.FormatInt(oldETagInt+1, 10) + } + etag = newETag + return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag) + } else { + etag = oldETag + } + return nil + }) + return +} diff --git a/userapi/storage/accounts/sqlite3/key_backup_table.go b/userapi/storage/accounts/sqlite3/key_backup_table.go new file mode 100644 index 000000000..675093514 --- /dev/null +++ b/userapi/storage/accounts/sqlite3/key_backup_table.go @@ -0,0 +1,174 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/userapi/api" +) + +const keyBackupTableSchema = ` +CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + + version TEXT NOT NULL, + first_message_index INTEGER NOT NULL, + forwarded_count INTEGER NOT NULL, + is_verified BOOLEAN NOT NULL, + session_data TEXT NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); +CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); +` + +const insertBackupKeySQL = "" + + "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + +const updateBackupKeySQL = "" + + "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + + "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8" + +const countKeysSQL = "" + + "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2" + +const selectKeysSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2" + +const selectKeysByRoomIDSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2 AND room_id = $3" + +const selectKeysByRoomIDAndSessionIDSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" + +type keyBackupStatements struct { + insertBackupKeyStmt *sql.Stmt + updateBackupKeyStmt *sql.Stmt + countKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysByRoomIDStmt *sql.Stmt + selectKeysByRoomIDAndSessionIDStmt *sql.Stmt +} + +func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(keyBackupTableSchema) + if err != nil { + return + } + if s.insertBackupKeyStmt, err = db.Prepare(insertBackupKeySQL); err != nil { + return + } + if s.updateBackupKeyStmt, err = db.Prepare(updateBackupKeySQL); err != nil { + return + } + if s.countKeysStmt, err = db.Prepare(countKeysSQL); err != nil { + return + } + if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { + return + } + if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil { + return + } + if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil { + return + } + return +} + +func (s keyBackupStatements) countKeys( + ctx context.Context, txn *sql.Tx, userID, version string, +) (count int64, err error) { + err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) + return +} + +func (s *keyBackupStatements) insertBackupKey( + ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, +) (err error) { + _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( + ctx, userID, key.RoomID, key.SessionID, version, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), + ) + return +} + +func (s *keyBackupStatements) updateBackupKey( + ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, +) (err error) { + _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( + ctx, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), userID, key.RoomID, key.SessionID, version, + ) + return +} + +func (s *keyBackupStatements) selectKeys( + ctx context.Context, txn *sql.Tx, userID, version string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func (s *keyBackupStatements) selectKeysByRoomID( + ctx context.Context, txn *sql.Tx, userID, version, roomID string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( + ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) { + result := make(map[string]map[string]api.KeyBackupSession) + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed") + for rows.Next() { + var key api.InternalKeyBackupSession + // room_id, session_id, first_message_index, forwarded_count, is_verified, session_data + var sessionDataStr string + if err := rows.Scan(&key.RoomID, &key.SessionID, &key.FirstMessageIndex, &key.ForwardedCount, &key.IsVerified, &sessionDataStr); err != nil { + return nil, err + } + key.SessionData = json.RawMessage(sessionDataStr) + roomData := result[key.RoomID] + if roomData == nil { + roomData = make(map[string]api.KeyBackupSession) + } + roomData[key.SessionID] = key.KeyBackupSession + result[key.RoomID] = roomData + } + return result, nil +} diff --git a/userapi/storage/accounts/sqlite3/key_backup_version_table.go b/userapi/storage/accounts/sqlite3/key_backup_version_table.go new file mode 100644 index 000000000..a9e7bf5db --- /dev/null +++ b/userapi/storage/accounts/sqlite3/key_backup_version_table.go @@ -0,0 +1,168 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strconv" +) + +const keyBackupVersionTableSchema = ` +-- the metadata for each generation of encrypted e2e session backups +CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( + user_id TEXT NOT NULL, + -- this means no 2 users will ever have the same version of e2e session backups which strictly + -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. + version INTEGER PRIMARY KEY AUTOINCREMENT, + algorithm TEXT NOT NULL, + auth_data TEXT NOT NULL, + etag TEXT NOT NULL, + deleted INTEGER DEFAULT 0 NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); +` + +const insertKeyBackupSQL = "" + + "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" + +const updateKeyBackupAuthDataSQL = "" + + "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" + +const updateKeyBackupETagSQL = "" + + "UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3" + +const deleteKeyBackupSQL = "" + + "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" + +const selectKeyBackupSQL = "" + + "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + +const selectLatestVersionSQL = "" + + "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" + +type keyBackupVersionStatements struct { + insertKeyBackupStmt *sql.Stmt + updateKeyBackupAuthDataStmt *sql.Stmt + deleteKeyBackupStmt *sql.Stmt + selectKeyBackupStmt *sql.Stmt + selectLatestVersionStmt *sql.Stmt + updateKeyBackupETagStmt *sql.Stmt +} + +func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(keyBackupVersionTableSchema) + if err != nil { + return + } + if s.insertKeyBackupStmt, err = db.Prepare(insertKeyBackupSQL); err != nil { + return + } + if s.updateKeyBackupAuthDataStmt, err = db.Prepare(updateKeyBackupAuthDataSQL); err != nil { + return + } + if s.deleteKeyBackupStmt, err = db.Prepare(deleteKeyBackupSQL); err != nil { + return + } + if s.selectKeyBackupStmt, err = db.Prepare(selectKeyBackupSQL); err != nil { + return + } + if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil { + return + } + if s.updateKeyBackupETagStmt, err = db.Prepare(updateKeyBackupETagSQL); err != nil { + return + } + return +} + +func (s *keyBackupVersionStatements) insertKeyBackup( + ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, +) (version string, err error) { + var versionInt int64 + err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData), etag).Scan(&versionInt) + return strconv.FormatInt(versionInt, 10), err +} + +func (s *keyBackupVersionStatements) updateKeyBackupAuthData( + ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, +) error { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return fmt.Errorf("invalid version") + } + _, err = txn.Stmt(s.updateKeyBackupAuthDataStmt).ExecContext(ctx, string(authData), userID, versionInt) + return err +} + +func (s *keyBackupVersionStatements) updateKeyBackupETag( + ctx context.Context, txn *sql.Tx, userID, version, etag string, +) error { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return fmt.Errorf("invalid version") + } + _, err = txn.Stmt(s.updateKeyBackupETagStmt).ExecContext(ctx, etag, userID, versionInt) + return err +} + +func (s *keyBackupVersionStatements) deleteKeyBackup( + ctx context.Context, txn *sql.Tx, userID, version string, +) (bool, error) { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return false, fmt.Errorf("invalid version") + } + result, err := txn.Stmt(s.deleteKeyBackupStmt).ExecContext(ctx, userID, versionInt) + if err != nil { + return false, err + } + ra, err := result.RowsAffected() + if err != nil { + return false, err + } + return ra == 1, nil +} + +func (s *keyBackupVersionStatements) selectKeyBackup( + ctx context.Context, txn *sql.Tx, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { + var versionInt int64 + if version == "" { + var v *int64 // allows nulls + if err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&v); err != nil { + return + } + if v == nil { + err = sql.ErrNoRows + return + } + versionInt = *v + } else { + if versionInt, err = strconv.ParseInt(version, 10, 64); err != nil { + return + } + } + versionResult = strconv.FormatInt(versionInt, 10) + var deletedInt int + var authDataStr string + err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &etag, &deletedInt) + deleted = deletedInt == 1 + authData = json.RawMessage(authDataStr) + return +} diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index c0f7118cb..d752e3db8 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -19,6 +19,7 @@ import ( "database/sql" "encoding/json" "errors" + "fmt" "strconv" "sync" "time" @@ -43,6 +44,8 @@ type Database struct { accountDatas accountDataStatements threepids threepidStatements openIDTokens tokenStatements + keyBackupVersions keyBackupVersionStatements + keyBackups keyBackupStatements serverName gomatrixserverlib.ServerName bcryptCost int openIDTokenLifetimeMS int64 @@ -97,6 +100,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.openIDTokens.prepare(db, serverName); err != nil { return nil, err } + if err = d.keyBackupVersions.prepare(db); err != nil { + return nil, err + } + if err = d.keyBackups.prepare(db); err != nil { + return nil, err + } return d, nil } @@ -156,8 +165,9 @@ func (d *Database) SetPassword( if err != nil { return err } - err = d.accounts.updatePassword(ctx, localpart, hash) - return err + return d.writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.accounts.updatePassword(ctx, localpart, hash) + }) } // CreateGuestAccount makes a new guest account and creates an empty profile @@ -384,7 +394,9 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi // DeactivateAccount deactivates the user's account, removing all ability for the user to login again. func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { - return d.accounts.deactivateAccount(ctx, localpart) + return d.writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.accounts.deactivateAccount(ctx, localpart) + }) } // CreateOpenIDToken persists a new token that was issued for OpenID Connect @@ -406,3 +418,146 @@ func (d *Database) GetOpenIDTokenAttributes( ) (*api.OpenIDTokenAttributes, error) { return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) } + +func (d *Database) CreateKeyBackup( + ctx context.Context, userID, algorithm string, authData json.RawMessage, +) (version string, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "") + return err + }) + return +} + +func (d *Database) UpdateKeyBackupAuthData( + ctx context.Context, userID, version string, authData json.RawMessage, +) (err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData) + }) + return +} + +func (d *Database) DeleteKeyBackup( + ctx context.Context, userID, version string, +) (exists bool, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) GetKeyBackup( + ctx context.Context, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) GetBackupKeys( + ctx context.Context, version, userID, filterRoomID, filterSessionID string, +) (result map[string]map[string]api.KeyBackupSession, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + if filterSessionID != "" { + result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) + return err + } + if filterRoomID != "" { + result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID) + return err + } + result, err = d.keyBackups.selectKeys(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) CountBackupKeys( + ctx context.Context, version, userID string, +) (count int64, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + if err != nil { + return err + } + return nil + }) + return +} + +// nolint:nakedret +func (d *Database) UpsertBackupKeys( + ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, +) (count int64, etag string, err error) { + // wrap the following logic in a txn to ensure we atomically upload keys + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + _, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) + if err != nil { + return err + } + if deleted { + return fmt.Errorf("backup was deleted") + } + // pull out all keys for this (user_id, version) + existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version) + if err != nil { + return err + } + + changed := false + // loop over all the new keys (which should be smaller than the set of backed up keys) + for _, newKey := range uploads { + // if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them. + existingRoom := existingKeys[newKey.RoomID] + if existingRoom != nil { + existingSession, ok := existingRoom[newKey.SessionID] + if ok { + if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { + err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) + } + } + // if we shouldn't replace the key we do nothing with it + continue + } + } + // if we're here, either the room or session are new, either way, we insert + err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) + } + } + + count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + if err != nil { + return err + } + if changed { + // update the etag + var newETag string + if oldETag == "" { + newETag = "1" + } else { + oldETagInt, err := strconv.ParseInt(oldETag, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse old etag: %s", err) + } + newETag = strconv.FormatInt(oldETagInt+1, 10) + } + etag = newETag + return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag) + } else { + etag = oldETag + } + + return nil + }) + return +}