Key backups (1/2) : Add E2E session backup metadata tables (#1943)

* Initial key backup paths and userapi API

* Fix unit tests

* Add key backup table

* Glue REST API to database

* Linting

* use writer on sqlite
This commit is contained in:
kegsay 2021-07-27 12:47:32 +01:00 committed by GitHub
parent e3679799ea
commit 32538640db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 712 additions and 0 deletions

View file

@ -0,0 +1,173 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"encoding/json"
"fmt"
"net/http"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
type keyBackupVersion struct {
Algorithm string `json:"algorithm"`
AuthData json.RawMessage `json:"auth_data"`
}
type keyBackupVersionCreateResponse struct {
Version string `json:"version"`
}
type keyBackupVersionResponse struct {
Algorithm string `json:"algorithm"`
AuthData json.RawMessage `json:"auth_data"`
Count int `json:"count"`
ETag string `json:"etag"`
Version string `json:"version"`
}
// Create a new key backup. Request must contain a `keyBackupVersion`. Returns a `keyBackupVersionCreateResponse`.
// Implements POST /_matrix/client/r0/room_keys/version
func CreateKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device) util.JSONResponse {
var kb keyBackupVersion
resErr := httputil.UnmarshalJSONRequest(req, &kb)
if resErr != nil {
return *resErr
}
var performKeyBackupResp userapi.PerformKeyBackupResponse
userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
UserID: device.UserID,
Version: "",
AuthData: kb.AuthData,
Algorithm: kb.Algorithm,
}, &performKeyBackupResp)
if performKeyBackupResp.Error != "" {
if performKeyBackupResp.BadInput {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
}
}
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
}
return util.JSONResponse{
Code: 200,
JSON: keyBackupVersionCreateResponse{
Version: performKeyBackupResp.Version,
},
}
}
// KeyBackupVersion returns the key backup version specified. If `version` is empty, the latest `keyBackupVersionResponse` is returned.
// Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version}
func KeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse {
var queryResp userapi.QueryKeyBackupResponse
userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{}, &queryResp)
if queryResp.Error != "" {
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
}
if !queryResp.Exists {
return util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound("version not found"),
}
}
return util.JSONResponse{
Code: 200,
JSON: keyBackupVersionResponse{
Algorithm: queryResp.Algorithm,
AuthData: queryResp.AuthData,
Count: queryResp.Count,
ETag: queryResp.ETag,
Version: queryResp.Version,
},
}
}
// Modify the auth data of a key backup. Version must not be empty. Request must contain a `keyBackupVersion`
// Implements PUT /_matrix/client/r0/room_keys/version/{version}
func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse {
var kb keyBackupVersion
resErr := httputil.UnmarshalJSONRequest(req, &kb)
if resErr != nil {
return *resErr
}
var performKeyBackupResp userapi.PerformKeyBackupResponse
userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
UserID: device.UserID,
Version: version,
AuthData: kb.AuthData,
Algorithm: kb.Algorithm,
}, &performKeyBackupResp)
if performKeyBackupResp.Error != "" {
if performKeyBackupResp.BadInput {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
}
}
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
}
if !performKeyBackupResp.Exists {
return util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound("backup version not found"),
}
}
// Unclear what the 200 body should be
return util.JSONResponse{
Code: 200,
JSON: keyBackupVersionCreateResponse{
Version: performKeyBackupResp.Version,
},
}
}
// Delete a version of key backup. Version must not be empty. If the key backup was previously deleted, will return 200 OK.
// Implements DELETE /_matrix/client/r0/room_keys/version/{version}
func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse {
var performKeyBackupResp userapi.PerformKeyBackupResponse
userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
UserID: device.UserID,
Version: version,
DeleteBackup: true,
}, &performKeyBackupResp)
if performKeyBackupResp.Error != "" {
if performKeyBackupResp.BadInput {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
}
}
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
}
if !performKeyBackupResp.Exists {
return util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound("backup version not found"),
}
}
// Unclear what the 200 body should be
return util.JSONResponse{
Code: 200,
JSON: keyBackupVersionCreateResponse{
Version: performKeyBackupResp.Version,
},
}
}

View file

@ -896,6 +896,48 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
// Key Backup Versions
r0mux.Handle("/room_keys/version/{versionID}",
httputil.MakeAuthAPI("get_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
version := req.URL.Query().Get("version")
return KeyBackupVersion(req, userAPI, device, version)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/version",
httputil.MakeAuthAPI("get_latest_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return KeyBackupVersion(req, userAPI, device, "")
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/version/{versionID}",
httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
version := req.URL.Query().Get("version")
if version == "" {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.InvalidArgumentValue("version must be specified"),
}
}
return ModifyKeyBackupVersionAuthData(req, userAPI, device, version)
}),
).Methods(http.MethodPut)
r0mux.Handle("/room_keys/version/{versionID}",
httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
version := req.URL.Query().Get("version")
if version == "" {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.InvalidArgumentValue("version must be specified"),
}
}
return DeleteKeyBackupVersion(req, userAPI, device, version)
}),
).Methods(http.MethodDelete)
r0mux.Handle("/room_keys/version",
httputil.MakeAuthAPI("post_new_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CreateKeyBackupVersion(req, userAPI, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
// Supplying a device ID is deprecated. // Supplying a device ID is deprecated.
r0mux.Handle("/keys/upload/{deviceID}", r0mux.Handle("/keys/upload/{deviceID}",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {

View file

@ -554,6 +554,10 @@ func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.Quer
func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOpenIDTokenRequest, res *userapi.QueryOpenIDTokenResponse) error { func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOpenIDTokenRequest, res *userapi.QueryOpenIDTokenResponse) error {
return nil return nil
} }
func (u *testUserAPI) PerformKeyBackup(ctx context.Context, req *userapi.PerformKeyBackupRequest, res *userapi.PerformKeyBackupResponse) {
}
func (u *testUserAPI) QueryKeyBackup(ctx context.Context, req *userapi.QueryKeyBackupRequest, res *userapi.QueryKeyBackupResponse) {
}
type testRoomserverAPI struct { type testRoomserverAPI struct {
// use a trace API as it implements method stubs so we don't need to have them here. // use a trace API as it implements method stubs so we don't need to have them here.

View file

@ -373,6 +373,10 @@ func (u *testUserAPI) PerformOpenIDTokenCreation(ctx context.Context, req *usera
func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error { func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error {
return nil return nil
} }
func (u *testUserAPI) PerformKeyBackup(ctx context.Context, req *userapi.PerformKeyBackupRequest, res *userapi.PerformKeyBackupResponse) {
}
func (u *testUserAPI) QueryKeyBackup(ctx context.Context, req *userapi.QueryKeyBackupRequest, res *userapi.QueryKeyBackupResponse) {
}
func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error { func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error {
dev, ok := u.accessTokens[req.AccessToken] dev, ok := u.accessTokens[req.AccessToken]
if !ok { if !ok {

View file

@ -540,3 +540,4 @@ Key notary server must not overwrite a valid key with a spurious result from the
GET /rooms/:room_id/aliases lists aliases GET /rooms/:room_id/aliases lists aliases
Only room members can list aliases of a room Only room members can list aliases of a room
Users with sufficient power-level can delete other's aliases Users with sufficient power-level can delete other's aliases
Can create more than 10 backup versions

View file

@ -33,6 +33,8 @@ type UserInternalAPI interface {
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse)
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse)
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
@ -42,6 +44,37 @@ type UserInternalAPI interface {
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
} }
type PerformKeyBackupRequest struct {
UserID string
Version string // optional if modifying a key backup
AuthData json.RawMessage
Algorithm string
DeleteBackup bool // if true will delete the backup based on 'Version'.
}
type PerformKeyBackupResponse struct {
Error string // set if there was a problem performing the request
BadInput bool // if set, the Error was due to bad input (HTTP 400)
Exists bool // set to true if the Version exists
Version string
}
type QueryKeyBackupRequest struct {
UserID string
Version string // the version to query, if blank it means the latest
}
type QueryKeyBackupResponse struct {
Error string
Exists bool
Algorithm string `json:"algorithm"`
AuthData json.RawMessage `json:"auth_data"`
Count int `json:"count"`
ETag string `json:"etag"`
Version string `json:"version"`
}
// InputAccountDataRequest is the request for InputAccountData // InputAccountDataRequest is the request for InputAccountData
type InputAccountDataRequest struct { type InputAccountDataRequest struct {
UserID string // required: the user to set account data for UserID string // required: the user to set account data for

View file

@ -442,3 +442,57 @@ func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOp
return nil return nil
} }
func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
// Delete
if req.DeleteBackup {
if req.Version == "" {
res.BadInput = true
res.Error = "must specify a version to delete"
return
}
exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version)
if err != nil {
res.Error = fmt.Sprintf("failed to delete backup: %s", err)
}
res.Exists = exists
res.Version = req.Version
return
}
// Create
if req.Version == "" {
version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
if err != nil {
res.Error = fmt.Sprintf("failed to create backup: %s", err)
}
res.Exists = err == nil
res.Version = version
return
}
// Update
err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
if err != nil {
res.Error = fmt.Sprintf("failed to update backup: %s", err)
}
res.Version = req.Version
}
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
version, algorithm, authData, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version)
res.Version = version
if err != nil {
if err == sql.ErrNoRows {
res.Exists = false
return
}
res.Error = fmt.Sprintf("failed to query key backup: %s", err)
return
}
res.Algorithm = algorithm
res.AuthData = authData
res.Exists = !deleted
// TODO:
res.Count = 0
res.ETag = ""
}

View file

@ -36,7 +36,9 @@ const (
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation"
PerformKeyBackupPath = "/userapi/performKeyBackup"
QueryKeyBackupPath = "/userapi/queryKeyBackup"
QueryProfilePath = "/userapi/queryProfile" QueryProfilePath = "/userapi/queryProfile"
QueryAccessTokenPath = "/userapi/queryAccessToken" QueryAccessTokenPath = "/userapi/queryAccessToken"
QueryDevicesPath = "/userapi/queryDevices" QueryDevicesPath = "/userapi/queryDevices"
@ -225,3 +227,24 @@ func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.Que
apiURL := h.apiURL + QueryOpenIDTokenPath apiURL := h.apiURL + QueryOpenIDTokenPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpUserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformKeyBackup")
defer span.Finish()
apiURL := h.apiURL + PerformKeyBackupPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = err.Error()
}
}
func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup")
defer span.Finish()
apiURL := h.apiURL + QueryKeyBackupPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = err.Error()
}
}

View file

@ -54,6 +54,12 @@ type Database interface {
DeactivateAccount(ctx context.Context, localpart string) (err error) DeactivateAccount(ctx context.Context, localpart string) (err error)
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error) CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
// Key backups
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error)
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error)
} }
// Err3PIDInUse is the error returned when trying to save an association involving // Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -0,0 +1,144 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strconv"
)
const keyBackupVersionTableSchema = `
CREATE SEQUENCE IF NOT EXISTS account_e2e_room_keys_versions_seq;
-- the metadata for each generation of encrypted e2e session backups
CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
user_id TEXT NOT NULL,
-- this means no 2 users will ever have the same version of e2e session backups which strictly
-- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1.
version BIGINT DEFAULT nextval('account_e2e_room_keys_versions_seq'),
algorithm TEXT NOT NULL,
auth_data TEXT NOT NULL,
deleted SMALLINT DEFAULT 0 NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
`
const insertKeyBackupSQL = "" +
"INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data) VALUES ($1, $2, $3) RETURNING version"
const updateKeyBackupAuthDataSQL = "" + // TODO: do we need to WHERE algorithm = $3 as well?
"UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
const deleteKeyBackupSQL = "" +
"UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
const selectKeyBackupSQL = "" +
"SELECT algorithm, auth_data, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2"
const selectLatestVersionSQL = "" +
"SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1"
type keyBackupVersionStatements struct {
insertKeyBackupStmt *sql.Stmt
updateKeyBackupAuthDataStmt *sql.Stmt
deleteKeyBackupStmt *sql.Stmt
selectKeyBackupStmt *sql.Stmt
selectLatestVersionStmt *sql.Stmt
}
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupVersionTableSchema)
if err != nil {
return
}
if s.insertKeyBackupStmt, err = db.Prepare(insertKeyBackupSQL); err != nil {
return
}
if s.updateKeyBackupAuthDataStmt, err = db.Prepare(updateKeyBackupAuthDataSQL); err != nil {
return
}
if s.deleteKeyBackupStmt, err = db.Prepare(deleteKeyBackupSQL); err != nil {
return
}
if s.selectKeyBackupStmt, err = db.Prepare(selectKeyBackupSQL); err != nil {
return
}
if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil {
return
}
return
}
func (s *keyBackupVersionStatements) insertKeyBackup(
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage,
) (version string, err error) {
var versionInt int64
err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData)).Scan(&versionInt)
return strconv.FormatInt(versionInt, 10), err
}
func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
) error {
versionInt, err := strconv.ParseInt(version, 10, 64)
if err != nil {
return fmt.Errorf("invalid version")
}
_, err = txn.Stmt(s.updateKeyBackupAuthDataStmt).ExecContext(ctx, string(authData), userID, versionInt)
return err
}
func (s *keyBackupVersionStatements) deleteKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string,
) (bool, error) {
versionInt, err := strconv.ParseInt(version, 10, 64)
if err != nil {
return false, fmt.Errorf("invalid version")
}
result, err := txn.Stmt(s.deleteKeyBackupStmt).ExecContext(ctx, userID, versionInt)
if err != nil {
return false, err
}
ra, err := result.RowsAffected()
if err != nil {
return false, err
}
return ra == 1, nil
}
func (s *keyBackupVersionStatements) selectKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) {
var versionInt int64
if version == "" {
err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt)
} else {
versionInt, err = strconv.ParseInt(version, 10, 64)
}
if err != nil {
return
}
versionResult = strconv.FormatInt(versionInt, 10)
var deletedInt int
var authDataStr string
err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &deletedInt)
deleted = deletedInt == 1
authData = json.RawMessage(authDataStr)
return
}

