Implement Cosmos DB for the KeyServer Service (#6)

* - Implement Cosmos for the devices_table
- Use the ConnectionString in the YAML to include the Tenant
- Revert all other non implemented tables back to use SQLLite3

* - Change the Config to use "test.criticicalarc.com" Container
- Add generic function GetDocumentOrNil to standardize GetDocument
- Add func to return CrossPartition queries for Aggregates
- Add func GetNextSequence() as generic seq generator for AutoIncrement
- Add cosmosdbutil.ErrNoRows to return (emulate) sql.ErrNoRows
- Add a "fake" ExclusiveWriterFake
- Add standard "getXX", "setXX" and "queryXX" to all TABLE class files
- Add specific Table SEQ for the Events table
- Add specific Table SEQ for the Rooms table
- Add specific Table SEQ for the StateSnapshot table

* - Use CosmosDB for the KeyServer
- Replace the ConnString in the YAML to Cosmos
- Update the 4 tables to use Cosmos
This commit is contained in:
alexfca 2021-05-21 09:34:30 +10:00 committed by GitHub
parent 5d68daef80
commit b4382bd8b9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 872 additions and 264 deletions

View file

@ -228,7 +228,7 @@ key_server:
listen: http://localhost:7779 listen: http://localhost:7779
connect: http://localhost:7779 connect: http://localhost:7779
database: database:
connection_string: file:keyserver.db connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;"
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1

View file

@ -0,0 +1,16 @@
package cosmosdbutil
import (
"database/sql"
)
// The Writer interface is designed to solve the problem of how
// to handle database writes for database engines that don't allow
// concurrent writes, e.g. SQLite.
//
// Copied for CosmosDB compatibility
type Writer interface {
Do(db *sql.DB, txn *sql.Tx ,f func(txn *sql.Tx) error) error
}

View file

@ -17,134 +17,318 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings" "fmt"
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
var deviceKeysSchema = ` // var deviceKeysSchema = `
-- Stores device keys for users // -- Stores device keys for users
CREATE TABLE IF NOT EXISTS keyserver_device_keys ( // CREATE TABLE IF NOT EXISTS keyserver_device_keys (
user_id TEXT NOT NULL, // user_id TEXT NOT NULL,
device_id TEXT NOT NULL, // device_id TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL, // ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL, // key_json TEXT NOT NULL,
stream_id BIGINT NOT NULL, // stream_id BIGINT NOT NULL,
display_name TEXT, // display_name TEXT,
-- Clobber based on tuple of user/device. // -- Clobber based on tuple of user/device.
UNIQUE (user_id, device_id) // UNIQUE (user_id, device_id)
); // );
` // `
const upsertDeviceKeysSQL = "" + type DeviceKeyCosmos struct {
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + UserID string `json:"user_id"`
" VALUES ($1, $2, $3, $4, $5, $6)" + DeviceID string `json:"device_id"`
" ON CONFLICT (user_id, device_id)" + // Use the CosmosDB.Timestamp for this one
" DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" // TSAddedSecs int64 `json:"ts_added_secs"`
KeyJSON []byte `json:"key_json"`
const selectDeviceKeysSQL = "" + StreamID int `json:"stream_id"`
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" DisplayName string `json:"display_name"`
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
const countStreamIDsForUserSQL = "" +
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
db *sql.DB
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
} }
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { type DeviceKeyCosmosNumber struct {
s := &deviceKeysStatements{ Number int64 `json:"number"`
db: db, }
}
_, err := db.Exec(deviceKeysSchema) type DeviceKeyCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
DeviceKey DeviceKeyCosmos `json:"mx_keyserver_device_key"`
}
// const upsertDeviceKeysSQL = "" +
// "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
// " VALUES ($1, $2, $3, $4, $5, $6)" +
// " ON CONFLICT (user_id, device_id)" +
// " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
// const selectDeviceKeysSQL = "" +
// "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
// "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
const selectBatchDeviceKeysSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_keyserver_device_key.user_id = @x2 " +
"and c.mx_keyserver_device_key.key_json <> \"\""
// "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" +
"select max(c.mx_keyserver_device_key.stream_id) as number from c where c._cn = @x1 " +
"and c.mx_keyserver_device_key.user_id = @x2 "
// "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
const countStreamIDsForUserSQL = "" +
"select count(c._ts) as number from c where c._cn = @x1 " +
"and c.mx_keyserver_device_key.user_id = @x2 " +
"and ARRAY_CONTAINS(@x3, c.mx_keyserver_device_key.stream_id) "
const selectAllDeviceKeysSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_keyserver_device_key.user_id = @x2 "
// const deleteAllDeviceKeysSQL = "" +
// "DELETE FROM keyserver_device_keys WHERE user_id=$1"
func queryDeviceKey(s *deviceKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceKeyCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []DeviceKeyCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil { return response, nil
}
func queryDeviceKeyNumber(s *deviceKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceKeyCosmosNumber, error) {
var response []DeviceKeyCosmosNumber
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
var query = cosmosdbapi.GetQuery(qry, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err return nil, err
} }
if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil {
return nil, err if len(response) == 0 {
return nil, cosmosdbutil.ErrNoRows
} }
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err return response, nil
}
func getDeviceKey(s *deviceKeysStatements, ctx context.Context, pk string, docId string) (*DeviceKeyCosmosData, error) {
response := DeviceKeyCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, cosmosdbutil.ErrNoRows
} }
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err return &response, err
}
func setDeviceKey(s *deviceKeysStatements, ctx context.Context, pk string, event DeviceKeyCosmosData) (*DeviceKeyCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, event.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
event.Id,
&event,
optionsReplace)
return &event, ex
}
func insertDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData DeviceKeyCosmosData) error {
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
options)
if err != nil {
return err
} }
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
return nil, err return nil
}
func mapFromDeviceKeyMessage(key api.DeviceMessage) DeviceKeyCosmos {
return DeviceKeyCosmos{
DeviceID: key.DeviceID,
DisplayName: key.DisplayName,
KeyJSON: key.KeyJSON,
StreamID: key.StreamID,
UserID: key.UserID,
} }
}
type deviceKeysStatements struct {
db *Database
// upsertDeviceKeysStmt *sql.Stmt
// selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt string
selectMaxStreamForUserStmt string
// deleteAllDeviceKeysStmt *sql.Stmt
tableName string
}
func NewCosmosDBDeviceKeysTable(db *Database) (tables.DeviceKeys, error) {
s := &deviceKeysStatements{
db: db,
}
s.selectBatchDeviceKeysStmt = selectBatchDeviceKeysSQL
s.selectMaxStreamForUserStmt = selectMaxStreamForUserSQL
s.tableName = "device_keys"
return s, nil return s, nil
} }
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { func deleteDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData DeviceKeyCosmosData) error {
_, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
return err return err
} }
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
// "DELETE FROM keyserver_device_keys WHERE user_id=$1"
// _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": userID,
}
response, err := queryDeviceKey(s, ctx, selectAllDeviceKeysSQL, params)
if err != nil {
return err
}
for _, item := range response {
errItem := deleteDeviceKeyCore(s, ctx, item)
if errItem != nil {
return errItem
}
}
return nil
}
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool) deviceIDMap := make(map[string]bool)
// "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
for _, d := range deviceIDs { for _, d := range deviceIDs {
deviceIDMap[d] = true deviceIDMap[d] = true
} }
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": userID,
}
response, err := queryDeviceKey(s, ctx, s.selectBatchDeviceKeysStmt, params)
// rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") // defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
var result []api.DeviceMessage var result []api.DeviceMessage
for rows.Next() { for _, item := range response {
var dk api.DeviceMessage var dk api.DeviceMessage
dk.UserID = userID dk.UserID = userID
var keyJSON string // var keyJSON string
var streamID int var streamID int
var displayName sql.NullString // var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { // if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
return nil, err // return nil, err
} // }
dk.KeyJSON = []byte(keyJSON) streamID = item.DeviceKey.StreamID
dk.KeyJSON = item.DeviceKey.KeyJSON
dk.StreamID = streamID dk.StreamID = streamID
if displayName.Valid { if len(item.DeviceKey.DisplayName) > 0 {
dk.DisplayName = displayName.String dk.DisplayName = item.DeviceKey.DisplayName
} }
// include the key if we want all keys (no device) or it was asked // include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk) result = append(result, dk)
} }
} }
return result, rows.Err() return result, nil
} }
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys { for i, key := range keys {
var keyJSONStr string var keyJSON []byte
var streamID int var streamID int
var displayName sql.NullString var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows { // "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
// err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (user_id, device_id)
docId := fmt.Sprintf("%s_%s", key.UserID, key.DeviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
response, err := getDeviceKey(s, ctx, pk, cosmosDocId)
if err != nil && err != cosmosdbutil.ErrNoRows {
return err return err
} }
if response != nil {
keyJSON = response.DeviceKey.KeyJSON
streamID = response.DeviceKey.StreamID
displayName.String = response.DeviceKey.DisplayName
}
// this will be '' when there is no device // this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr) keys[i].KeyJSON = keyJSON
keys[i].StreamID = streamID keys[i].StreamID = streamID
if displayName.Valid { if displayName.Valid {
keys[i].DisplayName = displayName.String keys[i].DisplayName = displayName.String
@ -156,10 +340,30 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
// nullable if there are no results // nullable if there are no results
var nullStream sql.NullInt32 var nullStream sql.NullInt32
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows { // "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
err = nil
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": userID,
} }
// err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
response, err := queryDeviceKeyNumber(s, ctx, countStreamIDsForUserSQL, params)
if err != nil {
if err == cosmosdbutil.ErrNoRows {
err = nil
} else {
return nullStream.Int32, err
}
}
if len(response) > 0 {
nullStream.Int32 = int32(response[0].Number)
}
if nullStream.Valid { if nullStream.Valid {
streamID = nullStream.Int32 streamID = nullStream.Int32
} }
@ -167,30 +371,66 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
} }
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) { func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
// "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
iStreamIDs := make([]interface{}, len(streamIDs)+1) iStreamIDs := make([]interface{}, len(streamIDs)+1)
iStreamIDs[0] = userID iStreamIDs[0] = userID
for i := range streamIDs { for i := range streamIDs {
iStreamIDs[i+1] = streamIDs[i] iStreamIDs[i+1] = streamIDs[i]
} }
query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
// nullable if there are no results var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var count sql.NullInt32 params := map[string]interface{}{
err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count) "@x1": dbCollectionName,
"@x2": userID,
"@x3": iStreamIDs,
}
// query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
// // nullable if there are no results
// var count sql.NullInt32
// err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
response, err := queryDeviceKeyNumber(s, ctx, countStreamIDsForUserSQL, params)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if count.Valid { // if count.Valid {
return int(count.Int32), nil // return int(count.Int32), nil
// }
if response[0].Number >= 0 {
return int(response[0].Number), nil
} }
return 0, nil return 0, nil
} }
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
// "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
// " VALUES ($1, $2, $3, $4, $5, $6)" +
// " ON CONFLICT (user_id, device_id)" +
// " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
for _, key := range keys { for _, key := range keys {
now := time.Now().Unix() now := time.Now().Unix()
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( // UNIQUE (user_id, device_id)
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, docId := fmt.Sprintf("%s_%s", key.UserID, key.DeviceID)
) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
dbData := &DeviceKeyCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: now,
DeviceKey: mapFromDeviceKeyMessage(key),
}
err := insertDeviceKeyCore(s, ctx, *dbData)
if err != nil { if err != nil {
return err return err
} }

View file

@ -16,64 +16,139 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "fmt"
"math" "math"
"time"
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
var keyChangesSchema = ` // var keyChangesSchema = `
-- Stores key change information about users. Used to determine when to send updated device lists to clients. // -- Stores key change information about users. Used to determine when to send updated device lists to clients.
CREATE TABLE IF NOT EXISTS keyserver_key_changes ( // CREATE TABLE IF NOT EXISTS keyserver_key_changes (
partition BIGINT NOT NULL, // partition BIGINT NOT NULL,
offset BIGINT NOT NULL, // offset BIGINT NOT NULL,
-- The key owner // -- The key owner
user_id TEXT NOT NULL, // user_id TEXT NOT NULL,
UNIQUE (partition, offset) // UNIQUE (partition, offset)
); // );
` // `
type KeyChangeCosmos struct {
Partition int32 `json:"partition"`
Offset int64 `json:"_offset"` //offset is reserved
UserID string `json:"user_id"`
}
type KeyChangeUserMaxCosmosData struct {
UserID string `json:"user_id"`
MaxOffset int64 `json:"max_offset"`
}
type KeyChangeCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
KeyChange KeyChangeCosmos `json:"mx_keyserver_key_change"`
}
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped. // Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will // Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too. // miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
const upsertKeyChangeSQL = "" + // const upsertKeyChangeSQL = "" +
"INSERT INTO keyserver_key_changes (partition, offset, user_id)" + // "INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
" VALUES ($1, $2, $3)" + // " VALUES ($1, $2, $3)" +
" ON CONFLICT (partition, offset)" + // " ON CONFLICT (partition, offset)" +
" DO UPDATE SET user_id = $3" // " DO UPDATE SET user_id = $3"
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just // select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
// take the max offset value as the latest offset. // take the max offset value as the latest offset.
// "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id"
const selectKeyChangesSQL = "" + const selectKeyChangesSQL = "" +
"SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id" "select c.mx_keyserver_key_change.user_id as user_id, max(c.mx_keyserver_key_change._offset) as max_offset " +
"from c where c._cn = @x1 " +
"and c.mx_keyserver_key_change.partition = @x2 " +
"and c.mx_keyserver_key_change._offset > @x3 " +
"and c.mx_keyserver_key_change._offset < @x4 " +
"group by c.mx_keyserver_key_change.user_id "
type keyChangesStatements struct { type keyChangesStatements struct {
db *sql.DB db *Database
upsertKeyChangeStmt *sql.Stmt // upsertKeyChangeStmt *sql.Stmt
selectKeyChangesStmt *sql.Stmt selectKeyChangesStmt string
tableName string
} }
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { func queryKeyChangeUserMax(s *keyChangesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyChangeUserMaxCosmosData, error) {
var response []KeyChangeUserMaxCosmosData
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
var query = cosmosdbapi.GetQuery(qry, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
// When there are no Rows we seem to get the generic Bad Req JSON error
if err != nil {
// return nil, err
}
return response, nil
}
func NewCosmosDBKeyChangesTable(db *Database) (tables.KeyChanges, error) {
s := &keyChangesStatements{ s := &keyChangesStatements{
db: db, db: db,
} }
_, err := db.Exec(keyChangesSchema) s.selectKeyChangesStmt = selectKeyChangesSQL
if err != nil { s.tableName = "key_changes"
return nil, err
}
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
return nil, err
}
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
// "INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
// " VALUES ($1, $2, $3)" +
// " ON CONFLICT (partition, offset)" +
// " DO UPDATE SET user_id = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// UNIQUE (partition, offset)
docId := fmt.Sprintf("%d_%d", partition, offset)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
data := KeyChangeCosmos{
Offset: offset,
Partition: partition,
UserID: userID,
}
dbData := KeyChangeCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
KeyChange: data,
}
// _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
options)
return err return err
} }
@ -84,17 +159,29 @@ func (s *keyChangesStatements) SelectKeyChanges(
toOffset = math.MaxInt64 toOffset = math.MaxInt64
} }
latestOffset = fromOffset latestOffset = fromOffset
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
// "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id"
// rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": partition,
"@x3": fromOffset,
"@x4": toOffset,
}
response, err := queryKeyChangeUserMax(s, ctx, s.selectKeyChangesStmt, params)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
for rows.Next() { for _, item := range response {
var userID string var userID string
var offset int64 var offset int64
if err := rows.Scan(&userID, &offset); err != nil { userID = item.UserID
return nil, 0, err offset = item.MaxOffset
}
if offset > latestOffset { if offset > latestOffset {
latestOffset = offset latestOffset = offset
} }

View file

@ -18,87 +18,194 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt"
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
var oneTimeKeysSchema = ` // var oneTimeKeysSchema = `
-- Stores one-time public keys for users // -- Stores one-time public keys for users
CREATE TABLE IF NOT EXISTS keyserver_one_time_keys ( // CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
user_id TEXT NOT NULL, // user_id TEXT NOT NULL,
device_id TEXT NOT NULL, // device_id TEXT NOT NULL,
key_id TEXT NOT NULL, // key_id TEXT NOT NULL,
algorithm TEXT NOT NULL, // algorithm TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL, // ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL, // key_json TEXT NOT NULL,
-- Clobber based on 4-uple of user/device/key/algorithm. // -- Clobber based on 4-uple of user/device/key/algorithm.
UNIQUE (user_id, device_id, key_id, algorithm) // UNIQUE (user_id, device_id, key_id, algorithm)
); // );
` // `
const upsertKeysSQL = "" + type OneTimeKeyCosmos struct {
"INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + UserID string `json:"user_id"`
" VALUES ($1, $2, $3, $4, $5, $6)" + DeviceID string `json:"device_id"`
" ON CONFLICT (user_id, device_id, key_id, algorithm)" + KeyID string `json:"key_id"`
" DO UPDATE SET key_json = $6" Algorithm string `json:"algorithm"`
// Use the CosmosDB.Timestamp for this one
// ts_added_secs int64 `json:"ts_added_secs"`
KeyJSON []byte `json:"key_json"`
}
type OneTimeKeyAlgoCountCosmosData struct {
Algorithm string `json:"algorithm"`
Count int `json:"count"`
}
type OneTimeKeyCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
OneTimeKey OneTimeKeyCosmos `json:"mx_keyserver_one_time_key"`
}
// const upsertKeysSQL = "" +
// "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
// " VALUES ($1, $2, $3, $4, $5, $6)" +
// " ON CONFLICT (user_id, device_id, key_id, algorithm)" +
// " DO UPDATE SET key_json = $6"
// "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
const selectKeysSQL = "" + const selectKeysSQL = "" +
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" "select * from c where c._cn = @x1 " +
"and c.mx_keyserver_one_time_key.user_id = @x2 " +
"and c.mx_keyserver_one_time_key.device_id = @x3 "
// "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
const selectKeysCountSQL = "" + const selectKeysCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" "select c.mx_keyserver_one_time_key.algorithm as algorithm, count(c.mx_keyserver_one_time_key.key_id) as count " +
"from c where c._cn = @x1 " +
"and c.mx_keyserver_one_time_key.user_id = @x2 " +
"and c.mx_keyserver_one_time_key.device_id = @x3 " +
"group by c.mx_keyserver_one_time_key.algorithm "
const deleteOneTimeKeySQL = "" + const deleteOneTimeKeySQL = "" +
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4" "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
// "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
const selectKeyByAlgorithmSQL = "" + const selectKeyByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" "select top 1 * from c where c._cn = @x1 " +
"and c.mx_keyserver_one_time_key.user_id = @x2 " +
"and c.mx_keyserver_one_time_key.device_id = @x3 " +
"and c.mx_keyserver_one_time_key.algorithm = @x4 "
type oneTimeKeysStatements struct { type oneTimeKeysStatements struct {
db *sql.DB db *Database
upsertKeysStmt *sql.Stmt // upsertKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt selectKeysStmt string
selectKeysCountStmt *sql.Stmt selectKeysCountStmt string
selectKeyByAlgorithmStmt *sql.Stmt selectKeyByAlgorithmStmt string
deleteOneTimeKeyStmt *sql.Stmt // deleteOneTimeKeyStmt *sql.Stmt
tableName string
} }
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { func queryOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyCosmosData, error) {
s := &oneTimeKeysStatements{ var response []OneTimeKeyCosmosData
db: db,
} var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
_, err := db.Exec(oneTimeKeysSchema) var query = cosmosdbapi.GetQuery(qry, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil {
return nil, err return response, nil
}
func queryOneTimeKeyAlgoCount(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyAlgoCountCosmosData, error) {
var response []OneTimeKeyAlgoCountCosmosData
var test interface{}
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
var query = cosmosdbapi.GetQuery(qry, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&test,
optionsQry)
// When there are no Rows we seem to get the generic Bad Req JSON error
if err != nil {
// return nil, err
} }
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
return nil, err return response, nil
}
func insertOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData OneTimeKeyCosmosData) error {
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
options)
if err != nil {
return err
} }
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
return nil, err return nil
}
func deleteOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData OneTimeKeyCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
} }
if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil { return err
return nil, err }
}
if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { func NewCosmosDBOneTimeKeysTable(db *Database) (tables.OneTimeKeys, error) {
return nil, err s := &oneTimeKeysStatements{
db: db,
} }
s.selectKeysStmt = selectKeysSQL
s.selectKeysCountStmt = selectKeysCountSQL
s.selectKeyByAlgorithmStmt = selectKeyByAlgorithmSQL
s.tableName = "one_time_keys"
return s, nil return s, nil
} }
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID)
// "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": userID,
"@x3": deviceID,
}
response, err := queryOneTimeKey(s, ctx, s.selectKeyByAlgorithmStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms)) wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
for _, ka := range keyIDsWithAlgorithms { for _, ka := range keyIDsWithAlgorithms {
@ -106,19 +213,18 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
} }
result := make(map[string]json.RawMessage) result := make(map[string]json.RawMessage)
for rows.Next() { for _, item := range response {
var keyID string var keyID string
var algorithm string var algorithm string
var keyJSONStr string keyID = item.OneTimeKey.KeyID
if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil { algorithm = item.OneTimeKey.Algorithm
return nil, err
}
keyIDWithAlgo := algorithm + ":" + keyID keyIDWithAlgo := algorithm + ":" + keyID
if wantSet[keyIDWithAlgo] { if wantSet[keyIDWithAlgo] {
result[keyIDWithAlgo] = json.RawMessage(keyJSONStr) result[keyIDWithAlgo] = item.OneTimeKey.KeyJSON
} }
} }
return result, rows.Err() return result, nil
} }
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
@ -127,17 +233,26 @@ func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, de
UserID: userID, UserID: userID,
KeyCount: make(map[string]int), KeyCount: make(map[string]int),
} }
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) // rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": counts.UserID,
"@x3": counts.DeviceID,
}
// "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
response, err := queryOneTimeKeyAlgoCount(s, ctx, s.selectKeysCountStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
for rows.Next() { for _, item := range response {
var algorithm string var algorithm string
var count int var count int
if err = rows.Scan(&algorithm, &count); err != nil { algorithm = item.Algorithm
return nil, err count = item.Count
}
counts.KeyCount[algorithm] = count counts.KeyCount[algorithm] = count
} }
return counts, nil return counts, nil
@ -152,30 +267,68 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(
UserID: keys.UserID, UserID: keys.UserID,
KeyCount: make(map[string]int), KeyCount: make(map[string]int),
} }
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
for keyIDWithAlgo, keyJSON := range keys.KeyJSON { for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
// "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
// " VALUES ($1, $2, $3, $4, $5, $6)" +
// " ON CONFLICT (user_id, device_id, key_id, algorithm)" +
// " DO UPDATE SET key_json = $6"
algo, keyID := keys.Split(keyIDWithAlgo) algo, keyID := keys.Split(keyIDWithAlgo)
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), // UNIQUE (user_id, device_id, key_id, algorithm)
) docId := fmt.Sprintf("%s_%s_%s_%s", keys.UserID, keys.DeviceID, keyID, algo)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
data := OneTimeKeyCosmos{
Algorithm: algo,
DeviceID: keys.DeviceID,
KeyID: keyID,
KeyJSON: keyJSON,
UserID: keys.UserID,
}
dbData := &OneTimeKeyCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: now,
OneTimeKey: data,
}
err := insertOneTimeKeyCore(s, ctx, *dbData)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) // rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": keys.UserID,
"@x3": keys.DeviceID,
}
// "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
response, err := queryOneTimeKeyAlgoCount(s, ctx, s.selectKeysCountStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
for rows.Next() { for _, item := range response {
var algorithm string var algorithm string
var count int var count int
if err = rows.Scan(&algorithm, &count); err != nil { algorithm = item.Algorithm
return nil, err count = item.Count
}
counts.KeyCount[algorithm] = count counts.KeyCount[algorithm] = count
} }
return counts, rows.Err() return counts, nil
} }
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
@ -183,14 +336,25 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
) (map[string]json.RawMessage, error) { ) (map[string]json.RawMessage, error) {
var keyID string var keyID string
var keyJSON string var keyJSON string
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
// "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": userID,
"@x3": deviceID,
"@x4": algorithm,
}
response, err := queryOneTimeKey(s, ctx, s.selectKeyByAlgorithmStmt, params)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == cosmosdbutil.ErrNoRows {
return nil, nil return nil, nil
} }
return nil, err return nil, err
} }
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) err = deleteOneTimeKeyCore(s, ctx, response[0])
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -16,78 +16,154 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql"
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
var staleDeviceListsSchema = ` // var staleDeviceListsSchema = `
-- Stores whether a user's device lists are stale or not. // -- Stores whether a user's device lists are stale or not.
CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( // CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
user_id TEXT PRIMARY KEY NOT NULL, // user_id TEXT PRIMARY KEY NOT NULL,
domain TEXT NOT NULL, // domain TEXT NOT NULL,
is_stale BOOLEAN NOT NULL, // is_stale BOOLEAN NOT NULL,
ts_added_secs BIGINT NOT NULL // ts_added_secs BIGINT NOT NULL
); // );
CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); // CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
` // `
const upsertStaleDeviceListSQL = "" + type StaleDeviceListCosmos struct {
"INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + UserID string `json:"user_id"`
" VALUES ($1, $2, $3, $4)" + Domain string `json:"domain"`
" ON CONFLICT (user_id)" + IsStale bool `json:"is_stale"`
" DO UPDATE SET is_stale = $3, ts_added_secs = $4" // Use the CosmosDB.Timestamp for this one
// ts_added_secs int64 `json:"ts_added_secs"`
const selectStaleDeviceListsWithDomainsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
type staleDeviceListsStatements struct {
db *sql.DB
upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt
} }
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { type StaleDeviceListCosmosData struct {
s := &staleDeviceListsStatements{ Id string `json:"id"`
db: db, Pk string `json:"_pk"`
} Cn string `json:"_cn"`
_, err := db.Exec(staleDeviceListsSchema) ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
StaleDeviceList StaleDeviceListCosmos `json:"mx_keyserver_stale_device_list"`
}
// const upsertStaleDeviceListSQL = "" +
// "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
// " VALUES ($1, $2, $3, $4)" +
// " ON CONFLICT (user_id)" +
// " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
const selectStaleDeviceListsWithDomainsSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_keyserver_stale_device_list.is_stale = @x2 " +
"and c.mx_keyserver_stale_device_list.domain = @x3 "
// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
const selectStaleDeviceListsSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_keyserver_stale_device_list.is_stale = @x2 "
type staleDeviceListsStatements struct {
db *Database
// upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt string
selectStaleDeviceListsStmt string
tableName string
}
func queryStaleDeviceList(s *staleDeviceListsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]StaleDeviceListCosmosData, error) {
var response []StaleDeviceListCosmosData
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
var query = cosmosdbapi.GetQuery(qry, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
return nil, err return response, nil
} }
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
return nil, err func NewCosmosDBStaleDeviceListsTable(db *Database) (tables.StaleDeviceLists, error) {
} s := &staleDeviceListsStatements{
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { db: db,
return nil, err
} }
s.selectStaleDeviceListsStmt = selectStaleDeviceListsSQL
s.selectStaleDeviceListsWithDomainsStmt = selectStaleDeviceListsWithDomainsSQL
s.tableName = "stale_device_lists"
return s, nil return s, nil
} }
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
// "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
// " VALUES ($1, $2, $3, $4)" +
// " ON CONFLICT (user_id)" +
// " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
_, domain, err := gomatrixserverlib.SplitID('@', userID) _, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return err return err
} }
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// user_id TEXT PRIMARY KEY NOT NULL,
docId := userID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
data := StaleDeviceListCosmos{
Domain: string(domain),
IsStale: isStale,
UserID: userID,
}
dbData := StaleDeviceListCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
StaleDeviceList: data,
}
// _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData,
options)
return err return err
} }
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
// we only query for 1 domain or all domains so optimise for those use cases // we only query for 1 domain or all domains so optimise for those use cases
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
if len(domains) == 0 { if len(domains) == 0 {
rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
// rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": true,
}
rows, err := queryStaleDeviceList(s, ctx, s.selectStaleDeviceListsWithDomainsStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -95,7 +171,17 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
} }
var result []string var result []string
for _, domain := range domains { for _, domain := range domains {
rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
// rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": true,
"@x3": string(domain),
}
rows, err := queryStaleDeviceList(s, ctx, s.selectStaleDeviceListsWithDomainsStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -108,14 +194,11 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
return result, nil return result, nil
} }
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { func rowsToUserIDs(ctx context.Context, rows []StaleDeviceListCosmosData) (result []string, err error) {
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") for _, item := range rows {
for rows.Next() {
var userID string var userID string
if err := rows.Scan(&userID); err != nil { userID = item.StaleDeviceList.UserID
return nil, err
}
result = append(result, userID) result = append(result, userID)
} }
return result, rows.Err() return result, nil
} }

