mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-28 09:13:09 -06:00
Implement Cosmos DB for the KeyServer Service (#6)
* - Implement Cosmos for the devices_table - Use the ConnectionString in the YAML to include the Tenant - Revert all other non implemented tables back to use SQLLite3 * - Change the Config to use "test.criticicalarc.com" Container - Add generic function GetDocumentOrNil to standardize GetDocument - Add func to return CrossPartition queries for Aggregates - Add func GetNextSequence() as generic seq generator for AutoIncrement - Add cosmosdbutil.ErrNoRows to return (emulate) sql.ErrNoRows - Add a "fake" ExclusiveWriterFake - Add standard "getXX", "setXX" and "queryXX" to all TABLE class files - Add specific Table SEQ for the Events table - Add specific Table SEQ for the Rooms table - Add specific Table SEQ for the StateSnapshot table * - Use CosmosDB for the KeyServer - Replace the ConnString in the YAML to Cosmos - Update the 4 tables to use Cosmos
This commit is contained in:
parent
5d68daef80
commit
b4382bd8b9
|
|
@ -228,7 +228,7 @@ key_server:
|
||||||
listen: http://localhost:7779
|
listen: http://localhost:7779
|
||||||
connect: http://localhost:7779
|
connect: http://localhost:7779
|
||||||
database:
|
database:
|
||||||
connection_string: file:keyserver.db
|
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;"
|
||||||
max_open_conns: 10
|
max_open_conns: 10
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
|
|
|
||||||
16
internal/cosmosdbutil/writer.go
Normal file
16
internal/cosmosdbutil/writer.go
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
package cosmosdbutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The Writer interface is designed to solve the problem of how
|
||||||
|
// to handle database writes for database engines that don't allow
|
||||||
|
// concurrent writes, e.g. SQLite.
|
||||||
|
//
|
||||||
|
|
||||||
|
// Copied for CosmosDB compatibility
|
||||||
|
|
||||||
|
type Writer interface {
|
||||||
|
Do(db *sql.DB, txn *sql.Tx ,f func(txn *sql.Tx) error) error
|
||||||
|
}
|
||||||
|
|
@ -17,134 +17,318 @@ package cosmosdb
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"strings"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
||||||
var deviceKeysSchema = `
|
// var deviceKeysSchema = `
|
||||||
-- Stores device keys for users
|
// -- Stores device keys for users
|
||||||
CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
// CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
||||||
user_id TEXT NOT NULL,
|
// user_id TEXT NOT NULL,
|
||||||
device_id TEXT NOT NULL,
|
// device_id TEXT NOT NULL,
|
||||||
ts_added_secs BIGINT NOT NULL,
|
// ts_added_secs BIGINT NOT NULL,
|
||||||
key_json TEXT NOT NULL,
|
// key_json TEXT NOT NULL,
|
||||||
stream_id BIGINT NOT NULL,
|
// stream_id BIGINT NOT NULL,
|
||||||
display_name TEXT,
|
// display_name TEXT,
|
||||||
-- Clobber based on tuple of user/device.
|
// -- Clobber based on tuple of user/device.
|
||||||
UNIQUE (user_id, device_id)
|
// UNIQUE (user_id, device_id)
|
||||||
);
|
// );
|
||||||
`
|
// `
|
||||||
|
|
||||||
const upsertDeviceKeysSQL = "" +
|
type DeviceKeyCosmos struct {
|
||||||
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
|
UserID string `json:"user_id"`
|
||||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
DeviceID string `json:"device_id"`
|
||||||
" ON CONFLICT (user_id, device_id)" +
|
// Use the CosmosDB.Timestamp for this one
|
||||||
" DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
|
// TSAddedSecs int64 `json:"ts_added_secs"`
|
||||||
|
KeyJSON []byte `json:"key_json"`
|
||||||
const selectDeviceKeysSQL = "" +
|
StreamID int `json:"stream_id"`
|
||||||
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
DisplayName string `json:"display_name"`
|
||||||
|
|
||||||
const selectBatchDeviceKeysSQL = "" +
|
|
||||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
|
||||||
|
|
||||||
const selectMaxStreamForUserSQL = "" +
|
|
||||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
|
||||||
|
|
||||||
const countStreamIDsForUserSQL = "" +
|
|
||||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
|
|
||||||
|
|
||||||
const deleteAllDeviceKeysSQL = "" +
|
|
||||||
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
|
||||||
|
|
||||||
type deviceKeysStatements struct {
|
|
||||||
db *sql.DB
|
|
||||||
upsertDeviceKeysStmt *sql.Stmt
|
|
||||||
selectDeviceKeysStmt *sql.Stmt
|
|
||||||
selectBatchDeviceKeysStmt *sql.Stmt
|
|
||||||
selectMaxStreamForUserStmt *sql.Stmt
|
|
||||||
deleteAllDeviceKeysStmt *sql.Stmt
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
type DeviceKeyCosmosNumber struct {
|
||||||
s := &deviceKeysStatements{
|
Number int64 `json:"number"`
|
||||||
db: db,
|
}
|
||||||
}
|
|
||||||
_, err := db.Exec(deviceKeysSchema)
|
type DeviceKeyCosmosData struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Pk string `json:"_pk"`
|
||||||
|
Cn string `json:"_cn"`
|
||||||
|
ETag string `json:"_etag"`
|
||||||
|
Timestamp int64 `json:"_ts"`
|
||||||
|
DeviceKey DeviceKeyCosmos `json:"mx_keyserver_device_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// const upsertDeviceKeysSQL = "" +
|
||||||
|
// "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
|
||||||
|
// " VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||||
|
// " ON CONFLICT (user_id, device_id)" +
|
||||||
|
// " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
|
||||||
|
|
||||||
|
// const selectDeviceKeysSQL = "" +
|
||||||
|
// "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||||
|
|
||||||
|
// "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||||
|
const selectBatchDeviceKeysSQL = "" +
|
||||||
|
"select * from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_keyserver_device_key.user_id = @x2 " +
|
||||||
|
"and c.mx_keyserver_device_key.key_json <> \"\""
|
||||||
|
|
||||||
|
// "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
const selectMaxStreamForUserSQL = "" +
|
||||||
|
"select max(c.mx_keyserver_device_key.stream_id) as number from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_keyserver_device_key.user_id = @x2 "
|
||||||
|
|
||||||
|
// "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
|
||||||
|
const countStreamIDsForUserSQL = "" +
|
||||||
|
"select count(c._ts) as number from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_keyserver_device_key.user_id = @x2 " +
|
||||||
|
"and ARRAY_CONTAINS(@x3, c.mx_keyserver_device_key.stream_id) "
|
||||||
|
|
||||||
|
const selectAllDeviceKeysSQL = "" +
|
||||||
|
"select * from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_keyserver_device_key.user_id = @x2 "
|
||||||
|
|
||||||
|
// const deleteAllDeviceKeysSQL = "" +
|
||||||
|
// "DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
|
func queryDeviceKey(s *deviceKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceKeyCosmosData, error) {
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
|
var response []DeviceKeyCosmosData
|
||||||
|
|
||||||
|
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(qry, params)
|
||||||
|
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
optionsQry)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil {
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryDeviceKeyNumber(s *deviceKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceKeyCosmosNumber, error) {
|
||||||
|
var response []DeviceKeyCosmosNumber
|
||||||
|
|
||||||
|
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
|
||||||
|
var query = cosmosdbapi.GetQuery(qry, params)
|
||||||
|
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
optionsQry)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil {
|
|
||||||
return nil, err
|
if len(response) == 0 {
|
||||||
|
return nil, cosmosdbutil.ErrNoRows
|
||||||
}
|
}
|
||||||
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
|
||||||
return nil, err
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDeviceKey(s *deviceKeysStatements, ctx context.Context, pk string, docId string) (*DeviceKeyCosmosData, error) {
|
||||||
|
response := DeviceKeyCosmosData{}
|
||||||
|
err := cosmosdbapi.GetDocumentOrNil(
|
||||||
|
s.db.connection,
|
||||||
|
s.db.cosmosConfig,
|
||||||
|
ctx,
|
||||||
|
pk,
|
||||||
|
docId,
|
||||||
|
&response)
|
||||||
|
|
||||||
|
if response.Id == "" {
|
||||||
|
return nil, cosmosdbutil.ErrNoRows
|
||||||
}
|
}
|
||||||
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
|
||||||
return nil, err
|
return &response, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func setDeviceKey(s *deviceKeysStatements, ctx context.Context, pk string, event DeviceKeyCosmosData) (*DeviceKeyCosmosData, error) {
|
||||||
|
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, event.ETag)
|
||||||
|
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
event.Id,
|
||||||
|
&event,
|
||||||
|
optionsReplace)
|
||||||
|
return &event, ex
|
||||||
|
}
|
||||||
|
|
||||||
|
func insertDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData DeviceKeyCosmosData) error {
|
||||||
|
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||||
|
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
dbData,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
|
|
||||||
return nil, err
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapFromDeviceKeyMessage(key api.DeviceMessage) DeviceKeyCosmos {
|
||||||
|
return DeviceKeyCosmos{
|
||||||
|
DeviceID: key.DeviceID,
|
||||||
|
DisplayName: key.DisplayName,
|
||||||
|
KeyJSON: key.KeyJSON,
|
||||||
|
StreamID: key.StreamID,
|
||||||
|
UserID: key.UserID,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type deviceKeysStatements struct {
|
||||||
|
db *Database
|
||||||
|
// upsertDeviceKeysStmt *sql.Stmt
|
||||||
|
// selectDeviceKeysStmt *sql.Stmt
|
||||||
|
selectBatchDeviceKeysStmt string
|
||||||
|
selectMaxStreamForUserStmt string
|
||||||
|
// deleteAllDeviceKeysStmt *sql.Stmt
|
||||||
|
tableName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCosmosDBDeviceKeysTable(db *Database) (tables.DeviceKeys, error) {
|
||||||
|
s := &deviceKeysStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
s.selectBatchDeviceKeysStmt = selectBatchDeviceKeysSQL
|
||||||
|
s.selectMaxStreamForUserStmt = selectMaxStreamForUserSQL
|
||||||
|
s.tableName = "device_keys"
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
func deleteDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData DeviceKeyCosmosData) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
|
||||||
|
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
dbData.Id,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||||
|
|
||||||
|
// "DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
// _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||||
|
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": userID,
|
||||||
|
}
|
||||||
|
response, err := queryDeviceKey(s, ctx, selectAllDeviceKeysSQL, params)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range response {
|
||||||
|
errItem := deleteDeviceKeyCore(s, ctx, item)
|
||||||
|
if errItem != nil {
|
||||||
|
return errItem
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||||
deviceIDMap := make(map[string]bool)
|
deviceIDMap := make(map[string]bool)
|
||||||
|
|
||||||
|
// "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||||
|
|
||||||
for _, d := range deviceIDs {
|
for _, d := range deviceIDs {
|
||||||
deviceIDMap[d] = true
|
deviceIDMap[d] = true
|
||||||
}
|
}
|
||||||
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": userID,
|
||||||
|
}
|
||||||
|
response, err := queryDeviceKey(s, ctx, s.selectBatchDeviceKeysStmt, params)
|
||||||
|
// rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
// defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
||||||
|
|
||||||
var result []api.DeviceMessage
|
var result []api.DeviceMessage
|
||||||
for rows.Next() {
|
for _, item := range response {
|
||||||
var dk api.DeviceMessage
|
var dk api.DeviceMessage
|
||||||
dk.UserID = userID
|
dk.UserID = userID
|
||||||
var keyJSON string
|
// var keyJSON string
|
||||||
var streamID int
|
var streamID int
|
||||||
var displayName sql.NullString
|
// var displayName sql.NullString
|
||||||
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
// if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
||||||
return nil, err
|
// return nil, err
|
||||||
}
|
// }
|
||||||
dk.KeyJSON = []byte(keyJSON)
|
streamID = item.DeviceKey.StreamID
|
||||||
|
|
||||||
|
dk.KeyJSON = item.DeviceKey.KeyJSON
|
||||||
dk.StreamID = streamID
|
dk.StreamID = streamID
|
||||||
if displayName.Valid {
|
if len(item.DeviceKey.DisplayName) > 0 {
|
||||||
dk.DisplayName = displayName.String
|
dk.DisplayName = item.DeviceKey.DisplayName
|
||||||
}
|
}
|
||||||
// include the key if we want all keys (no device) or it was asked
|
// include the key if we want all keys (no device) or it was asked
|
||||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||||
result = append(result, dk)
|
result = append(result, dk)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result, rows.Err()
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
for i, key := range keys {
|
for i, key := range keys {
|
||||||
var keyJSONStr string
|
var keyJSON []byte
|
||||||
var streamID int
|
var streamID int
|
||||||
var displayName sql.NullString
|
var displayName sql.NullString
|
||||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
|
||||||
if err != nil && err != sql.ErrNoRows {
|
// "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||||
|
|
||||||
|
// err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
// UNIQUE (user_id, device_id)
|
||||||
|
docId := fmt.Sprintf("%s_%s", key.UserID, key.DeviceID)
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||||
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
|
|
||||||
|
response, err := getDeviceKey(s, ctx, pk, cosmosDocId)
|
||||||
|
|
||||||
|
if err != nil && err != cosmosdbutil.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if response != nil {
|
||||||
|
keyJSON = response.DeviceKey.KeyJSON
|
||||||
|
streamID = response.DeviceKey.StreamID
|
||||||
|
displayName.String = response.DeviceKey.DisplayName
|
||||||
|
}
|
||||||
|
|
||||||
// this will be '' when there is no device
|
// this will be '' when there is no device
|
||||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
keys[i].KeyJSON = keyJSON
|
||||||
keys[i].StreamID = streamID
|
keys[i].StreamID = streamID
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
keys[i].DisplayName = displayName.String
|
keys[i].DisplayName = displayName.String
|
||||||
|
|
@ -156,10 +340,30 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
|
||||||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
||||||
// nullable if there are no results
|
// nullable if there are no results
|
||||||
var nullStream sql.NullInt32
|
var nullStream sql.NullInt32
|
||||||
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
|
||||||
if err == sql.ErrNoRows {
|
// "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
err = nil
|
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": userID,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||||
|
response, err := queryDeviceKeyNumber(s, ctx, countStreamIDsForUserSQL, params)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == cosmosdbutil.ErrNoRows {
|
||||||
|
err = nil
|
||||||
|
} else {
|
||||||
|
return nullStream.Int32, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(response) > 0 {
|
||||||
|
nullStream.Int32 = int32(response[0].Number)
|
||||||
|
}
|
||||||
|
|
||||||
if nullStream.Valid {
|
if nullStream.Valid {
|
||||||
streamID = nullStream.Int32
|
streamID = nullStream.Int32
|
||||||
}
|
}
|
||||||
|
|
@ -167,30 +371,66 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
|
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
|
||||||
|
|
||||||
|
// "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
|
||||||
|
|
||||||
iStreamIDs := make([]interface{}, len(streamIDs)+1)
|
iStreamIDs := make([]interface{}, len(streamIDs)+1)
|
||||||
iStreamIDs[0] = userID
|
iStreamIDs[0] = userID
|
||||||
for i := range streamIDs {
|
for i := range streamIDs {
|
||||||
iStreamIDs[i+1] = streamIDs[i]
|
iStreamIDs[i+1] = streamIDs[i]
|
||||||
}
|
}
|
||||||
query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
|
|
||||||
// nullable if there are no results
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
var count sql.NullInt32
|
params := map[string]interface{}{
|
||||||
err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": userID,
|
||||||
|
"@x3": iStreamIDs,
|
||||||
|
}
|
||||||
|
|
||||||
|
// query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
|
||||||
|
// // nullable if there are no results
|
||||||
|
// var count sql.NullInt32
|
||||||
|
// err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
|
||||||
|
|
||||||
|
response, err := queryDeviceKeyNumber(s, ctx, countStreamIDsForUserSQL, params)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
if count.Valid {
|
// if count.Valid {
|
||||||
return int(count.Int32), nil
|
// return int(count.Int32), nil
|
||||||
|
// }
|
||||||
|
if response[0].Number >= 0 {
|
||||||
|
return int(response[0].Number), nil
|
||||||
}
|
}
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||||
|
|
||||||
|
// "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
|
||||||
|
// " VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||||
|
// " ON CONFLICT (user_id, device_id)" +
|
||||||
|
// " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
|
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
// UNIQUE (user_id, device_id)
|
||||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
docId := fmt.Sprintf("%s_%s", key.UserID, key.DeviceID)
|
||||||
)
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||||
|
|
||||||
|
dbData := &DeviceKeyCosmosData{
|
||||||
|
Id: cosmosDocId,
|
||||||
|
Cn: dbCollectionName,
|
||||||
|
Pk: pk,
|
||||||
|
Timestamp: now,
|
||||||
|
DeviceKey: mapFromDeviceKeyMessage(key),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := insertDeviceKeyCore(s, ctx, *dbData)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,64 +16,139 @@ package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Shopify/sarama"
|
"github.com/Shopify/sarama"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
||||||
var keyChangesSchema = `
|
// var keyChangesSchema = `
|
||||||
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
// -- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
// CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||||
partition BIGINT NOT NULL,
|
// partition BIGINT NOT NULL,
|
||||||
offset BIGINT NOT NULL,
|
// offset BIGINT NOT NULL,
|
||||||
-- The key owner
|
// -- The key owner
|
||||||
user_id TEXT NOT NULL,
|
// user_id TEXT NOT NULL,
|
||||||
UNIQUE (partition, offset)
|
// UNIQUE (partition, offset)
|
||||||
);
|
// );
|
||||||
`
|
// `
|
||||||
|
|
||||||
|
type KeyChangeCosmos struct {
|
||||||
|
Partition int32 `json:"partition"`
|
||||||
|
Offset int64 `json:"_offset"` //offset is reserved
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyChangeUserMaxCosmosData struct {
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
MaxOffset int64 `json:"max_offset"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyChangeCosmosData struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Pk string `json:"_pk"`
|
||||||
|
Cn string `json:"_cn"`
|
||||||
|
ETag string `json:"_etag"`
|
||||||
|
Timestamp int64 `json:"_ts"`
|
||||||
|
KeyChange KeyChangeCosmos `json:"mx_keyserver_key_change"`
|
||||||
|
}
|
||||||
|
|
||||||
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
|
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
|
||||||
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
|
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
|
||||||
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
|
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
|
||||||
const upsertKeyChangeSQL = "" +
|
// const upsertKeyChangeSQL = "" +
|
||||||
"INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
|
// "INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
|
||||||
" VALUES ($1, $2, $3)" +
|
// " VALUES ($1, $2, $3)" +
|
||||||
" ON CONFLICT (partition, offset)" +
|
// " ON CONFLICT (partition, offset)" +
|
||||||
" DO UPDATE SET user_id = $3"
|
// " DO UPDATE SET user_id = $3"
|
||||||
|
|
||||||
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
|
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
|
||||||
// take the max offset value as the latest offset.
|
// take the max offset value as the latest offset.
|
||||||
|
// "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id"
|
||||||
const selectKeyChangesSQL = "" +
|
const selectKeyChangesSQL = "" +
|
||||||
"SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id"
|
"select c.mx_keyserver_key_change.user_id as user_id, max(c.mx_keyserver_key_change._offset) as max_offset " +
|
||||||
|
"from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_keyserver_key_change.partition = @x2 " +
|
||||||
|
"and c.mx_keyserver_key_change._offset > @x3 " +
|
||||||
|
"and c.mx_keyserver_key_change._offset < @x4 " +
|
||||||
|
"group by c.mx_keyserver_key_change.user_id "
|
||||||
|
|
||||||
type keyChangesStatements struct {
|
type keyChangesStatements struct {
|
||||||
db *sql.DB
|
db *Database
|
||||||
upsertKeyChangeStmt *sql.Stmt
|
// upsertKeyChangeStmt *sql.Stmt
|
||||||
selectKeyChangesStmt *sql.Stmt
|
selectKeyChangesStmt string
|
||||||
|
tableName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
func queryKeyChangeUserMax(s *keyChangesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyChangeUserMaxCosmosData, error) {
|
||||||
|
var response []KeyChangeUserMaxCosmosData
|
||||||
|
|
||||||
|
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
|
||||||
|
var query = cosmosdbapi.GetQuery(qry, params)
|
||||||
|
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
optionsQry)
|
||||||
|
|
||||||
|
// When there are no Rows we seem to get the generic Bad Req JSON error
|
||||||
|
if err != nil {
|
||||||
|
// return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCosmosDBKeyChangesTable(db *Database) (tables.KeyChanges, error) {
|
||||||
s := &keyChangesStatements{
|
s := &keyChangesStatements{
|
||||||
db: db,
|
db: db,
|
||||||
}
|
}
|
||||||
_, err := db.Exec(keyChangesSchema)
|
s.selectKeyChangesStmt = selectKeyChangesSQL
|
||||||
if err != nil {
|
s.tableName = "key_changes"
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
||||||
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
|
||||||
|
// "INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
|
||||||
|
// " VALUES ($1, $2, $3)" +
|
||||||
|
// " ON CONFLICT (partition, offset)" +
|
||||||
|
// " DO UPDATE SET user_id = $3"
|
||||||
|
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
|
// UNIQUE (partition, offset)
|
||||||
|
docId := fmt.Sprintf("%d_%d", partition, offset)
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||||
|
|
||||||
|
data := KeyChangeCosmos{
|
||||||
|
Offset: offset,
|
||||||
|
Partition: partition,
|
||||||
|
UserID: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
dbData := KeyChangeCosmosData{
|
||||||
|
Id: cosmosDocId,
|
||||||
|
Cn: dbCollectionName,
|
||||||
|
Pk: pk,
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
KeyChange: data,
|
||||||
|
}
|
||||||
|
|
||||||
|
// _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
||||||
|
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||||
|
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
dbData,
|
||||||
|
options)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -84,17 +159,29 @@ func (s *keyChangesStatements) SelectKeyChanges(
|
||||||
toOffset = math.MaxInt64
|
toOffset = math.MaxInt64
|
||||||
}
|
}
|
||||||
latestOffset = fromOffset
|
latestOffset = fromOffset
|
||||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
|
|
||||||
|
// "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id"
|
||||||
|
// rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
|
||||||
|
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": partition,
|
||||||
|
"@x3": fromOffset,
|
||||||
|
"@x4": toOffset,
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := queryKeyChangeUserMax(s, ctx, s.selectKeyChangesStmt, params)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
|
|
||||||
for rows.Next() {
|
for _, item := range response {
|
||||||
var userID string
|
var userID string
|
||||||
var offset int64
|
var offset int64
|
||||||
if err := rows.Scan(&userID, &offset); err != nil {
|
userID = item.UserID
|
||||||
return nil, 0, err
|
offset = item.MaxOffset
|
||||||
}
|
|
||||||
if offset > latestOffset {
|
if offset > latestOffset {
|
||||||
latestOffset = offset
|
latestOffset = offset
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,87 +18,194 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
||||||
var oneTimeKeysSchema = `
|
// var oneTimeKeysSchema = `
|
||||||
-- Stores one-time public keys for users
|
// -- Stores one-time public keys for users
|
||||||
CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
|
// CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
|
||||||
user_id TEXT NOT NULL,
|
// user_id TEXT NOT NULL,
|
||||||
device_id TEXT NOT NULL,
|
// device_id TEXT NOT NULL,
|
||||||
key_id TEXT NOT NULL,
|
// key_id TEXT NOT NULL,
|
||||||
algorithm TEXT NOT NULL,
|
// algorithm TEXT NOT NULL,
|
||||||
ts_added_secs BIGINT NOT NULL,
|
// ts_added_secs BIGINT NOT NULL,
|
||||||
key_json TEXT NOT NULL,
|
// key_json TEXT NOT NULL,
|
||||||
-- Clobber based on 4-uple of user/device/key/algorithm.
|
// -- Clobber based on 4-uple of user/device/key/algorithm.
|
||||||
UNIQUE (user_id, device_id, key_id, algorithm)
|
// UNIQUE (user_id, device_id, key_id, algorithm)
|
||||||
);
|
// );
|
||||||
`
|
// `
|
||||||
|
|
||||||
const upsertKeysSQL = "" +
|
type OneTimeKeyCosmos struct {
|
||||||
"INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
|
UserID string `json:"user_id"`
|
||||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
DeviceID string `json:"device_id"`
|
||||||
" ON CONFLICT (user_id, device_id, key_id, algorithm)" +
|
KeyID string `json:"key_id"`
|
||||||
" DO UPDATE SET key_json = $6"
|
Algorithm string `json:"algorithm"`
|
||||||
|
// Use the CosmosDB.Timestamp for this one
|
||||||
|
// ts_added_secs int64 `json:"ts_added_secs"`
|
||||||
|
KeyJSON []byte `json:"key_json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OneTimeKeyAlgoCountCosmosData struct {
|
||||||
|
Algorithm string `json:"algorithm"`
|
||||||
|
Count int `json:"count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OneTimeKeyCosmosData struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Pk string `json:"_pk"`
|
||||||
|
Cn string `json:"_cn"`
|
||||||
|
ETag string `json:"_etag"`
|
||||||
|
Timestamp int64 `json:"_ts"`
|
||||||
|
OneTimeKey OneTimeKeyCosmos `json:"mx_keyserver_one_time_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// const upsertKeysSQL = "" +
|
||||||
|
// "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||||
|
// " VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||||
|
// " ON CONFLICT (user_id, device_id, key_id, algorithm)" +
|
||||||
|
// " DO UPDATE SET key_json = $6"
|
||||||
|
|
||||||
|
// "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
|
||||||
const selectKeysSQL = "" +
|
const selectKeysSQL = "" +
|
||||||
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
|
"select * from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_keyserver_one_time_key.user_id = @x2 " +
|
||||||
|
"and c.mx_keyserver_one_time_key.device_id = @x3 "
|
||||||
|
|
||||||
|
// "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
||||||
const selectKeysCountSQL = "" +
|
const selectKeysCountSQL = "" +
|
||||||
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
"select c.mx_keyserver_one_time_key.algorithm as algorithm, count(c.mx_keyserver_one_time_key.key_id) as count " +
|
||||||
|
"from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_keyserver_one_time_key.user_id = @x2 " +
|
||||||
|
"and c.mx_keyserver_one_time_key.device_id = @x3 " +
|
||||||
|
"group by c.mx_keyserver_one_time_key.algorithm "
|
||||||
|
|
||||||
const deleteOneTimeKeySQL = "" +
|
const deleteOneTimeKeySQL = "" +
|
||||||
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
|
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
|
||||||
|
|
||||||
|
// "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
||||||
const selectKeyByAlgorithmSQL = "" +
|
const selectKeyByAlgorithmSQL = "" +
|
||||||
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
"select top 1 * from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_keyserver_one_time_key.user_id = @x2 " +
|
||||||
|
"and c.mx_keyserver_one_time_key.device_id = @x3 " +
|
||||||
|
"and c.mx_keyserver_one_time_key.algorithm = @x4 "
|
||||||
|
|
||||||
type oneTimeKeysStatements struct {
|
type oneTimeKeysStatements struct {
|
||||||
db *sql.DB
|
db *Database
|
||||||
upsertKeysStmt *sql.Stmt
|
// upsertKeysStmt *sql.Stmt
|
||||||
selectKeysStmt *sql.Stmt
|
selectKeysStmt string
|
||||||
selectKeysCountStmt *sql.Stmt
|
selectKeysCountStmt string
|
||||||
selectKeyByAlgorithmStmt *sql.Stmt
|
selectKeyByAlgorithmStmt string
|
||||||
deleteOneTimeKeyStmt *sql.Stmt
|
// deleteOneTimeKeyStmt *sql.Stmt
|
||||||
|
tableName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
func queryOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyCosmosData, error) {
|
||||||
s := &oneTimeKeysStatements{
|
var response []OneTimeKeyCosmosData
|
||||||
db: db,
|
|
||||||
}
|
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
|
||||||
_, err := db.Exec(oneTimeKeysSchema)
|
var query = cosmosdbapi.GetQuery(qry, params)
|
||||||
|
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
optionsQry)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil {
|
|
||||||
return nil, err
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryOneTimeKeyAlgoCount(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyAlgoCountCosmosData, error) {
|
||||||
|
var response []OneTimeKeyAlgoCountCosmosData
|
||||||
|
var test interface{}
|
||||||
|
|
||||||
|
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
|
||||||
|
var query = cosmosdbapi.GetQuery(qry, params)
|
||||||
|
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&test,
|
||||||
|
optionsQry)
|
||||||
|
|
||||||
|
// When there are no Rows we seem to get the generic Bad Req JSON error
|
||||||
|
if err != nil {
|
||||||
|
// return nil, err
|
||||||
}
|
}
|
||||||
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
|
|
||||||
return nil, err
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func insertOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData OneTimeKeyCosmosData) error {
|
||||||
|
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||||
|
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
dbData,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
|
|
||||||
return nil, err
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func deleteOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData OneTimeKeyCosmosData) error {
|
||||||
|
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
|
||||||
|
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
dbData.Id,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
|
return err
|
||||||
return nil, err
|
}
|
||||||
}
|
|
||||||
if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
|
func NewCosmosDBOneTimeKeysTable(db *Database) (tables.OneTimeKeys, error) {
|
||||||
return nil, err
|
s := &oneTimeKeysStatements{
|
||||||
|
db: db,
|
||||||
}
|
}
|
||||||
|
s.selectKeysStmt = selectKeysSQL
|
||||||
|
s.selectKeysCountStmt = selectKeysCountSQL
|
||||||
|
s.selectKeyByAlgorithmStmt = selectKeyByAlgorithmSQL
|
||||||
|
s.tableName = "one_time_keys"
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||||
rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID)
|
|
||||||
|
// "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
|
||||||
|
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": userID,
|
||||||
|
"@x3": deviceID,
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := queryOneTimeKey(s, ctx, s.selectKeyByAlgorithmStmt, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
|
|
||||||
|
|
||||||
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
|
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
|
||||||
for _, ka := range keyIDsWithAlgorithms {
|
for _, ka := range keyIDsWithAlgorithms {
|
||||||
|
|
@ -106,19 +213,18 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
|
||||||
}
|
}
|
||||||
|
|
||||||
result := make(map[string]json.RawMessage)
|
result := make(map[string]json.RawMessage)
|
||||||
for rows.Next() {
|
for _, item := range response {
|
||||||
var keyID string
|
var keyID string
|
||||||
var algorithm string
|
var algorithm string
|
||||||
var keyJSONStr string
|
keyID = item.OneTimeKey.KeyID
|
||||||
if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
|
algorithm = item.OneTimeKey.Algorithm
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
keyIDWithAlgo := algorithm + ":" + keyID
|
keyIDWithAlgo := algorithm + ":" + keyID
|
||||||
if wantSet[keyIDWithAlgo] {
|
if wantSet[keyIDWithAlgo] {
|
||||||
result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
|
result[keyIDWithAlgo] = item.OneTimeKey.KeyJSON
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result, rows.Err()
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||||
|
|
@ -127,17 +233,26 @@ func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, de
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
KeyCount: make(map[string]int),
|
KeyCount: make(map[string]int),
|
||||||
}
|
}
|
||||||
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
|
// rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": counts.UserID,
|
||||||
|
"@x3": counts.DeviceID,
|
||||||
|
}
|
||||||
|
|
||||||
|
// "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
||||||
|
response, err := queryOneTimeKeyAlgoCount(s, ctx, s.selectKeysCountStmt, params)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
|
||||||
for rows.Next() {
|
for _, item := range response {
|
||||||
var algorithm string
|
var algorithm string
|
||||||
var count int
|
var count int
|
||||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
algorithm = item.Algorithm
|
||||||
return nil, err
|
count = item.Count
|
||||||
}
|
|
||||||
counts.KeyCount[algorithm] = count
|
counts.KeyCount[algorithm] = count
|
||||||
}
|
}
|
||||||
return counts, nil
|
return counts, nil
|
||||||
|
|
@ -152,30 +267,68 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(
|
||||||
UserID: keys.UserID,
|
UserID: keys.UserID,
|
||||||
KeyCount: make(map[string]int),
|
KeyCount: make(map[string]int),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
|
|
||||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||||
|
|
||||||
|
// "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||||
|
// " VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||||
|
// " ON CONFLICT (user_id, device_id, key_id, algorithm)" +
|
||||||
|
// " DO UPDATE SET key_json = $6"
|
||||||
|
|
||||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||||
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
|
|
||||||
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
// UNIQUE (user_id, device_id, key_id, algorithm)
|
||||||
)
|
docId := fmt.Sprintf("%s_%s_%s_%s", keys.UserID, keys.DeviceID, keyID, algo)
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||||
|
|
||||||
|
data := OneTimeKeyCosmos{
|
||||||
|
Algorithm: algo,
|
||||||
|
DeviceID: keys.DeviceID,
|
||||||
|
KeyID: keyID,
|
||||||
|
KeyJSON: keyJSON,
|
||||||
|
UserID: keys.UserID,
|
||||||
|
}
|
||||||
|
|
||||||
|
dbData := &OneTimeKeyCosmosData{
|
||||||
|
Id: cosmosDocId,
|
||||||
|
Cn: dbCollectionName,
|
||||||
|
Pk: pk,
|
||||||
|
Timestamp: now,
|
||||||
|
OneTimeKey: data,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := insertOneTimeKeyCore(s, ctx, *dbData)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
// rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": keys.UserID,
|
||||||
|
"@x3": keys.DeviceID,
|
||||||
|
}
|
||||||
|
|
||||||
|
// "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
||||||
|
response, err := queryOneTimeKeyAlgoCount(s, ctx, s.selectKeysCountStmt, params)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
|
||||||
for rows.Next() {
|
for _, item := range response {
|
||||||
var algorithm string
|
var algorithm string
|
||||||
var count int
|
var count int
|
||||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
algorithm = item.Algorithm
|
||||||
return nil, err
|
count = item.Count
|
||||||
}
|
|
||||||
counts.KeyCount[algorithm] = count
|
counts.KeyCount[algorithm] = count
|
||||||
}
|
}
|
||||||
|
|
||||||
return counts, rows.Err()
|
return counts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||||
|
|
@ -183,14 +336,25 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||||
) (map[string]json.RawMessage, error) {
|
) (map[string]json.RawMessage, error) {
|
||||||
var keyID string
|
var keyID string
|
||||||
var keyJSON string
|
var keyJSON string
|
||||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
|
||||||
|
// "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
||||||
|
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": userID,
|
||||||
|
"@x3": deviceID,
|
||||||
|
"@x4": algorithm,
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := queryOneTimeKey(s, ctx, s.selectKeyByAlgorithmStmt, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == cosmosdbutil.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
err = deleteOneTimeKeyCore(s, ctx, response[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,78 +16,154 @@ package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
var staleDeviceListsSchema = `
|
// var staleDeviceListsSchema = `
|
||||||
-- Stores whether a user's device lists are stale or not.
|
// -- Stores whether a user's device lists are stale or not.
|
||||||
CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
|
// CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
|
||||||
user_id TEXT PRIMARY KEY NOT NULL,
|
// user_id TEXT PRIMARY KEY NOT NULL,
|
||||||
domain TEXT NOT NULL,
|
// domain TEXT NOT NULL,
|
||||||
is_stale BOOLEAN NOT NULL,
|
// is_stale BOOLEAN NOT NULL,
|
||||||
ts_added_secs BIGINT NOT NULL
|
// ts_added_secs BIGINT NOT NULL
|
||||||
);
|
// );
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
|
// CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
|
||||||
`
|
// `
|
||||||
|
|
||||||
const upsertStaleDeviceListSQL = "" +
|
type StaleDeviceListCosmos struct {
|
||||||
"INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
|
UserID string `json:"user_id"`
|
||||||
" VALUES ($1, $2, $3, $4)" +
|
Domain string `json:"domain"`
|
||||||
" ON CONFLICT (user_id)" +
|
IsStale bool `json:"is_stale"`
|
||||||
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
|
// Use the CosmosDB.Timestamp for this one
|
||||||
|
// ts_added_secs int64 `json:"ts_added_secs"`
|
||||||
const selectStaleDeviceListsWithDomainsSQL = "" +
|
|
||||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
|
|
||||||
|
|
||||||
const selectStaleDeviceListsSQL = "" +
|
|
||||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
|
|
||||||
|
|
||||||
type staleDeviceListsStatements struct {
|
|
||||||
db *sql.DB
|
|
||||||
upsertStaleDeviceListStmt *sql.Stmt
|
|
||||||
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
|
||||||
selectStaleDeviceListsStmt *sql.Stmt
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
|
type StaleDeviceListCosmosData struct {
|
||||||
s := &staleDeviceListsStatements{
|
Id string `json:"id"`
|
||||||
db: db,
|
Pk string `json:"_pk"`
|
||||||
}
|
Cn string `json:"_cn"`
|
||||||
_, err := db.Exec(staleDeviceListsSchema)
|
ETag string `json:"_etag"`
|
||||||
|
Timestamp int64 `json:"_ts"`
|
||||||
|
StaleDeviceList StaleDeviceListCosmos `json:"mx_keyserver_stale_device_list"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// const upsertStaleDeviceListSQL = "" +
|
||||||
|
// "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
|
||||||
|
// " VALUES ($1, $2, $3, $4)" +
|
||||||
|
// " ON CONFLICT (user_id)" +
|
||||||
|
// " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
|
||||||
|
|
||||||
|
// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
|
||||||
|
const selectStaleDeviceListsWithDomainsSQL = "" +
|
||||||
|
"select * from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_keyserver_stale_device_list.is_stale = @x2 " +
|
||||||
|
"and c.mx_keyserver_stale_device_list.domain = @x3 "
|
||||||
|
|
||||||
|
// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
|
||||||
|
const selectStaleDeviceListsSQL = "" +
|
||||||
|
"select * from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_keyserver_stale_device_list.is_stale = @x2 "
|
||||||
|
|
||||||
|
type staleDeviceListsStatements struct {
|
||||||
|
db *Database
|
||||||
|
// upsertStaleDeviceListStmt *sql.Stmt
|
||||||
|
selectStaleDeviceListsWithDomainsStmt string
|
||||||
|
selectStaleDeviceListsStmt string
|
||||||
|
tableName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryStaleDeviceList(s *staleDeviceListsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]StaleDeviceListCosmosData, error) {
|
||||||
|
var response []StaleDeviceListCosmosData
|
||||||
|
|
||||||
|
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
|
||||||
|
var query = cosmosdbapi.GetQuery(qry, params)
|
||||||
|
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
optionsQry)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
|
|
||||||
return nil, err
|
return response, nil
|
||||||
}
|
}
|
||||||
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
|
|
||||||
return nil, err
|
func NewCosmosDBStaleDeviceListsTable(db *Database) (tables.StaleDeviceLists, error) {
|
||||||
}
|
s := &staleDeviceListsStatements{
|
||||||
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
|
db: db,
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
s.selectStaleDeviceListsStmt = selectStaleDeviceListsSQL
|
||||||
|
s.selectStaleDeviceListsWithDomainsStmt = selectStaleDeviceListsWithDomainsSQL
|
||||||
|
s.tableName = "stale_device_lists"
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
|
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
|
||||||
|
|
||||||
|
// "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
|
||||||
|
// " VALUES ($1, $2, $3, $4)" +
|
||||||
|
// " ON CONFLICT (user_id)" +
|
||||||
|
// " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
|
||||||
|
|
||||||
_, domain, err := gomatrixserverlib.SplitID('@', userID)
|
_, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
|
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
|
||||||
|
// user_id TEXT PRIMARY KEY NOT NULL,
|
||||||
|
docId := userID
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
|
||||||
|
|
||||||
|
data := StaleDeviceListCosmos{
|
||||||
|
Domain: string(domain),
|
||||||
|
IsStale: isStale,
|
||||||
|
UserID: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
dbData := StaleDeviceListCosmosData{
|
||||||
|
Id: cosmosDocId,
|
||||||
|
Cn: dbCollectionName,
|
||||||
|
Pk: pk,
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
StaleDeviceList: data,
|
||||||
|
}
|
||||||
|
|
||||||
|
// _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
||||||
|
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||||
|
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
dbData,
|
||||||
|
options)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||||
// we only query for 1 domain or all domains so optimise for those use cases
|
// we only query for 1 domain or all domains so optimise for those use cases
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
if len(domains) == 0 {
|
if len(domains) == 0 {
|
||||||
rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
|
|
||||||
|
// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
|
||||||
|
// rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": true,
|
||||||
|
}
|
||||||
|
rows, err := queryStaleDeviceList(s, ctx, s.selectStaleDeviceListsWithDomainsStmt, params)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -95,7 +171,17 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
|
||||||
}
|
}
|
||||||
var result []string
|
var result []string
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
|
|
||||||
|
// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
|
||||||
|
// rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": true,
|
||||||
|
"@x3": string(domain),
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := queryStaleDeviceList(s, ctx, s.selectStaleDeviceListsWithDomainsStmt, params)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -108,14 +194,11 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
|
func rowsToUserIDs(ctx context.Context, rows []StaleDeviceListCosmosData) (result []string, err error) {
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
|
for _, item := range rows {
|
||||||
for rows.Next() {
|
|
||||||
var userID string
|
var userID string
|
||||||
if err := rows.Scan(&userID); err != nil {
|
userID = item.StaleDeviceList.UserID
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
result = append(result, userID)
|
result = append(result, userID)
|
||||||
}
|
}
|
||||||
return result, rows.Err()
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,35 +15,53 @@
|
||||||
package cosmosdb
|
package cosmosdb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/shared"
|
"github.com/matrix-org/dendrite/keyserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// A Database is used to store room events and stream offsets.
|
||||||
|
type Database struct {
|
||||||
|
shared.Database
|
||||||
|
connection cosmosdbapi.CosmosConnection
|
||||||
|
databaseName string
|
||||||
|
cosmosConfig cosmosdbapi.CosmosConfig
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
|
||||||
db, err := sqlutil.Open(dbProperties)
|
conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
|
||||||
|
config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
|
||||||
|
d := &Database{
|
||||||
|
databaseName: "keyserver",
|
||||||
|
connection: conn,
|
||||||
|
cosmosConfig: config,
|
||||||
|
}
|
||||||
|
|
||||||
|
// db, err := sqlutil.Open(dbProperties)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
otk, err := NewCosmosDBOneTimeKeysTable(d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
otk, err := NewSqliteOneTimeKeysTable(db)
|
dk, err := NewCosmosDBDeviceKeysTable(d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
dk, err := NewSqliteDeviceKeysTable(db)
|
kc, err := NewCosmosDBKeyChangesTable(d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
kc, err := NewSqliteKeyChangesTable(db)
|
sdl, err := NewCosmosDBStaleDeviceListsTable(d)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
sdl, err := NewSqliteStaleDeviceListsTable(db)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &shared.Database{
|
return &shared.Database{
|
||||||
DB: db,
|
Writer: cosmosdbutil.NewExclusiveWriterFake(),
|
||||||
Writer: sqlutil.NewExclusiveWriter(),
|
|
||||||
OneTimeKeysTable: otk,
|
OneTimeKeysTable: otk,
|
||||||
DeviceKeysTable: dk,
|
DeviceKeysTable: dk,
|
||||||
KeyChangesTable: kc,
|
KeyChangesTable: kc,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue