Key Backups (2/3) : Add E2E backup key tables (#1945)

* Add PUT key backup endpoints and glue them to PerformKeyBackup

* Add tables for storing backup keys and glue them into the user API

* Don't create tables whilst still WIPing

* writer on sqlite please

* Linting
This commit is contained in:
kegsay 2021-07-27 17:08:53 +01:00 committed by GitHub
parent a060df91e2
commit b3754d68fc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 737 additions and 45 deletions

View file

@ -42,6 +42,17 @@ type keyBackupVersionResponse struct {
Version string `json:"version"` Version string `json:"version"`
} }
type keyBackupSessionRequest struct {
Rooms map[string]struct {
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
} `json:"rooms"`
}
type keyBackupSessionResponse struct {
Count int64 `json:"count"`
ETag string `json:"etag"`
}
// Create a new key backup. Request must contain a `keyBackupVersion`. Returns a `keyBackupVersionCreateResponse`. // Create a new key backup. Request must contain a `keyBackupVersion`. Returns a `keyBackupVersionCreateResponse`.
// Implements POST /_matrix/client/r0/room_keys/version // Implements POST /_matrix/client/r0/room_keys/version
func CreateKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device) util.JSONResponse { func CreateKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device) util.JSONResponse {
@ -171,3 +182,37 @@ func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI,
}, },
} }
} }
// Upload a bunch of session keys for a given `version`.
func UploadBackupKeys(
req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string, keys *keyBackupSessionRequest,
) util.JSONResponse {
var performKeyBackupResp userapi.PerformKeyBackupResponse
userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
UserID: device.UserID,
Version: version,
Keys: *keys,
}, &performKeyBackupResp)
if performKeyBackupResp.Error != "" {
if performKeyBackupResp.BadInput {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
}
}
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
}
if !performKeyBackupResp.Exists {
return util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound("backup version not found"),
}
}
return util.JSONResponse{
Code: 200,
JSON: keyBackupSessionResponse{
Count: performKeyBackupResp.KeyCount,
ETag: performKeyBackupResp.KeyETag,
},
}
}

View file