View file

@ -45,6 +45,7 @@ type Database struct {
accountDatas accountDataStatements accountDatas accountDataStatements
threepids threepidStatements threepids threepidStatements
openIDTokens tokenStatements openIDTokens tokenStatements
keyBackups keyBackupVersionStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
bcryptCost int bcryptCost int
openIDTokenLifetimeMS int64 openIDTokenLifetimeMS int64
@ -93,6 +94,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.openIDTokens.prepare(db, serverName); err != nil { if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err return nil, err
} }
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
return d, nil return d, nil
} }
@ -368,3 +372,42 @@ func (d *Database) GetOpenIDTokenAttributes(
) (*api.OpenIDTokenAttributes, error) { ) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
} }
func (d *Database) CreateKeyBackup(
ctx context.Context, userID, algorithm string, authData json.RawMessage,
) (version string, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
version, err = d.keyBackups.insertKeyBackup(ctx, txn, userID, algorithm, authData)
return err
})
return
}
func (d *Database) UpdateKeyBackupAuthData(
ctx context.Context, userID, version string, authData json.RawMessage,
) (err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.keyBackups.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
})
return
}
func (d *Database) DeleteKeyBackup(
ctx context.Context, userID, version string,
) (exists bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
exists, err = d.keyBackups.deleteKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetKeyBackup(
ctx context.Context, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
versionResult, algorithm, authData, deleted, err = d.keyBackups.selectKeyBackup(ctx, txn, userID, version)
return err
})
return
}

