dendrite/userapi/storage/accounts/cosmosdb/key_backup_table.go
alexfca fd7f25479b
Upgrade Dendrite 0.5.0 support for CosmosDB (#15)
* - Add CosmosDB back
- Add missing methods to blacklist_table.go
- Add missing methods to device_keys_table.go
- Add missing methods to events_table.go
- Add missing methods to membership_table.go
- Update state_block_table.go (due to reafctor SQL)
- Update state_snapshot_table.go (due to reafctor SQL)
- Add new key_backup_table.go
- Add new key_backup_version_table.go
- Code compiles but has runtime errors

* Message sending + receiving working
Rooms and DMs working
- Add CrossSigningKeys table
- Add CrossSigningSigs table
- Refactor DeviceKeys yable
- Fix OneTimeKeys
- Update the KeyServer storage.go to use a PartitionStorer instead of a specific SQL PartitionOffsetStatements
- Fix small issues from the previous commit
- Implement DeleteSendToDeviceMessages

Co-authored-by: alexf@example.com <alexf@example.com>
2021-09-10 16:04:17 +10:00

415 lines
15 KiB
Go

// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cosmosdb
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
// const keyBackupTableSchema = `
// CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
// user_id TEXT NOT NULL,
// room_id TEXT NOT NULL,
// session_id TEXT NOT NULL,
// version TEXT NOT NULL,
// first_message_index INTEGER NOT NULL,
// forwarded_count INTEGER NOT NULL,
// is_verified BOOLEAN NOT NULL,
// session_data TEXT NOT NULL
// );
// CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
// CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
// `
type KeyBackupCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Tn string `json:"_sid"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
KeyBackup KeyBackupCosmos `json:"mx_userapi_account_e2e_room_keys"`
}
type KeyBackupCosmos struct {
UserId string `json:"user_id"`
RoomId string `json:"room_id"`
SessionId string `json:"session_id"`
Version string `json:"vesion"`
FirstMessageIndex int `json:"first_message_index"`
ForwardedCount int `json:"forwarded_count"`
IsVerified bool `json:"is_verified"`
SessionData []byte `json:"session_data"`
}
type KeyBackupCosmosNumber struct {
Number int64 `json:"number"`
}
// const insertBackupKeySQL = "" +
// "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
// "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
// const updateBackupKeySQL = "" +
// "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
// "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8"
// "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2"
const countKeysSQL = "" +
"select count(c._ts) as number from c where c._cn = @x1 " +
"and c.mx_userapi_account_e2e_room_keys.user_id = @x2 " +
"and c.mx_userapi_account_e2e_room_keys.version = @x3 "
// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
// "WHERE user_id = $1 AND version = $2"
const selectKeysSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_userapi_account_e2e_room_keys.user_id = @x2 " +
"and c.mx_userapi_account_e2e_room_keys.version = @x3 "
// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
// "WHERE user_id = $1 AND version = $2 AND room_id = $3"
const selectKeysByRoomIDSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_userapi_account_e2e_room_keys.user_id = @x2 " +
"and c.mx_userapi_account_e2e_room_keys.version = @x3 " +
"and c.mx_userapi_account_e2e_room_keys.room_id = @x4 "
// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
// "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
const selectKeysByRoomIDAndSessionIDSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_userapi_account_e2e_room_keys.user_id = @x2 " +
"and c.mx_userapi_account_e2e_room_keys.version = @x3 " +
"and c.mx_userapi_account_e2e_room_keys.room_id = @x4 " +
"and c.mx_userapi_account_e2e_room_keys.session_id = @x5 "
type keyBackupStatements struct {
db *Database
// insertBackupKeyStmt *sql.Stmt
// updateBackupKeyStmt *sql.Stmt
countKeysStmt string
selectKeysStmt string
selectKeysByRoomIDStmt string
selectKeysByRoomIDAndSessionIDStmt string
tableName string
serverName gomatrixserverlib.ServerName
}
func queryKeyBackup(s *keyBackupStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyBackupCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []KeyBackupCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryKeyBackupNumber(s *keyBackupStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyBackupCosmosNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []KeyBackupCosmosNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getKeyBackup(s *keyBackupStatements, ctx context.Context, pk string, docId string) (*KeyBackupCosmosData, error) {
response := KeyBackupCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, nil
}
return &response, err
}
func setKeyBackup(s *keyBackupStatements, ctx context.Context, keyBackup KeyBackupCosmosData) (*KeyBackupCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(keyBackup.Pk, keyBackup.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
keyBackup.Id,
&keyBackup,
optionsReplace)
return &keyBackup, ex
}
func (s *keyBackupStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
s.db = db
// s.insertBackupKeyStmt = insertBackupKeySQL
// s.updateBackupKeyStmt = updateBackupKeySQL
s.countKeysStmt = countKeysSQL
s.selectKeysStmt = selectKeysSQL
s.selectKeysByRoomIDStmt = selectKeysByRoomIDSQL
s.selectKeysByRoomIDAndSessionIDStmt = selectKeysByRoomIDAndSessionIDSQL
s.tableName = "account_e2e_room_keys"
s.serverName = server
return
}
func (s keyBackupStatements) countKeys(
ctx context.Context, userID, version string,
) (count int64, err error) {
// "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2"
// err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": userID,
"@x3": version,
}
rows, err := queryKeyBackupNumber(&s, ctx, s.countKeysStmt, params)
if err != nil {
return -1, err
}
if len(rows) == 0 {
return -1, nil
}
// err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
count = rows[0].Number
return
}
func (s *keyBackupStatements) insertBackupKey(
ctx context.Context, userID, version string, key api.InternalKeyBackupSession,
) (err error) {
// "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
// "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
// _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
// ctx, userID, key.RoomID, key.SessionID, version, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData),
// )
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
docId := fmt.Sprintf("%s_%s_%s_%s", userID, key.RoomID, key.SessionID, version)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
data := KeyBackupCosmos{
UserId: userID,
RoomId: key.RoomID,
SessionId: key.SessionID,
Version: version,
FirstMessageIndex: key.FirstMessageIndex,
ForwardedCount: key.ForwardedCount,
IsVerified: key.IsVerified,
SessionData: key.SessionData,
}
dbData := &KeyBackupCosmosData{
Id: cosmosDocId,
Tn: s.db.cosmosConfig.TenantName,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
KeyBackup: data,
}
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return
}
func (s *keyBackupStatements) updateBackupKey(
ctx context.Context, userID, version string, key api.InternalKeyBackupSession,
) (err error) {
// "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
// "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8"
// _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
// ctx, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), userID, key.RoomID, key.SessionID, version,
// )
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
docId := fmt.Sprintf("%s_%s_%s_%s", userID, key.RoomID, key.SessionID, version)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
res, err := getKeyBackup(s, ctx, pk, cosmosDocId)
if err != nil {
return
}
if res == nil {
return
}
// ctx, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), userID, key.RoomID, key.SessionID, version,
res.KeyBackup.FirstMessageIndex = key.FirstMessageIndex
res.KeyBackup.ForwardedCount = key.ForwardedCount
res.KeyBackup.IsVerified = key.IsVerified
res.KeyBackup.SessionData = key.SessionData
_, err = setKeyBackup(s, ctx, *res)
return
}
func (s *keyBackupStatements) selectKeys(
ctx context.Context, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) {
// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
// "WHERE user_id = $1 AND version = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": userID,
"@x3": version,
}
rows, err := queryKeyBackup(s, ctx, s.selectKeysStmt, params)
if err != nil {
return nil, err
}
if len(rows) == 0 {
return nil, nil
}
// rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
return unpackKeys(ctx, &rows)
}
func (s *keyBackupStatements) selectKeysByRoomID(
ctx context.Context, userID, version, roomID string,
) (map[string]map[string]api.KeyBackupSession, error) {
// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
// "WHERE user_id = $1 AND version = $2 AND room_id = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": userID,
"@x3": version,
"@x4": roomID,
}
rows, err := queryKeyBackup(s, ctx, s.selectKeysByRoomIDStmt, params)
if err != nil {
return nil, err
}
if len(rows) == 0 {
return nil, nil
}
// rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
if err != nil {
return nil, err
}
return unpackKeys(ctx, &rows)
}
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
ctx context.Context, userID, version, roomID, sessionID string,
) (map[string]map[string]api.KeyBackupSession, error) {
// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
// "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": userID,
"@x3": version,
"@x4": roomID,
"@x5": sessionID,
}
rows, err := queryKeyBackup(s, ctx, s.selectKeysByRoomIDAndSessionIDStmt, params)
if err != nil {
return nil, err
}
if len(rows) == 0 {
return nil, nil
}
// rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
if err != nil {
return nil, err
}
return unpackKeys(ctx, &rows)
}
func unpackKeys(ctx context.Context, rows *[]KeyBackupCosmosData) (map[string]map[string]api.KeyBackupSession, error) {
result := make(map[string]map[string]api.KeyBackupSession)
for _, item := range *rows {
var key api.InternalKeyBackupSession
// room_id, session_id, first_message_index, forwarded_count, is_verified, session_data
var sessionDataStr string
// if err := rows.Scan(&key.RoomID, &key.SessionID, &key.FirstMessageIndex, &key.ForwardedCount, &key.IsVerified, &sessionDataStr); err != nil {
// return nil, err
// }
key.RoomID = item.KeyBackup.RoomId
key.SessionID = item.KeyBackup.SessionId
key.FirstMessageIndex = item.KeyBackup.FirstMessageIndex
key.ForwardedCount = item.KeyBackup.ForwardedCount
key.SessionData = json.RawMessage(sessionDataStr)
roomData := result[key.RoomID]
if roomData == nil {
roomData = make(map[string]api.KeyBackupSession)
}
roomData[key.SessionID] = key.KeyBackupSession
result[key.RoomID] = roomData
}
return result, nil
}