@ -938,6 +938,85 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
// E2E Backup Keys
// Bulk room and session
r0mux.Handle("/room_keys/keys",
httputil.MakeAuthAPI("put_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
version := req.URL.Query().Get("version")
if version == "" {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.InvalidArgumentValue("version must be specified"),
}
}
var reqBody keyBackupSessionRequest
resErr := clientutil.UnmarshalJSONRequest(req, &reqBody)
if resErr != nil {
return *resErr
}
return UploadBackupKeys(req, userAPI, device, version, &reqBody)
}),
).Methods(http.MethodPut)
// Single room bulk session
r0mux.Handle("/room_keys/keys/{roomID}",
httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
version := req.URL.Query().Get("version")
if version == "" {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.InvalidArgumentValue("version must be specified"),
}
}
roomID := vars["roomID"]
var reqBody keyBackupSessionRequest
reqBody.Rooms[roomID] = struct {
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
}{
Sessions: map[string]userapi.KeyBackupSession{},
}
body := reqBody.Rooms[roomID]
resErr := clientutil.UnmarshalJSONRequest(req, &body)
if resErr != nil {
return *resErr
}
reqBody.Rooms[roomID] = body
return UploadBackupKeys(req, userAPI, device, version, &reqBody)
}),
).Methods(http.MethodPut)
// Single room, single session
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}",
httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
version := req.URL.Query().Get("version")
if version == "" {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.InvalidArgumentValue("version must be specified"),
}
}
var reqBody userapi.KeyBackupSession
resErr := clientutil.UnmarshalJSONRequest(req, &reqBody)
if resErr != nil {
return *resErr
}
roomID := vars["roomID"]
sessionID := vars["sessionID"]
var keyReq keyBackupSessionRequest
keyReq.Rooms[roomID] = struct {
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
}{}
keyReq.Rooms[roomID].Sessions[sessionID] = reqBody
return UploadBackupKeys(req, userAPI, device, version, &keyReq)
}),
).Methods(http.MethodPut)
// Supplying a device ID is deprecated. // Supplying a device ID is deprecated.
r0mux.Handle("/keys/upload/{deviceID}", r0mux.Handle("/keys/upload/{deviceID}",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {

View file

@ -540,4 +540,3 @@ Key notary server must not overwrite a valid key with a spurious result from the
GET /rooms/:room_id/aliases lists aliases GET /rooms/:room_id/aliases lists aliases
Only room members can list aliases of a room Only room members can list aliases of a room
Users with sufficient power-level can delete other's aliases Users with sufficient power-level can delete other's aliases
Can create more than 10 backup versions

View file

@ -50,13 +50,39 @@ type PerformKeyBackupRequest struct {
AuthData json.RawMessage AuthData json.RawMessage
Algorithm string Algorithm string
DeleteBackup bool // if true will delete the backup based on 'Version'. DeleteBackup bool // if true will delete the backup based on 'Version'.
// The keys to upload, if any. If blank, creates/updates/deletes key version metadata only.
Keys struct {
Rooms map[string]struct {
Sessions map[string]KeyBackupSession `json:"sessions"`
} `json:"rooms"`
}
}
// KeyBackupData in https://spec.matrix.org/unstable/client-server-api/#get_matrixclientr0room_keyskeysroomidsessionid
type KeyBackupSession struct {
FirstMessageIndex int `json:"first_message_index"`
ForwardedCount int `json:"forwarded_count"`
IsVerified bool `json:"is_verified"`
SessionData json.RawMessage `json:"session_data"`
}
// Internal KeyBackupData for passing to/from the storage layer
type InternalKeyBackupSession struct {
KeyBackupSession
RoomID string
SessionID string
} }
type PerformKeyBackupResponse struct { type PerformKeyBackupResponse struct {
Error string // set if there was a problem performing the request Error string // set if there was a problem performing the request
BadInput bool // if set, the Error was due to bad input (HTTP 400) BadInput bool // if set, the Error was due to bad input (HTTP 400)
Exists bool // set to true if the Version exists
Version string Exists bool // set to true if the Version exists
Version string // the newly created version
KeyCount int64 // only set if Keys were given in the request
KeyETag string // only set if Keys were given in the request
} }
type QueryKeyBackupRequest struct { type QueryKeyBackupRequest struct {

View file

@ -444,7 +444,7 @@ func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOp
} }
func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
// Delete // Delete metadata
if req.DeleteBackup { if req.DeleteBackup {
if req.Version == "" { if req.Version == "" {
res.BadInput = true res.BadInput = true
@ -459,7 +459,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
res.Version = req.Version res.Version = req.Version
return return
} }
// Create // Create metadata
if req.Version == "" { if req.Version == "" {
version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
if err != nil { if err != nil {
@ -469,16 +469,55 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
res.Version = version res.Version = version
return return
} }
// Update // Update metadata
err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) if len(req.Keys.Rooms) == 0 {
if err != nil { err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
res.Error = fmt.Sprintf("failed to update backup: %s", err) if err != nil {
res.Error = fmt.Sprintf("failed to update backup: %s", err)
}
res.Version = req.Version
return
} }
res.Version = req.Version // Upload Keys for a specific version metadata
a.uploadBackupKeys(ctx, req, res)
}
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
// ensure the version metadata exists
version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version)
if err != nil {
res.Error = fmt.Sprintf("failed to query version: %s", err)
return
}
if deleted {
res.Error = "backup was deleted"
return
}
res.Exists = true
res.Version = version
// map keys to a form we can upload more easily - the map ensures we have no duplicates.
var uploads []api.InternalKeyBackupSession
for roomID, data := range req.Keys.Rooms {
for sessionID, sessionData := range data.Sessions {
uploads = append(uploads, api.InternalKeyBackupSession{
RoomID: roomID,
SessionID: sessionID,
KeyBackupSession: sessionData,
})
}
}
count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
if err != nil {
res.Error = fmt.Sprintf("failed to upsert keys: %s", err)
return
}
res.KeyCount = count
res.KeyETag = etag
} }
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
version, algorithm, authData, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version) version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version)
res.Version = version res.Version = version
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -494,5 +533,5 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
// TODO: // TODO:
res.Count = 0 res.Count = 0
res.ETag = "" res.ETag = etag
} }

View file

@ -59,7 +59,8 @@ type Database interface {
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error) CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error) UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error)
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error) DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
} }
// Err3PIDInUse is the error returned when trying to save an association involving // Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -0,0 +1,134 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/userapi/api"
)
const keyBackupTableSchema = `
CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
session_id TEXT NOT NULL,
version TEXT NOT NULL,
first_message_index INTEGER NOT NULL,
forwarded_count INTEGER NOT NULL,
is_verified BOOLEAN NOT NULL,
session_data TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id);
`
const insertBackupKeySQL = "" +
"INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
const updateBackupKeySQL = "" +
"UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
"WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8"
const countKeysSQL = "" +
"SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2"
const selectKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"WHERE user_id = $1 AND version = $2"
type keyBackupStatements struct {
insertBackupKeyStmt *sql.Stmt
updateBackupKeyStmt *sql.Stmt
countKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt
}
// nolint:unused
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupTableSchema)
if err != nil {
return
}
if s.insertBackupKeyStmt, err = db.Prepare(insertBackupKeySQL); err != nil {
return
}
if s.updateBackupKeyStmt, err = db.Prepare(updateBackupKeySQL); err != nil {
return
}
if s.countKeysStmt, err = db.Prepare(countKeysSQL); err != nil {
return
}
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
return
}
return
}
func (s keyBackupStatements) countKeys(
ctx context.Context, txn *sql.Tx, userID, version string,
) (count int64, err error) {
err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
return
}
func (s *keyBackupStatements) insertBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) {
_, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
ctx, userID, key.RoomID, key.SessionID, version, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData),
)
return
}
func (s *keyBackupStatements) updateBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) {
_, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
ctx, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), userID, key.RoomID, key.SessionID, version,
)
return
}
func (s *keyBackupStatements) selectKeys(
ctx context.Context, txn *sql.Tx, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) {
result := make(map[string]map[string]api.KeyBackupSession)
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed")
for rows.Next() {
var key api.InternalKeyBackupSession
// room_id, session_id, first_message_index, forwarded_count, is_verified, session_data
var sessionDataStr string
if err := rows.Scan(&key.RoomID, &key.SessionID, &key.FirstMessageIndex, &key.ForwardedCount, &key.IsVerified, &sessionDataStr); err != nil {
return nil, err
}
key.SessionData = json.RawMessage(sessionDataStr)
roomData := result[key.RoomID]
if roomData == nil {
roomData = make(map[string]api.KeyBackupSession)
}
roomData[key.SessionID] = key.KeyBackupSession
result[key.RoomID] = roomData
}
return result, nil
}

View file