View file

@ -0,0 +1,142 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strconv"
)
const keyBackupVersionTableSchema = `
-- the metadata for each generation of encrypted e2e session backups
CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
user_id TEXT NOT NULL,
-- this means no 2 users will ever have the same version of e2e session backups which strictly
-- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1.
version INTEGER PRIMARY KEY AUTOINCREMENT,
algorithm TEXT NOT NULL,
auth_data TEXT NOT NULL,
deleted INTEGER DEFAULT 0 NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
`
const insertKeyBackupSQL = "" +
"INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data) VALUES ($1, $2, $3) RETURNING version"
const updateKeyBackupAuthDataSQL = "" + // TODO: do we need to WHERE algorithm = $3 as well?
"UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
const deleteKeyBackupSQL = "" +
"UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
const selectKeyBackupSQL = "" +
"SELECT algorithm, auth_data, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2"
const selectLatestVersionSQL = "" +
"SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1"
type keyBackupVersionStatements struct {
insertKeyBackupStmt *sql.Stmt
updateKeyBackupAuthDataStmt *sql.Stmt
deleteKeyBackupStmt *sql.Stmt
selectKeyBackupStmt *sql.Stmt
selectLatestVersionStmt *sql.Stmt
}
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupVersionTableSchema)
if err != nil {
return
}
if s.insertKeyBackupStmt, err = db.Prepare(insertKeyBackupSQL); err != nil {
return
}
if s.updateKeyBackupAuthDataStmt, err = db.Prepare(updateKeyBackupAuthDataSQL); err != nil {
return
}
if s.deleteKeyBackupStmt, err = db.Prepare(deleteKeyBackupSQL); err != nil {
return
}
if s.selectKeyBackupStmt, err = db.Prepare(selectKeyBackupSQL); err != nil {
return
}
if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil {
return
}
return
}
func (s *keyBackupVersionStatements) insertKeyBackup(
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage,
) (version string, err error) {
var versionInt int64
err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData)).Scan(&versionInt)
return strconv.FormatInt(versionInt, 10), err
}
func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
) error {
versionInt, err := strconv.ParseInt(version, 10, 64)
if err != nil {
return fmt.Errorf("invalid version")
}
_, err = txn.Stmt(s.updateKeyBackupAuthDataStmt).ExecContext(ctx, string(authData), userID, versionInt)
return err
}
func (s *keyBackupVersionStatements) deleteKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string,
) (bool, error) {
versionInt, err := strconv.ParseInt(version, 10, 64)
if err != nil {
return false, fmt.Errorf("invalid version")
}
result, err := txn.Stmt(s.deleteKeyBackupStmt).ExecContext(ctx, userID, versionInt)
if err != nil {
return false, err
}
ra, err := result.RowsAffected()
if err != nil {
return false, err
}
return ra == 1, nil
}
func (s *keyBackupVersionStatements) selectKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) {
var versionInt int64
if version == "" {
err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt)
} else {
versionInt, err = strconv.ParseInt(version, 10, 64)
}
if err != nil {
return
}
versionResult = strconv.FormatInt(versionInt, 10)
var deletedInt int
var authDataStr string
err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &deletedInt)
deleted = deletedInt == 1
authData = json.RawMessage(authDataStr)
return
}