View file

@ -15,35 +15,53 @@
package cosmosdb package cosmosdb
import ( import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/keyserver/storage/shared" "github.com/matrix-org/dendrite/keyserver/storage/shared"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
) )
// A Database is used to store room events and stream offsets.
type Database struct {
shared.Database
connection cosmosdbapi.CosmosConnection
databaseName string
cosmosConfig cosmosdbapi.CosmosConfig
serverName gomatrixserverlib.ServerName
}
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties) conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
d := &Database{
databaseName: "keyserver",
connection: conn,
cosmosConfig: config,
}
// db, err := sqlutil.Open(dbProperties)
// if err != nil {
// return nil, err
// }
otk, err := NewCosmosDBOneTimeKeysTable(d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
otk, err := NewSqliteOneTimeKeysTable(db) dk, err := NewCosmosDBDeviceKeysTable(d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dk, err := NewSqliteDeviceKeysTable(db) kc, err := NewCosmosDBKeyChangesTable(d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
kc, err := NewSqliteKeyChangesTable(db) sdl, err := NewCosmosDBStaleDeviceListsTable(d)
if err != nil {
return nil, err
}
sdl, err := NewSqliteStaleDeviceListsTable(db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &shared.Database{ return &shared.Database{
DB: db, Writer: cosmosdbutil.NewExclusiveWriterFake(),
Writer: sqlutil.NewExclusiveWriter(),
OneTimeKeysTable: otk, OneTimeKeysTable: otk,
DeviceKeysTable: dk, DeviceKeysTable: dk,
KeyChangesTable: kc, KeyChangesTable: kc,