@ -33,6 +33,7 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
version BIGINT DEFAULT nextval('account_e2e_room_keys_versions_seq'), version BIGINT DEFAULT nextval('account_e2e_room_keys_versions_seq'),
algorithm TEXT NOT NULL, algorithm TEXT NOT NULL,
auth_data TEXT NOT NULL, auth_data TEXT NOT NULL,
etag TEXT NOT NULL,
deleted SMALLINT DEFAULT 0 NOT NULL deleted SMALLINT DEFAULT 0 NOT NULL
); );
@ -40,16 +41,19 @@ CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_
` `
const insertKeyBackupSQL = "" + const insertKeyBackupSQL = "" +
"INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data) VALUES ($1, $2, $3) RETURNING version" "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
const updateKeyBackupAuthDataSQL = "" + // TODO: do we need to WHERE algorithm = $3 as well? const updateKeyBackupAuthDataSQL = "" +
"UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
const updateKeyBackupETagSQL = "" +
"UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
const deleteKeyBackupSQL = "" + const deleteKeyBackupSQL = "" +
"UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
const selectKeyBackupSQL = "" + const selectKeyBackupSQL = "" +
"SELECT algorithm, auth_data, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2"
const selectLatestVersionSQL = "" + const selectLatestVersionSQL = "" +
"SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1"
@ -60,8 +64,10 @@ type keyBackupVersionStatements struct {
deleteKeyBackupStmt *sql.Stmt deleteKeyBackupStmt *sql.Stmt
selectKeyBackupStmt *sql.Stmt selectKeyBackupStmt *sql.Stmt
selectLatestVersionStmt *sql.Stmt selectLatestVersionStmt *sql.Stmt
updateKeyBackupETagStmt *sql.Stmt
} }
// nolint:unused
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupVersionTableSchema) _, err = db.Exec(keyBackupVersionTableSchema)
if err != nil { if err != nil {
@ -82,14 +88,17 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil { if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil {
return return
} }
if s.updateKeyBackupETagStmt, err = db.Prepare(updateKeyBackupETagSQL); err != nil {
return
}
return return
} }
func (s *keyBackupVersionStatements) insertKeyBackup( func (s *keyBackupVersionStatements) insertKeyBackup(
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string,
) (version string, err error) { ) (version string, err error) {
var versionInt int64 var versionInt int64
err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData)).Scan(&versionInt) err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData), etag).Scan(&versionInt)
return strconv.FormatInt(versionInt, 10), err return strconv.FormatInt(versionInt, 10), err
} }
@ -104,6 +113,17 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
return err return err
} }
func (s *keyBackupVersionStatements) updateKeyBackupETag(
ctx context.Context, txn *sql.Tx, userID, version, etag string,
) error {
versionInt, err := strconv.ParseInt(version, 10, 64)
if err != nil {
return fmt.Errorf("invalid version")
}
_, err = txn.Stmt(s.updateKeyBackupETagStmt).ExecContext(ctx, etag, userID, versionInt)
return err
}
func (s *keyBackupVersionStatements) deleteKeyBackup( func (s *keyBackupVersionStatements) deleteKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (bool, error) { ) (bool, error) {
@ -124,7 +144,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup(
func (s *keyBackupVersionStatements) selectKeyBackup( func (s *keyBackupVersionStatements) selectKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) { ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
var versionInt int64 var versionInt int64
if version == "" { if version == "" {
err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt) err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt)
@ -137,7 +157,7 @@ func (s *keyBackupVersionStatements) selectKeyBackup(
versionResult = strconv.FormatInt(versionInt, 10) versionResult = strconv.FormatInt(versionInt, 10)
var deletedInt int var deletedInt int
var authDataStr string var authDataStr string
err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &deletedInt) err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &etag, &deletedInt)
deleted = deletedInt == 1 deleted = deletedInt == 1
authData = json.RawMessage(authDataStr) authData = json.RawMessage(authDataStr)
return return

View file

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"strconv" "strconv"
"time" "time"
@ -45,7 +46,8 @@ type Database struct {
accountDatas accountDataStatements accountDatas accountDataStatements
threepids threepidStatements threepids threepidStatements
openIDTokens tokenStatements openIDTokens tokenStatements
keyBackups keyBackupVersionStatements keyBackupVersions keyBackupVersionStatements
keyBackups keyBackupStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
bcryptCost int bcryptCost int
openIDTokenLifetimeMS int64 openIDTokenLifetimeMS int64
@ -94,9 +96,13 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.openIDTokens.prepare(db, serverName); err != nil { if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err return nil, err
} }
if err = d.keyBackups.prepare(db); err != nil { /*
return nil, err if err = d.keyBackupVersions.prepare(db); err != nil {
} return nil, err
}
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
} */
return d, nil return d, nil
} }
@ -377,7 +383,7 @@ func (d *Database) CreateKeyBackup(
ctx context.Context, userID, algorithm string, authData json.RawMessage, ctx context.Context, userID, algorithm string, authData json.RawMessage,
) (version string, err error) { ) (version string, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
version, err = d.keyBackups.insertKeyBackup(ctx, txn, userID, algorithm, authData) version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "")
return err return err
}) })
return return
@ -387,7 +393,7 @@ func (d *Database) UpdateKeyBackupAuthData(
ctx context.Context, userID, version string, authData json.RawMessage, ctx context.Context, userID, version string, authData json.RawMessage,
) (err error) { ) (err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.keyBackups.updateKeyBackupAuthData(ctx, txn, userID, version, authData) return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
}) })
return return
} }
@ -396,7 +402,7 @@ func (d *Database) DeleteKeyBackup(
ctx context.Context, userID, version string, ctx context.Context, userID, version string,
) (exists bool, err error) { ) (exists bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
exists, err = d.keyBackups.deleteKeyBackup(ctx, txn, userID, version) exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version)
return err return err
}) })
return return
@ -404,10 +410,101 @@ func (d *Database) DeleteKeyBackup(
func (d *Database) GetKeyBackup( func (d *Database) GetKeyBackup(
ctx context.Context, userID, version string, ctx context.Context, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) { ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
versionResult, algorithm, authData, deleted, err = d.keyBackups.selectKeyBackup(ctx, txn, userID, version) versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
return err return err
}) })
return return
} }
// nolint:nakedret
func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
) (count int64, etag string, err error) {
// wrap the following logic in a txn to ensure we atomically upload keys
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
_, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
if err != nil {
return err
}
if deleted {
return fmt.Errorf("backup was deleted")
}
// pull out all keys for this (user_id, version)
existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version)
if err != nil {
return err
}
changed := false
// loop over all the new keys (which should be smaller than the set of backed up keys)
for _, newKey := range uploads {
// if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them.
existingRoom := existingKeys[newKey.RoomID]
if existingRoom != nil {
existingSession, ok := existingRoom[newKey.SessionID]
if ok {
if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) {
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
return err
}
}
// if we shouldn't replace the key we do nothing with it
continue
}
}
// if we're here, either the room or session are new, either way, we insert
err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
return err
}
}
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
if err != nil {
return err
}
if changed {
// update the etag
var newETag string
if oldETag == "" {
newETag = "1"
} else {
oldETagInt, err := strconv.ParseInt(oldETag, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse old etag: %s", err)
}
newETag = strconv.FormatInt(oldETagInt+1, 10)
}
etag = newETag
return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag)
} else {
etag = oldETag
}
return nil
})
return
}
// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't
// create circular import loops
func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool {
// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
// "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
if uploaded.IsVerified && !existing.IsVerified {
return true
}
// "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
if uploaded.FirstMessageIndex < existing.FirstMessageIndex {
return true
}
// "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
if uploaded.ForwardedCount < existing.ForwardedCount {
return true
}
return false
}

