mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 08:13:09 -06:00
Implement Cosmos DB for the KeyServer Service (#6)
* - Implement Cosmos for the devices_table - Use the ConnectionString in the YAML to include the Tenant - Revert all other non implemented tables back to use SQLLite3 * - Change the Config to use "test.criticicalarc.com" Container - Add generic function GetDocumentOrNil to standardize GetDocument - Add func to return CrossPartition queries for Aggregates - Add func GetNextSequence() as generic seq generator for AutoIncrement - Add cosmosdbutil.ErrNoRows to return (emulate) sql.ErrNoRows - Add a "fake" ExclusiveWriterFake - Add standard "getXX", "setXX" and "queryXX" to all TABLE class files - Add specific Table SEQ for the Events table - Add specific Table SEQ for the Rooms table - Add specific Table SEQ for the StateSnapshot table * - Use CosmosDB for the KeyServer - Replace the ConnString in the YAML to Cosmos - Update the 4 tables to use Cosmos
This commit is contained in:
parent
5d68daef80
commit
b4382bd8b9
|
|
@ -228,7 +228,7 @@ key_server:
|
|||
listen: http://localhost:7779
|
||||
connect: http://localhost:7779
|
||||
database:
|
||||
connection_string: file:keyserver.db
|
||||
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;"
|
||||
max_open_conns: 10
|
||||
max_idle_conns: 2
|
||||
conn_max_lifetime: -1
|
||||
|
|
|
|||
16
internal/cosmosdbutil/writer.go
Normal file
16
internal/cosmosdbutil/writer.go
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
package cosmosdbutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// The Writer interface is designed to solve the problem of how
|
||||
// to handle database writes for database engines that don't allow
|
||||
// concurrent writes, e.g. SQLite.
|
||||
//
|
||||
|
||||
// Copied for CosmosDB compatibility
|
||||
|
||||
type Writer interface {
|
||||
Do(db *sql.DB, txn *sql.Tx ,f func(txn *sql.Tx) error) error
|
||||
}
|
||||
|
|
@ -17,134 +17,318 @@ package cosmosdb
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
)
|
||||
|
||||
var deviceKeysSchema = `
|
||||
-- Stores device keys for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
stream_id BIGINT NOT NULL,
|
||||
display_name TEXT,
|
||||
-- Clobber based on tuple of user/device.
|
||||
UNIQUE (user_id, device_id)
|
||||
);
|
||||
`
|
||||
// var deviceKeysSchema = `
|
||||
// -- Stores device keys for users
|
||||
// CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
||||
// user_id TEXT NOT NULL,
|
||||
// device_id TEXT NOT NULL,
|
||||
// ts_added_secs BIGINT NOT NULL,
|
||||
// key_json TEXT NOT NULL,
|
||||
// stream_id BIGINT NOT NULL,
|
||||
// display_name TEXT,
|
||||
// -- Clobber based on tuple of user/device.
|
||||
// UNIQUE (user_id, device_id)
|
||||
// );
|
||||
// `
|
||||
|
||||
const upsertDeviceKeysSQL = "" +
|
||||
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT (user_id, device_id)" +
|
||||
" DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
|
||||
|
||||
const selectDeviceKeysSQL = "" +
|
||||
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
const selectBatchDeviceKeysSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||
|
||||
const selectMaxStreamForUserSQL = "" +
|
||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const countStreamIDsForUserSQL = "" +
|
||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
|
||||
|
||||
const deleteAllDeviceKeysSQL = "" +
|
||||
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
type deviceKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertDeviceKeysStmt *sql.Stmt
|
||||
selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
selectMaxStreamForUserStmt *sql.Stmt
|
||||
deleteAllDeviceKeysStmt *sql.Stmt
|
||||
type DeviceKeyCosmos struct {
|
||||
UserID string `json:"user_id"`
|
||||
DeviceID string `json:"device_id"`
|
||||
// Use the CosmosDB.Timestamp for this one
|
||||
// TSAddedSecs int64 `json:"ts_added_secs"`
|
||||
KeyJSON []byte `json:"key_json"`
|
||||
StreamID int `json:"stream_id"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
|
||||
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||
s := &deviceKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(deviceKeysSchema)
|
||||
type DeviceKeyCosmosNumber struct {
|
||||
Number int64 `json:"number"`
|
||||
}
|
||||
|
||||
type DeviceKeyCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
DeviceKey DeviceKeyCosmos `json:"mx_keyserver_device_key"`
|
||||
}
|
||||
|
||||
// const upsertDeviceKeysSQL = "" +
|
||||
// "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
|
||||
// " VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
// " ON CONFLICT (user_id, device_id)" +
|
||||
// " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
|
||||
|
||||
// const selectDeviceKeysSQL = "" +
|
||||
// "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
// "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||
const selectBatchDeviceKeysSQL = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
"and c.mx_keyserver_device_key.user_id = @x2 " +
|
||||
"and c.mx_keyserver_device_key.key_json <> \"\""
|
||||
|
||||
// "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||
const selectMaxStreamForUserSQL = "" +
|
||||
"select max(c.mx_keyserver_device_key.stream_id) as number from c where c._cn = @x1 " +
|
||||
"and c.mx_keyserver_device_key.user_id = @x2 "
|
||||
|
||||
// "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
|
||||
const countStreamIDsForUserSQL = "" +
|
||||
"select count(c._ts) as number from c where c._cn = @x1 " +
|
||||
"and c.mx_keyserver_device_key.user_id = @x2 " +
|
||||
"and ARRAY_CONTAINS(@x3, c.mx_keyserver_device_key.stream_id) "
|
||||
|
||||
const selectAllDeviceKeysSQL = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
"and c.mx_keyserver_device_key.user_id = @x2 "
|
||||
|
||||
// const deleteAllDeviceKeysSQL = "" +
|
||||
// "DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
func queryDeviceKey(s *deviceKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceKeyCosmosData, error) {
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
var response []DeviceKeyCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func queryDeviceKeyNumber(s *deviceKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceKeyCosmosNumber, error) {
|
||||
var response []DeviceKeyCosmosNumber
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil {
|
||||
return nil, err
|
||||
|
||||
if len(response) == 0 {
|
||||
return nil, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
||||
return nil, err
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func getDeviceKey(s *deviceKeysStatements, ctx context.Context, pk string, docId string) (*DeviceKeyCosmosData, error) {
|
||||
response := DeviceKeyCosmosData{}
|
||||
err := cosmosdbapi.GetDocumentOrNil(
|
||||
s.db.connection,
|
||||
s.db.cosmosConfig,
|
||||
ctx,
|
||||
pk,
|
||||
docId,
|
||||
&response)
|
||||
|
||||
if response.Id == "" {
|
||||
return nil, cosmosdbutil.ErrNoRows
|
||||
}
|
||||
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||
return nil, err
|
||||
|
||||
return &response, err
|
||||
}
|
||||
|
||||
func setDeviceKey(s *deviceKeysStatements, ctx context.Context, pk string, event DeviceKeyCosmosData) (*DeviceKeyCosmosData, error) {
|
||||
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, event.ETag)
|
||||
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
event.Id,
|
||||
&event,
|
||||
optionsReplace)
|
||||
return &event, ex
|
||||
}
|
||||
|
||||
func insertDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData DeviceKeyCosmosData) error {
|
||||
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
dbData,
|
||||
options)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
|
||||
return nil, err
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func mapFromDeviceKeyMessage(key api.DeviceMessage) DeviceKeyCosmos {
|
||||
return DeviceKeyCosmos{
|
||||
DeviceID: key.DeviceID,
|
||||
DisplayName: key.DisplayName,
|
||||
KeyJSON: key.KeyJSON,
|
||||
StreamID: key.StreamID,
|
||||
UserID: key.UserID,
|
||||
}
|
||||
}
|
||||
|
||||
type deviceKeysStatements struct {
|
||||
db *Database
|
||||
// upsertDeviceKeysStmt *sql.Stmt
|
||||
// selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt string
|
||||
selectMaxStreamForUserStmt string
|
||||
// deleteAllDeviceKeysStmt *sql.Stmt
|
||||
tableName string
|
||||
}
|
||||
|
||||
func NewCosmosDBDeviceKeysTable(db *Database) (tables.DeviceKeys, error) {
|
||||
s := &deviceKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
s.selectBatchDeviceKeysStmt = selectBatchDeviceKeysSQL
|
||||
s.selectMaxStreamForUserStmt = selectMaxStreamForUserSQL
|
||||
s.tableName = "device_keys"
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||
func deleteDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData DeviceKeyCosmosData) error {
|
||||
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
dbData.Id,
|
||||
options)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
|
||||
// "DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||
// _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": userID,
|
||||
}
|
||||
response, err := queryDeviceKey(s, ctx, selectAllDeviceKeysSQL, params)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, item := range response {
|
||||
errItem := deleteDeviceKeyCore(s, ctx, item)
|
||||
if errItem != nil {
|
||||
return errItem
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||
deviceIDMap := make(map[string]bool)
|
||||
|
||||
// "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||
|
||||
for _, d := range deviceIDs {
|
||||
deviceIDMap[d] = true
|
||||
}
|
||||
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": userID,
|
||||
}
|
||||
response, err := queryDeviceKey(s, ctx, s.selectBatchDeviceKeysStmt, params)
|
||||
// rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
||||
// defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
||||
|
||||
var result []api.DeviceMessage
|
||||
for rows.Next() {
|
||||
for _, item := range response {
|
||||
var dk api.DeviceMessage
|
||||
dk.UserID = userID
|
||||
var keyJSON string
|
||||
// var keyJSON string
|
||||
var streamID int
|
||||
var displayName sql.NullString
|
||||
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk.KeyJSON = []byte(keyJSON)
|
||||
// var displayName sql.NullString
|
||||
// if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
streamID = item.DeviceKey.StreamID
|
||||
|
||||
dk.KeyJSON = item.DeviceKey.KeyJSON
|
||||
dk.StreamID = streamID
|
||||
if displayName.Valid {
|
||||
dk.DisplayName = displayName.String
|
||||
if len(item.DeviceKey.DisplayName) > 0 {
|
||||
dk.DisplayName = item.DeviceKey.DisplayName
|
||||
}
|
||||
// include the key if we want all keys (no device) or it was asked
|
||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||
result = append(result, dk)
|
||||
}
|
||||
}
|
||||
return result, rows.Err()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
for i, key := range keys {
|
||||
var keyJSONStr string
|
||||
var keyJSON []byte
|
||||
var streamID int
|
||||
var displayName sql.NullString
|
||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
|
||||
// "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
// err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
// UNIQUE (user_id, device_id)
|
||||
docId := fmt.Sprintf("%s_%s", key.UserID, key.DeviceID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
response, err := getDeviceKey(s, ctx, pk, cosmosDocId)
|
||||
|
||||
if err != nil && err != cosmosdbutil.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
if response != nil {
|
||||
keyJSON = response.DeviceKey.KeyJSON
|
||||
streamID = response.DeviceKey.StreamID
|
||||
displayName.String = response.DeviceKey.DisplayName
|
||||
}
|
||||
|
||||
// this will be '' when there is no device
|
||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
||||
keys[i].KeyJSON = keyJSON
|
||||
keys[i].StreamID = streamID
|
||||
if displayName.Valid {
|
||||
keys[i].DisplayName = displayName.String
|
||||
|
|
@ -156,10 +340,30 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
|
|||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
||||
// nullable if there are no results
|
||||
var nullStream sql.NullInt32
|
||||
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
|
||||
// "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": userID,
|
||||
}
|
||||
|
||||
// err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||
response, err := queryDeviceKeyNumber(s, ctx, countStreamIDsForUserSQL, params)
|
||||
|
||||
if err != nil {
|
||||
if err == cosmosdbutil.ErrNoRows {
|
||||
err = nil
|
||||
} else {
|
||||
return nullStream.Int32, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(response) > 0 {
|
||||
nullStream.Int32 = int32(response[0].Number)
|
||||
}
|
||||
|
||||
if nullStream.Valid {
|
||||
streamID = nullStream.Int32
|
||||
}
|
||||
|
|
@ -167,30 +371,66 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
|
|||
}
|
||||
|
||||
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
|
||||
|
||||
// "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
|
||||
|
||||
iStreamIDs := make([]interface{}, len(streamIDs)+1)
|
||||
iStreamIDs[0] = userID
|
||||
for i := range streamIDs {
|
||||
iStreamIDs[i+1] = streamIDs[i]
|
||||
}
|
||||
query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
|
||||
// nullable if there are no results
|
||||
var count sql.NullInt32
|
||||
err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": userID,
|
||||
"@x3": iStreamIDs,
|
||||
}
|
||||
|
||||
// query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
|
||||
// // nullable if there are no results
|
||||
// var count sql.NullInt32
|
||||
// err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
|
||||
|
||||
response, err := queryDeviceKeyNumber(s, ctx, countStreamIDsForUserSQL, params)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if count.Valid {
|
||||
return int(count.Int32), nil
|
||||
// if count.Valid {
|
||||
// return int(count.Int32), nil
|
||||
// }
|
||||
if response[0].Number >= 0 {
|
||||
return int(response[0].Number), nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||
|
||||
// "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
|
||||
// " VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
// " ON CONFLICT (user_id, device_id)" +
|
||||
// " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
for _, key := range keys {
|
||||
now := time.Now().Unix()
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
||||
)
|
||||
// UNIQUE (user_id, device_id)
|
||||
docId := fmt.Sprintf("%s_%s", key.UserID, key.DeviceID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
|
||||
dbData := &DeviceKeyCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: now,
|
||||
DeviceKey: mapFromDeviceKeyMessage(key),
|
||||
}
|
||||
|
||||
err := insertDeviceKeyCore(s, ctx, *dbData)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,64 +16,139 @@ package cosmosdb
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
)
|
||||
|
||||
var keyChangesSchema = `
|
||||
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
partition BIGINT NOT NULL,
|
||||
offset BIGINT NOT NULL,
|
||||
-- The key owner
|
||||
user_id TEXT NOT NULL,
|
||||
UNIQUE (partition, offset)
|
||||
);
|
||||
`
|
||||
// var keyChangesSchema = `
|
||||
// -- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||
// CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
// partition BIGINT NOT NULL,
|
||||
// offset BIGINT NOT NULL,
|
||||
// -- The key owner
|
||||
// user_id TEXT NOT NULL,
|
||||
// UNIQUE (partition, offset)
|
||||
// );
|
||||
// `
|
||||
|
||||
type KeyChangeCosmos struct {
|
||||
Partition int32 `json:"partition"`
|
||||
Offset int64 `json:"_offset"` //offset is reserved
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
type KeyChangeUserMaxCosmosData struct {
|
||||
UserID string `json:"user_id"`
|
||||
MaxOffset int64 `json:"max_offset"`
|
||||
}
|
||||
|
||||
type KeyChangeCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
KeyChange KeyChangeCosmos `json:"mx_keyserver_key_change"`
|
||||
}
|
||||
|
||||
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
|
||||
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
|
||||
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
|
||||
const upsertKeyChangeSQL = "" +
|
||||
"INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
|
||||
" VALUES ($1, $2, $3)" +
|
||||
" ON CONFLICT (partition, offset)" +
|
||||
" DO UPDATE SET user_id = $3"
|
||||
// const upsertKeyChangeSQL = "" +
|
||||
// "INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
|
||||
// " VALUES ($1, $2, $3)" +
|
||||
// " ON CONFLICT (partition, offset)" +
|
||||
// " DO UPDATE SET user_id = $3"
|
||||
|
||||
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
|
||||
// take the max offset value as the latest offset.
|
||||
// "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id"
|
||||
const selectKeyChangesSQL = "" +
|
||||
"SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id"
|
||||
"select c.mx_keyserver_key_change.user_id as user_id, max(c.mx_keyserver_key_change._offset) as max_offset " +
|
||||
"from c where c._cn = @x1 " +
|
||||
"and c.mx_keyserver_key_change.partition = @x2 " +
|
||||
"and c.mx_keyserver_key_change._offset > @x3 " +
|
||||
"and c.mx_keyserver_key_change._offset < @x4 " +
|
||||
"group by c.mx_keyserver_key_change.user_id "
|
||||
|
||||
type keyChangesStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeyChangeStmt *sql.Stmt
|
||||
selectKeyChangesStmt *sql.Stmt
|
||||
db *Database
|
||||
// upsertKeyChangeStmt *sql.Stmt
|
||||
selectKeyChangesStmt string
|
||||
tableName string
|
||||
}
|
||||
|
||||
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
||||
func queryKeyChangeUserMax(s *keyChangesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyChangeUserMaxCosmosData, error) {
|
||||
var response []KeyChangeUserMaxCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
// When there are no Rows we seem to get the generic Bad Req JSON error
|
||||
if err != nil {
|
||||
// return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func NewCosmosDBKeyChangesTable(db *Database) (tables.KeyChanges, error) {
|
||||
s := &keyChangesStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(keyChangesSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.selectKeyChangesStmt = selectKeyChangesSQL
|
||||
s.tableName = "key_changes"
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
||||
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
||||
|
||||
// "INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
|
||||
// " VALUES ($1, $2, $3)" +
|
||||
// " ON CONFLICT (partition, offset)" +
|
||||
// " DO UPDATE SET user_id = $3"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
// UNIQUE (partition, offset)
|
||||
docId := fmt.Sprintf("%d_%d", partition, offset)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
|
||||
data := KeyChangeCosmos{
|
||||
Offset: offset,
|
||||
Partition: partition,
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
dbData := KeyChangeCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
KeyChange: data,
|
||||
}
|
||||
|
||||
// _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
||||
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
dbData,
|
||||
options)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -84,17 +159,29 @@ func (s *keyChangesStatements) SelectKeyChanges(
|
|||
toOffset = math.MaxInt64
|
||||
}
|
||||
latestOffset = fromOffset
|
||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
|
||||
|
||||
// "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id"
|
||||
// rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": partition,
|
||||
"@x3": fromOffset,
|
||||
"@x4": toOffset,
|
||||
}
|
||||
|
||||
response, err := queryKeyChangeUserMax(s, ctx, s.selectKeyChangesStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
|
||||
for _, item := range response {
|
||||
var userID string
|
||||
var offset int64
|
||||
if err := rows.Scan(&userID, &offset); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
userID = item.UserID
|
||||
offset = item.MaxOffset
|
||||
if offset > latestOffset {
|
||||
latestOffset = offset
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,87 +18,194 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
)
|
||||
|
||||
var oneTimeKeysSchema = `
|
||||
-- Stores one-time public keys for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
-- Clobber based on 4-uple of user/device/key/algorithm.
|
||||
UNIQUE (user_id, device_id, key_id, algorithm)
|
||||
);
|
||||
`
|
||||
// var oneTimeKeysSchema = `
|
||||
// -- Stores one-time public keys for users
|
||||
// CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
|
||||
// user_id TEXT NOT NULL,
|
||||
// device_id TEXT NOT NULL,
|
||||
// key_id TEXT NOT NULL,
|
||||
// algorithm TEXT NOT NULL,
|
||||
// ts_added_secs BIGINT NOT NULL,
|
||||
// key_json TEXT NOT NULL,
|
||||
// -- Clobber based on 4-uple of user/device/key/algorithm.
|
||||
// UNIQUE (user_id, device_id, key_id, algorithm)
|
||||
// );
|
||||
// `
|
||||
|
||||
const upsertKeysSQL = "" +
|
||||
"INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT (user_id, device_id, key_id, algorithm)" +
|
||||
" DO UPDATE SET key_json = $6"
|
||||
type OneTimeKeyCosmos struct {
|
||||
UserID string `json:"user_id"`
|
||||
DeviceID string `json:"device_id"`
|
||||
KeyID string `json:"key_id"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
// Use the CosmosDB.Timestamp for this one
|
||||
// ts_added_secs int64 `json:"ts_added_secs"`
|
||||
KeyJSON []byte `json:"key_json"`
|
||||
}
|
||||
|
||||
type OneTimeKeyAlgoCountCosmosData struct {
|
||||
Algorithm string `json:"algorithm"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
type OneTimeKeyCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
OneTimeKey OneTimeKeyCosmos `json:"mx_keyserver_one_time_key"`
|
||||
}
|
||||
|
||||
// const upsertKeysSQL = "" +
|
||||
// "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||
// " VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
// " ON CONFLICT (user_id, device_id, key_id, algorithm)" +
|
||||
// " DO UPDATE SET key_json = $6"
|
||||
|
||||
// "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
|
||||
const selectKeysSQL = "" +
|
||||
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
|
||||
"select * from c where c._cn = @x1 " +
|
||||
"and c.mx_keyserver_one_time_key.user_id = @x2 " +
|
||||
"and c.mx_keyserver_one_time_key.device_id = @x3 "
|
||||
|
||||
// "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
||||
const selectKeysCountSQL = "" +
|
||||
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
||||
"select c.mx_keyserver_one_time_key.algorithm as algorithm, count(c.mx_keyserver_one_time_key.key_id) as count " +
|
||||
"from c where c._cn = @x1 " +
|
||||
"and c.mx_keyserver_one_time_key.user_id = @x2 " +
|
||||
"and c.mx_keyserver_one_time_key.device_id = @x3 " +
|
||||
"group by c.mx_keyserver_one_time_key.algorithm "
|
||||
|
||||
const deleteOneTimeKeySQL = "" +
|
||||
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
|
||||
|
||||
// "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
||||
const selectKeyByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
||||
"select top 1 * from c where c._cn = @x1 " +
|
||||
"and c.mx_keyserver_one_time_key.user_id = @x2 " +
|
||||
"and c.mx_keyserver_one_time_key.device_id = @x3 " +
|
||||
"and c.mx_keyserver_one_time_key.algorithm = @x4 "
|
||||
|
||||
type oneTimeKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeysStmt *sql.Stmt
|
||||
selectKeysStmt *sql.Stmt
|
||||
selectKeysCountStmt *sql.Stmt
|
||||
selectKeyByAlgorithmStmt *sql.Stmt
|
||||
deleteOneTimeKeyStmt *sql.Stmt
|
||||
db *Database
|
||||
// upsertKeysStmt *sql.Stmt
|
||||
selectKeysStmt string
|
||||
selectKeysCountStmt string
|
||||
selectKeyByAlgorithmStmt string
|
||||
// deleteOneTimeKeyStmt *sql.Stmt
|
||||
tableName string
|
||||
}
|
||||
|
||||
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||
s := &oneTimeKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(oneTimeKeysSchema)
|
||||
func queryOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyCosmosData, error) {
|
||||
var response []OneTimeKeyCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil {
|
||||
return nil, err
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func queryOneTimeKeyAlgoCount(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyAlgoCountCosmosData, error) {
|
||||
var response []OneTimeKeyAlgoCountCosmosData
|
||||
var test interface{}
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&test,
|
||||
optionsQry)
|
||||
|
||||
// When there are no Rows we seem to get the generic Bad Req JSON error
|
||||
if err != nil {
|
||||
// return nil, err
|
||||
}
|
||||
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
|
||||
return nil, err
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func insertOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData OneTimeKeyCosmosData) error {
|
||||
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
dbData,
|
||||
options)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
|
||||
return nil, err
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func deleteOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData OneTimeKeyCosmosData) error {
|
||||
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
dbData.Id,
|
||||
options)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
func NewCosmosDBOneTimeKeysTable(db *Database) (tables.OneTimeKeys, error) {
|
||||
s := &oneTimeKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
s.selectKeysStmt = selectKeysSQL
|
||||
s.selectKeysCountStmt = selectKeysCountSQL
|
||||
s.selectKeyByAlgorithmStmt = selectKeyByAlgorithmSQL
|
||||
s.tableName = "one_time_keys"
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID)
|
||||
|
||||
// "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": userID,
|
||||
"@x3": deviceID,
|
||||
}
|
||||
|
||||
response, err := queryOneTimeKey(s, ctx, s.selectKeyByAlgorithmStmt, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
|
||||
|
||||
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
|
||||
for _, ka := range keyIDsWithAlgorithms {
|
||||
|
|
@ -106,19 +213,18 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
|
|||
}
|
||||
|
||||
result := make(map[string]json.RawMessage)
|
||||
for rows.Next() {
|
||||
for _, item := range response {
|
||||
var keyID string
|
||||
var algorithm string
|
||||
var keyJSONStr string
|
||||
if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyID = item.OneTimeKey.KeyID
|
||||
algorithm = item.OneTimeKey.Algorithm
|
||||
|
||||
keyIDWithAlgo := algorithm + ":" + keyID
|
||||
if wantSet[keyIDWithAlgo] {
|
||||
result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
|
||||
result[keyIDWithAlgo] = item.OneTimeKey.KeyJSON
|
||||
}
|
||||
}
|
||||
return result, rows.Err()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||
|
|
@ -127,17 +233,26 @@ func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, de
|
|||
UserID: userID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
|
||||
// rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": counts.UserID,
|
||||
"@x3": counts.DeviceID,
|
||||
}
|
||||
|
||||
// "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
||||
response, err := queryOneTimeKeyAlgoCount(s, ctx, s.selectKeysCountStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
|
||||
for _, item := range response {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
algorithm = item.Algorithm
|
||||
count = item.Count
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
return counts, nil
|
||||
|
|
@ -152,30 +267,68 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(
|
|||
UserID: keys.UserID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
|
||||
// "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||
// " VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
// " ON CONFLICT (user_id, device_id, key_id, algorithm)" +
|
||||
// " DO UPDATE SET key_json = $6"
|
||||
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
|
||||
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
|
||||
// UNIQUE (user_id, device_id, key_id, algorithm)
|
||||
docId := fmt.Sprintf("%s_%s_%s_%s", keys.UserID, keys.DeviceID, keyID, algo)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
|
||||
data := OneTimeKeyCosmos{
|
||||
Algorithm: algo,
|
||||
DeviceID: keys.DeviceID,
|
||||
KeyID: keyID,
|
||||
KeyJSON: keyJSON,
|
||||
UserID: keys.UserID,
|
||||
}
|
||||
|
||||
dbData := &OneTimeKeyCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: now,
|
||||
OneTimeKey: data,
|
||||
}
|
||||
|
||||
err := insertOneTimeKeyCore(s, ctx, *dbData)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
||||
// rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": keys.UserID,
|
||||
"@x3": keys.DeviceID,
|
||||
}
|
||||
|
||||
// "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
||||
response, err := queryOneTimeKeyAlgoCount(s, ctx, s.selectKeysCountStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
|
||||
for _, item := range response {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
algorithm = item.Algorithm
|
||||
count = item.Count
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
|
||||
return counts, rows.Err()
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||
|
|
@ -183,14 +336,25 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
|||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||
|
||||
// "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": userID,
|
||||
"@x3": deviceID,
|
||||
"@x4": algorithm,
|
||||
}
|
||||
|
||||
response, err := queryOneTimeKey(s, ctx, s.selectKeyByAlgorithmStmt, params)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
if err == cosmosdbutil.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||
err = deleteOneTimeKeyCore(s, ctx, response[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,78 +16,154 @@ package cosmosdb
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var staleDeviceListsSchema = `
|
||||
-- Stores whether a user's device lists are stale or not.
|
||||
CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
|
||||
user_id TEXT PRIMARY KEY NOT NULL,
|
||||
domain TEXT NOT NULL,
|
||||
is_stale BOOLEAN NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL
|
||||
);
|
||||
// var staleDeviceListsSchema = `
|
||||
// -- Stores whether a user's device lists are stale or not.
|
||||
// CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
|
||||
// user_id TEXT PRIMARY KEY NOT NULL,
|
||||
// domain TEXT NOT NULL,
|
||||
// is_stale BOOLEAN NOT NULL,
|
||||
// ts_added_secs BIGINT NOT NULL
|
||||
// );
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
|
||||
`
|
||||
// CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
|
||||
// `
|
||||
|
||||
const upsertStaleDeviceListSQL = "" +
|
||||
"INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
|
||||
" VALUES ($1, $2, $3, $4)" +
|
||||
" ON CONFLICT (user_id)" +
|
||||
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
|
||||
|
||||
const selectStaleDeviceListsWithDomainsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
|
||||
|
||||
const selectStaleDeviceListsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
|
||||
|
||||
type staleDeviceListsStatements struct {
|
||||
db *sql.DB
|
||||
upsertStaleDeviceListStmt *sql.Stmt
|
||||
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
||||
selectStaleDeviceListsStmt *sql.Stmt
|
||||
type StaleDeviceListCosmos struct {
|
||||
UserID string `json:"user_id"`
|
||||
Domain string `json:"domain"`
|
||||
IsStale bool `json:"is_stale"`
|
||||
// Use the CosmosDB.Timestamp for this one
|
||||
// ts_added_secs int64 `json:"ts_added_secs"`
|
||||
}
|
||||
|
||||
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
|
||||
s := &staleDeviceListsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(staleDeviceListsSchema)
|
||||
type StaleDeviceListCosmosData struct {
|
||||
Id string `json:"id"`
|
||||
Pk string `json:"_pk"`
|
||||
Cn string `json:"_cn"`
|
||||
ETag string `json:"_etag"`
|
||||
Timestamp int64 `json:"_ts"`
|
||||
StaleDeviceList StaleDeviceListCosmos `json:"mx_keyserver_stale_device_list"`
|
||||
}
|
||||
|
||||
// const upsertStaleDeviceListSQL = "" +
|
||||
// "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
|
||||
// " VALUES ($1, $2, $3, $4)" +
|
||||
// " ON CONFLICT (user_id)" +
|
||||
// " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
|
||||
|
||||
// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
|
||||
const selectStaleDeviceListsWithDomainsSQL = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
"and c.mx_keyserver_stale_device_list.is_stale = @x2 " +
|
||||
"and c.mx_keyserver_stale_device_list.domain = @x3 "
|
||||
|
||||
// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
|
||||
const selectStaleDeviceListsSQL = "" +
|
||||
"select * from c where c._cn = @x1 " +
|
||||
"and c.mx_keyserver_stale_device_list.is_stale = @x2 "
|
||||
|
||||
type staleDeviceListsStatements struct {
|
||||
db *Database
|
||||
// upsertStaleDeviceListStmt *sql.Stmt
|
||||
selectStaleDeviceListsWithDomainsStmt string
|
||||
selectStaleDeviceListsStmt string
|
||||
tableName string
|
||||
}
|
||||
|
||||
func queryStaleDeviceList(s *staleDeviceListsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]StaleDeviceListCosmosData, error) {
|
||||
var response []StaleDeviceListCosmosData
|
||||
|
||||
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
|
||||
var query = cosmosdbapi.GetQuery(qry, params)
|
||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
query,
|
||||
&response,
|
||||
optionsQry)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
|
||||
return nil, err
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func NewCosmosDBStaleDeviceListsTable(db *Database) (tables.StaleDeviceLists, error) {
|
||||
s := &staleDeviceListsStatements{
|
||||
db: db,
|
||||
}
|
||||
s.selectStaleDeviceListsStmt = selectStaleDeviceListsSQL
|
||||
s.selectStaleDeviceListsWithDomainsStmt = selectStaleDeviceListsWithDomainsSQL
|
||||
s.tableName = "stale_device_lists"
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
|
||||
|
||||
// "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
|
||||
// " VALUES ($1, $2, $3, $4)" +
|
||||
// " ON CONFLICT (user_id)" +
|
||||
// " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
|
||||
|
||||
_, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
|
||||
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||
// user_id TEXT PRIMARY KEY NOT NULL,
|
||||
docId := userID
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||
|
||||
data := StaleDeviceListCosmos{
|
||||
Domain: string(domain),
|
||||
IsStale: isStale,
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
dbData := StaleDeviceListCosmosData{
|
||||
Id: cosmosDocId,
|
||||
Cn: dbCollectionName,
|
||||
Pk: pk,
|
||||
Timestamp: time.Now().Unix(),
|
||||
StaleDeviceList: data,
|
||||
}
|
||||
|
||||
// _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
||||
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||
ctx,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
dbData,
|
||||
options)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
// we only query for 1 domain or all domains so optimise for those use cases
|
||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
if len(domains) == 0 {
|
||||
rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
|
||||
|
||||
// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
|
||||
// rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": true,
|
||||
}
|
||||
rows, err := queryStaleDeviceList(s, ctx, s.selectStaleDeviceListsWithDomainsStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -95,7 +171,17 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
|
|||
}
|
||||
var result []string
|
||||
for _, domain := range domains {
|
||||
rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
|
||||
|
||||
// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
|
||||
// rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
|
||||
params := map[string]interface{}{
|
||||
"@x1": dbCollectionName,
|
||||
"@x2": true,
|
||||
"@x3": string(domain),
|
||||
}
|
||||
|
||||
rows, err := queryStaleDeviceList(s, ctx, s.selectStaleDeviceListsWithDomainsStmt, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -108,14 +194,11 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
|
||||
for rows.Next() {
|
||||
func rowsToUserIDs(ctx context.Context, rows []StaleDeviceListCosmosData) (result []string, err error) {
|
||||
for _, item := range rows {
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userID = item.StaleDeviceList.UserID
|
||||
result = append(result, userID)
|
||||
}
|
||||
return result, rows.Err()
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,35 +15,53 @@
|
|||
package cosmosdb
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// A Database is used to store room events and stream offsets.
|
||||
type Database struct {
|
||||
shared.Database
|
||||
connection cosmosdbapi.CosmosConnection
|
||||
databaseName string
|
||||
cosmosConfig cosmosdbapi.CosmosConfig
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
|
||||
db, err := sqlutil.Open(dbProperties)
|
||||
conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
|
||||
config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
|
||||
d := &Database{
|
||||
databaseName: "keyserver",
|
||||
connection: conn,
|
||||
cosmosConfig: config,
|
||||
}
|
||||
|
||||
// db, err := sqlutil.Open(dbProperties)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
otk, err := NewCosmosDBOneTimeKeysTable(d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
otk, err := NewSqliteOneTimeKeysTable(db)
|
||||
dk, err := NewCosmosDBDeviceKeysTable(d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk, err := NewSqliteDeviceKeysTable(db)
|
||||
kc, err := NewCosmosDBKeyChangesTable(d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kc, err := NewSqliteKeyChangesTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sdl, err := NewSqliteStaleDeviceListsTable(db)
|
||||
sdl, err := NewCosmosDBStaleDeviceListsTable(d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &shared.Database{
|
||||
DB: db,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
Writer: cosmosdbutil.NewExclusiveWriterFake(),
|
||||
OneTimeKeysTable: otk,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
|
|
|
|||
Loading…
Reference in a new issue