View file

@ -43,6 +43,7 @@ type Database struct {
accountDatas accountDataStatements accountDatas accountDataStatements
threepids threepidStatements threepids threepidStatements
openIDTokens tokenStatements openIDTokens tokenStatements
keyBackups keyBackupVersionStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
bcryptCost int bcryptCost int
openIDTokenLifetimeMS int64 openIDTokenLifetimeMS int64
@ -97,6 +98,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.openIDTokens.prepare(db, serverName); err != nil { if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err return nil, err
} }
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
return d, nil return d, nil
} }
@ -406,3 +410,42 @@ func (d *Database) GetOpenIDTokenAttributes(
) (*api.OpenIDTokenAttributes, error) { ) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
} }
func (d *Database) CreateKeyBackup(
ctx context.Context, userID, algorithm string, authData json.RawMessage,
) (version string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
version, err = d.keyBackups.insertKeyBackup(ctx, txn, userID, algorithm, authData)
return err
})
return
}
func (d *Database) UpdateKeyBackupAuthData(
ctx context.Context, userID, version string, authData json.RawMessage,
) (err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.keyBackups.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
})
return
}
func (d *Database) DeleteKeyBackup(
ctx context.Context, userID, version string,
) (exists bool, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
exists, err = d.keyBackups.deleteKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetKeyBackup(
ctx context.Context, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
versionResult, algorithm, authData, deleted, err = d.keyBackups.selectKeyBackup(ctx, txn, userID, version)
return err
})
return
}