View file

@ -0,0 +1,134 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/userapi/api"
)
const keyBackupTableSchema = `
CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
session_id TEXT NOT NULL,
version TEXT NOT NULL,
first_message_index INTEGER NOT NULL,
forwarded_count INTEGER NOT NULL,
is_verified BOOLEAN NOT NULL,
session_data TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id);
`
const insertBackupKeySQL = "" +
"INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
const updateBackupKeySQL = "" +
"UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
"WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8"
const countKeysSQL = "" +
"SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2"
const selectKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"WHERE user_id = $1 AND version = $2"
type keyBackupStatements struct {
insertBackupKeyStmt *sql.Stmt
updateBackupKeyStmt *sql.Stmt
countKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt
}
// nolint:unused
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupTableSchema)
if err != nil {
return
}
if s.insertBackupKeyStmt, err = db.Prepare(insertBackupKeySQL); err != nil {
return
}
if s.updateBackupKeyStmt, err = db.Prepare(updateBackupKeySQL); err != nil {
return
}
if s.countKeysStmt, err = db.Prepare(countKeysSQL); err != nil {
return
}
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
return
}
return
}
func (s keyBackupStatements) countKeys(
ctx context.Context, txn *sql.Tx, userID, version string,
) (count int64, err error) {
err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
return
}
func (s *keyBackupStatements) insertBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) {
_, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
ctx, userID, key.RoomID, key.SessionID, version, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData),
)
return
}
func (s *keyBackupStatements) updateBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) {
_, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
ctx, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), userID, key.RoomID, key.SessionID, version,
)
return
}
func (s *keyBackupStatements) selectKeys(
ctx context.Context, txn *sql.Tx, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) {
result := make(map[string]map[string]api.KeyBackupSession)
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed")
for rows.Next() {
var key api.InternalKeyBackupSession
// room_id, session_id, first_message_index, forwarded_count, is_verified, session_data
var sessionDataStr string
if err := rows.Scan(&key.RoomID, &key.SessionID, &key.FirstMessageIndex, &key.ForwardedCount, &key.IsVerified, &sessionDataStr); err != nil {
return nil, err
}
key.SessionData = json.RawMessage(sessionDataStr)
roomData := result[key.RoomID]
if roomData == nil {
roomData = make(map[string]api.KeyBackupSession)
}
roomData[key.SessionID] = key.KeyBackupSession
result[key.RoomID] = roomData
}
return result, nil
}

View file

@ -31,6 +31,7 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
version INTEGER PRIMARY KEY AUTOINCREMENT, version INTEGER PRIMARY KEY AUTOINCREMENT,
algorithm TEXT NOT NULL, algorithm TEXT NOT NULL,
auth_data TEXT NOT NULL, auth_data TEXT NOT NULL,
etag TEXT NOT NULL,
deleted INTEGER DEFAULT 0 NOT NULL deleted INTEGER DEFAULT 0 NOT NULL
); );
@ -38,16 +39,19 @@ CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_
` `
const insertKeyBackupSQL = "" + const insertKeyBackupSQL = "" +
"INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data) VALUES ($1, $2, $3) RETURNING version" "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
const updateKeyBackupAuthDataSQL = "" + // TODO: do we need to WHERE algorithm = $3 as well? const updateKeyBackupAuthDataSQL = "" +
"UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
const updateKeyBackupETagSQL = "" +
"UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
const deleteKeyBackupSQL = "" + const deleteKeyBackupSQL = "" +
"UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
const selectKeyBackupSQL = "" + const selectKeyBackupSQL = "" +
"SELECT algorithm, auth_data, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2"
const selectLatestVersionSQL = "" + const selectLatestVersionSQL = "" +
"SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1"
@ -58,8 +62,10 @@ type keyBackupVersionStatements struct {
deleteKeyBackupStmt *sql.Stmt deleteKeyBackupStmt *sql.Stmt
selectKeyBackupStmt *sql.Stmt selectKeyBackupStmt *sql.Stmt
selectLatestVersionStmt *sql.Stmt selectLatestVersionStmt *sql.Stmt
updateKeyBackupETagStmt *sql.Stmt
} }
// nolint:unused
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupVersionTableSchema) _, err = db.Exec(keyBackupVersionTableSchema)
if err != nil { if err != nil {
@ -80,14 +86,17 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil { if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil {
return return
} }
if s.updateKeyBackupETagStmt, err = db.Prepare(updateKeyBackupETagSQL); err != nil {
return
}
return return
} }
func (s *keyBackupVersionStatements) insertKeyBackup( func (s *keyBackupVersionStatements) insertKeyBackup(
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string,
) (version string, err error) { ) (version string, err error) {
var versionInt int64 var versionInt int64
err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData)).Scan(&versionInt) err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData), etag).Scan(&versionInt)
return strconv.FormatInt(versionInt, 10), err return strconv.FormatInt(versionInt, 10), err
} }
@ -102,6 +111,17 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
return err return err
} }
func (s *keyBackupVersionStatements) updateKeyBackupETag(
ctx context.Context, txn *sql.Tx, userID, version, etag string,
) error {
versionInt, err := strconv.ParseInt(version, 10, 64)
if err != nil {
return fmt.Errorf("invalid version")
}
_, err = txn.Stmt(s.updateKeyBackupETagStmt).ExecContext(ctx, etag, userID, versionInt)
return err
}
func (s *keyBackupVersionStatements) deleteKeyBackup( func (s *keyBackupVersionStatements) deleteKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (bool, error) { ) (bool, error) {
@ -122,7 +142,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup(
func (s *keyBackupVersionStatements) selectKeyBackup( func (s *keyBackupVersionStatements) selectKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) { ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
var versionInt int64 var versionInt int64
if version == "" { if version == "" {
err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt) err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt)
@ -135,7 +155,7 @@ func (s *keyBackupVersionStatements) selectKeyBackup(
versionResult = strconv.FormatInt(versionInt, 10) versionResult = strconv.FormatInt(versionInt, 10)
var deletedInt int var deletedInt int
var authDataStr string var authDataStr string
err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &deletedInt) err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &etag, &deletedInt)
deleted = deletedInt == 1 deleted = deletedInt == 1
authData = json.RawMessage(authDataStr) authData = json.RawMessage(authDataStr)
return return

View file

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@ -43,7 +44,8 @@ type Database struct {
accountDatas accountDataStatements accountDatas accountDataStatements
threepids threepidStatements threepids threepidStatements
openIDTokens tokenStatements openIDTokens tokenStatements
keyBackups keyBackupVersionStatements keyBackupVersions keyBackupVersionStatements
keyBackups keyBackupStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
bcryptCost int bcryptCost int
openIDTokenLifetimeMS int64 openIDTokenLifetimeMS int64
@ -98,9 +100,13 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.openIDTokens.prepare(db, serverName); err != nil { if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err return nil, err
} }
if err = d.keyBackups.prepare(db); err != nil { /*
return nil, err if err = d.keyBackupVersions.prepare(db); err != nil {
} return nil, err
}
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
} */
return d, nil return d, nil
} }
@ -418,7 +424,7 @@ func (d *Database) CreateKeyBackup(
ctx context.Context, userID, algorithm string, authData json.RawMessage, ctx context.Context, userID, algorithm string, authData json.RawMessage,
) (version string, err error) { ) (version string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
version, err = d.keyBackups.insertKeyBackup(ctx, txn, userID, algorithm, authData) version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "")
return err return err
}) })
return return
@ -428,7 +434,7 @@ func (d *Database) UpdateKeyBackupAuthData(
ctx context.Context, userID, version string, authData json.RawMessage, ctx context.Context, userID, version string, authData json.RawMessage,
) (err error) { ) (err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.keyBackups.updateKeyBackupAuthData(ctx, txn, userID, version, authData) return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
}) })
return return
} }
@ -437,7 +443,7 @@ func (d *Database) DeleteKeyBackup(
ctx context.Context, userID, version string, ctx context.Context, userID, version string,
) (exists bool, err error) { ) (exists bool, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
exists, err = d.keyBackups.deleteKeyBackup(ctx, txn, userID, version) exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version)
return err return err
}) })
return return
@ -445,10 +451,102 @@ func (d *Database) DeleteKeyBackup(
func (d *Database) GetKeyBackup( func (d *Database) GetKeyBackup(
ctx context.Context, userID, version string, ctx context.Context, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) { ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
versionResult, algorithm, authData, deleted, err = d.keyBackups.selectKeyBackup(ctx, txn, userID, version) versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
return err return err
}) })
return return
} }
// nolint:nakedret
func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
) (count int64, etag string, err error) {
// wrap the following logic in a txn to ensure we atomically upload keys
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
_, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
if err != nil {
return err
}
if deleted {
return fmt.Errorf("backup was deleted")
}
// pull out all keys for this (user_id, version)
existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version)
if err != nil {
return err
}
changed := false
// loop over all the new keys (which should be smaller than the set of backed up keys)
for _, newKey := range uploads {
// if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them.
existingRoom := existingKeys[newKey.RoomID]
if existingRoom != nil {
existingSession, ok := existingRoom[newKey.SessionID]
if ok {
if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) {
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
return err
}
}
// if we shouldn't replace the key we do nothing with it
continue
}
}
// if we're here, either the room or session are new, either way, we insert
err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
return err
}
}
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
if err != nil {
return err
}
if changed {
// update the etag
var newETag string
if oldETag == "" {
newETag = "1"
} else {
oldETagInt, err := strconv.ParseInt(oldETag, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse old etag: %s", err)
}
newETag = strconv.FormatInt(oldETagInt+1, 10)
}
etag = newETag
return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag)
} else {
etag = oldETag
}
return nil
})
return
}
// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't
// create circular import loops
func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool {
// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
// "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
if uploaded.IsVerified && !existing.IsVerified {
return true
}
// "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
if uploaded.FirstMessageIndex < existing.FirstMessageIndex {
return true
}
// "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
if uploaded.ForwardedCount < existing.ForwardedCount {
return true
}
return false
}