From b4382bd8b97032969e14a2e3755aa85469ef8ba6 Mon Sep 17 00:00:00 2001 From: alexfca <75228224+alexfca@users.noreply.github.com> Date: Fri, 21 May 2021 09:34:30 +1000 Subject: [PATCH] Implement Cosmos DB for the KeyServer Service (#6) * - Implement Cosmos for the devices_table - Use the ConnectionString in the YAML to include the Tenant - Revert all other non implemented tables back to use SQLLite3 * - Change the Config to use "test.criticicalarc.com" Container - Add generic function GetDocumentOrNil to standardize GetDocument - Add func to return CrossPartition queries for Aggregates - Add func GetNextSequence() as generic seq generator for AutoIncrement - Add cosmosdbutil.ErrNoRows to return (emulate) sql.ErrNoRows - Add a "fake" ExclusiveWriterFake - Add standard "getXX", "setXX" and "queryXX" to all TABLE class files - Add specific Table SEQ for the Events table - Add specific Table SEQ for the Rooms table - Add specific Table SEQ for the StateSnapshot table * - Use CosmosDB for the KeyServer - Replace the ConnString in the YAML to Cosmos - Update the 4 tables to use Cosmos --- dendrite-config-cosmosdb.yaml | 2 +- internal/cosmosdbutil/writer.go | 16 + .../storage/cosmosdb/device_keys_table.go | 416 ++++++++++++++---- .../storage/cosmosdb/key_changes_table.go | 165 +++++-- .../storage/cosmosdb/one_time_keys_table.go | 310 ++++++++++--- .../storage/cosmosdb/stale_device_lists.go | 187 +++++--- keyserver/storage/cosmosdb/storage.go | 40 +- 7 files changed, 872 insertions(+), 264 deletions(-) create mode 100644 internal/cosmosdbutil/writer.go diff --git a/dendrite-config-cosmosdb.yaml b/dendrite-config-cosmosdb.yaml index ef7883e23..9dd980b1b 100644 --- a/dendrite-config-cosmosdb.yaml +++ b/dendrite-config-cosmosdb.yaml @@ -228,7 +228,7 @@ key_server: listen: http://localhost:7779 connect: http://localhost:7779 database: - connection_string: file:keyserver.db + connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;" max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 diff --git a/internal/cosmosdbutil/writer.go b/internal/cosmosdbutil/writer.go new file mode 100644 index 000000000..ca5176fb2 --- /dev/null +++ b/internal/cosmosdbutil/writer.go @@ -0,0 +1,16 @@ +package cosmosdbutil + +import ( + "database/sql" +) + +// The Writer interface is designed to solve the problem of how +// to handle database writes for database engines that don't allow +// concurrent writes, e.g. SQLite. +// + +// Copied for CosmosDB compatibility + +type Writer interface { + Do(db *sql.DB, txn *sql.Tx ,f func(txn *sql.Tx) error) error +} diff --git a/keyserver/storage/cosmosdb/device_keys_table.go b/keyserver/storage/cosmosdb/device_keys_table.go index 67d4da201..330642c62 100644 --- a/keyserver/storage/cosmosdb/device_keys_table.go +++ b/keyserver/storage/cosmosdb/device_keys_table.go @@ -17,134 +17,318 @@ package cosmosdb import ( "context" "database/sql" - "strings" + "fmt" "time" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) -var deviceKeysSchema = ` --- Stores device keys for users -CREATE TABLE IF NOT EXISTS keyserver_device_keys ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - ts_added_secs BIGINT NOT NULL, - key_json TEXT NOT NULL, - stream_id BIGINT NOT NULL, - display_name TEXT, - -- Clobber based on tuple of user/device. - UNIQUE (user_id, device_id) -); -` +// var deviceKeysSchema = ` +// -- Stores device keys for users +// CREATE TABLE IF NOT EXISTS keyserver_device_keys ( +// user_id TEXT NOT NULL, +// device_id TEXT NOT NULL, +// ts_added_secs BIGINT NOT NULL, +// key_json TEXT NOT NULL, +// stream_id BIGINT NOT NULL, +// display_name TEXT, +// -- Clobber based on tuple of user/device. +// UNIQUE (user_id, device_id) +// ); +// ` -const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + - " VALUES ($1, $2, $3, $4, $5, $6)" + - " ON CONFLICT (user_id, device_id)" + - " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" - -const selectDeviceKeysSQL = "" + - "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" - -const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" - -const selectMaxStreamForUserSQL = "" + - "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" - -const countStreamIDsForUserSQL = "" + - "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" - -const deleteAllDeviceKeysSQL = "" + - "DELETE FROM keyserver_device_keys WHERE user_id=$1" - -type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt - selectMaxStreamForUserStmt *sql.Stmt - deleteAllDeviceKeysStmt *sql.Stmt +type DeviceKeyCosmos struct { + UserID string `json:"user_id"` + DeviceID string `json:"device_id"` + // Use the CosmosDB.Timestamp for this one + // TSAddedSecs int64 `json:"ts_added_secs"` + KeyJSON []byte `json:"key_json"` + StreamID int `json:"stream_id"` + DisplayName string `json:"display_name"` } -func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { - s := &deviceKeysStatements{ - db: db, - } - _, err := db.Exec(deviceKeysSchema) +type DeviceKeyCosmosNumber struct { + Number int64 `json:"number"` +} + +type DeviceKeyCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + DeviceKey DeviceKeyCosmos `json:"mx_keyserver_device_key"` +} + +// const upsertDeviceKeysSQL = "" + +// "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + +// " VALUES ($1, $2, $3, $4, $5, $6)" + +// " ON CONFLICT (user_id, device_id)" + +// " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" + +// const selectDeviceKeysSQL = "" + +// "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + +// "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" +const selectBatchDeviceKeysSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_keyserver_device_key.user_id = @x2 " + + "and c.mx_keyserver_device_key.key_json <> \"\"" + +// "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" +const selectMaxStreamForUserSQL = "" + + "select max(c.mx_keyserver_device_key.stream_id) as number from c where c._cn = @x1 " + + "and c.mx_keyserver_device_key.user_id = @x2 " + +// "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" +const countStreamIDsForUserSQL = "" + + "select count(c._ts) as number from c where c._cn = @x1 " + + "and c.mx_keyserver_device_key.user_id = @x2 " + + "and ARRAY_CONTAINS(@x3, c.mx_keyserver_device_key.stream_id) " + +const selectAllDeviceKeysSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_keyserver_device_key.user_id = @x2 " + +// const deleteAllDeviceKeysSQL = "" + +// "DELETE FROM keyserver_device_keys WHERE user_id=$1" + +func queryDeviceKey(s *deviceKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceKeyCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []DeviceKeyCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } - if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil { + return response, nil +} + +func queryDeviceKeyNumber(s *deviceKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceKeyCosmosNumber, error) { + var response []DeviceKeyCosmosNumber + + var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() + var query = cosmosdbapi.GetQuery(qry, params) + var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { return nil, err } - if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil { - return nil, err + + if len(response) == 0 { + return nil, cosmosdbutil.ErrNoRows } - if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { - return nil, err + + return response, nil +} + +func getDeviceKey(s *deviceKeysStatements, ctx context.Context, pk string, docId string) (*DeviceKeyCosmosData, error) { + response := DeviceKeyCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, cosmosdbutil.ErrNoRows } - if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { - return nil, err + + return &response, err +} + +func setDeviceKey(s *deviceKeysStatements, ctx context.Context, pk string, event DeviceKeyCosmosData) (*DeviceKeyCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, event.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + event.Id, + &event, + optionsReplace) + return &event, ex +} + +func insertDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData DeviceKeyCosmosData) error { + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData, + options) + + if err != nil { + return err } - if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { - return nil, err + + return nil +} + +func mapFromDeviceKeyMessage(key api.DeviceMessage) DeviceKeyCosmos { + return DeviceKeyCosmos{ + DeviceID: key.DeviceID, + DisplayName: key.DisplayName, + KeyJSON: key.KeyJSON, + StreamID: key.StreamID, + UserID: key.UserID, } +} + +type deviceKeysStatements struct { + db *Database + // upsertDeviceKeysStmt *sql.Stmt + // selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt string + selectMaxStreamForUserStmt string + // deleteAllDeviceKeysStmt *sql.Stmt + tableName string +} + +func NewCosmosDBDeviceKeysTable(db *Database) (tables.DeviceKeys, error) { + s := &deviceKeysStatements{ + db: db, + } + s.selectBatchDeviceKeysStmt = selectBatchDeviceKeysSQL + s.selectMaxStreamForUserStmt = selectMaxStreamForUserSQL + s.tableName = "device_keys" return s, nil } -func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { - _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) +func deleteDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData DeviceKeyCosmosData) error { + var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) + var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData.Id, + options) + + if err != nil { + return err + } return err } +func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { + + // "DELETE FROM keyserver_device_keys WHERE user_id=$1" + // _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, + } + response, err := queryDeviceKey(s, ctx, selectAllDeviceKeysSQL, params) + + if err != nil { + return err + } + + for _, item := range response { + errItem := deleteDeviceKeyCore(s, ctx, item) + if errItem != nil { + return errItem + } + } + return nil +} + func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { deviceIDMap := make(map[string]bool) + + // "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" + for _, d := range deviceIDs { deviceIDMap[d] = true } - rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, + } + response, err := queryDeviceKey(s, ctx, s.selectBatchDeviceKeysStmt, params) + // rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") + // defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") + var result []api.DeviceMessage - for rows.Next() { + for _, item := range response { var dk api.DeviceMessage dk.UserID = userID - var keyJSON string + // var keyJSON string var streamID int - var displayName sql.NullString - if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { - return nil, err - } - dk.KeyJSON = []byte(keyJSON) + // var displayName sql.NullString + // if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { + // return nil, err + // } + streamID = item.DeviceKey.StreamID + + dk.KeyJSON = item.DeviceKey.KeyJSON dk.StreamID = streamID - if displayName.Valid { - dk.DisplayName = displayName.String + if len(item.DeviceKey.DisplayName) > 0 { + dk.DisplayName = item.DeviceKey.DisplayName } // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { result = append(result, dk) } } - return result, rows.Err() + return result, nil } func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { for i, key := range keys { - var keyJSONStr string + var keyJSON []byte var streamID int var displayName sql.NullString - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) - if err != nil && err != sql.ErrNoRows { + + // "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + + // err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // UNIQUE (user_id, device_id) + docId := fmt.Sprintf("%s_%s", key.UserID, key.DeviceID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + response, err := getDeviceKey(s, ctx, pk, cosmosDocId) + + if err != nil && err != cosmosdbutil.ErrNoRows { return err } + if response != nil { + keyJSON = response.DeviceKey.KeyJSON + streamID = response.DeviceKey.StreamID + displayName.String = response.DeviceKey.DisplayName + } + // this will be '' when there is no device - keys[i].KeyJSON = []byte(keyJSONStr) + keys[i].KeyJSON = keyJSON keys[i].StreamID = streamID if displayName.Valid { keys[i].DisplayName = displayName.String @@ -156,10 +340,30 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { // nullable if there are no results var nullStream sql.NullInt32 - err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) - if err == sql.ErrNoRows { - err = nil + + // "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, } + + // err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + response, err := queryDeviceKeyNumber(s, ctx, countStreamIDsForUserSQL, params) + + if err != nil { + if err == cosmosdbutil.ErrNoRows { + err = nil + } else { + return nullStream.Int32, err + } + } + + if len(response) > 0 { + nullStream.Int32 = int32(response[0].Number) + } + if nullStream.Valid { streamID = nullStream.Int32 } @@ -167,30 +371,66 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn } func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) { + + // "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" + iStreamIDs := make([]interface{}, len(streamIDs)+1) iStreamIDs[0] = userID for i := range streamIDs { iStreamIDs[i+1] = streamIDs[i] } - query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1) - // nullable if there are no results - var count sql.NullInt32 - err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, + "@x3": iStreamIDs, + } + + // query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1) + // // nullable if there are no results + // var count sql.NullInt32 + // err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count) + + response, err := queryDeviceKeyNumber(s, ctx, countStreamIDsForUserSQL, params) + if err != nil { return 0, err } - if count.Valid { - return int(count.Int32), nil + // if count.Valid { + // return int(count.Int32), nil + // } + if response[0].Number >= 0 { + return int(response[0].Number), nil } return 0, nil } func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { + + // "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + + // " VALUES ($1, $2, $3, $4, $5, $6)" + + // " ON CONFLICT (user_id, device_id)" + + // " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + for _, key := range keys { now := time.Now().Unix() - _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, - ) + // UNIQUE (user_id, device_id) + docId := fmt.Sprintf("%s_%s", key.UserID, key.DeviceID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + + dbData := &DeviceKeyCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: now, + DeviceKey: mapFromDeviceKeyMessage(key), + } + + err := insertDeviceKeyCore(s, ctx, *dbData) + if err != nil { return err } diff --git a/keyserver/storage/cosmosdb/key_changes_table.go b/keyserver/storage/cosmosdb/key_changes_table.go index 08eef3619..6cc2faac8 100644 --- a/keyserver/storage/cosmosdb/key_changes_table.go +++ b/keyserver/storage/cosmosdb/key_changes_table.go @@ -16,64 +16,139 @@ package cosmosdb import ( "context" - "database/sql" + "fmt" "math" + "time" "github.com/Shopify/sarama" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) -var keyChangesSchema = ` --- Stores key change information about users. Used to determine when to send updated device lists to clients. -CREATE TABLE IF NOT EXISTS keyserver_key_changes ( - partition BIGINT NOT NULL, - offset BIGINT NOT NULL, - -- The key owner - user_id TEXT NOT NULL, - UNIQUE (partition, offset) -); -` +// var keyChangesSchema = ` +// -- Stores key change information about users. Used to determine when to send updated device lists to clients. +// CREATE TABLE IF NOT EXISTS keyserver_key_changes ( +// partition BIGINT NOT NULL, +// offset BIGINT NOT NULL, +// -- The key owner +// user_id TEXT NOT NULL, +// UNIQUE (partition, offset) +// ); +// ` + +type KeyChangeCosmos struct { + Partition int32 `json:"partition"` + Offset int64 `json:"_offset"` //offset is reserved + UserID string `json:"user_id"` +} + +type KeyChangeUserMaxCosmosData struct { + UserID string `json:"user_id"` + MaxOffset int64 `json:"max_offset"` +} + +type KeyChangeCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + KeyChange KeyChangeCosmos `json:"mx_keyserver_key_change"` +} // Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped. // Rather than falling over, just overwrite (though this will mean clients with an existing sync token will // miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too. -const upsertKeyChangeSQL = "" + - "INSERT INTO keyserver_key_changes (partition, offset, user_id)" + - " VALUES ($1, $2, $3)" + - " ON CONFLICT (partition, offset)" + - " DO UPDATE SET user_id = $3" +// const upsertKeyChangeSQL = "" + +// "INSERT INTO keyserver_key_changes (partition, offset, user_id)" + +// " VALUES ($1, $2, $3)" + +// " ON CONFLICT (partition, offset)" + +// " DO UPDATE SET user_id = $3" // select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just // take the max offset value as the latest offset. +// "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id" const selectKeyChangesSQL = "" + - "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id" + "select c.mx_keyserver_key_change.user_id as user_id, max(c.mx_keyserver_key_change._offset) as max_offset " + + "from c where c._cn = @x1 " + + "and c.mx_keyserver_key_change.partition = @x2 " + + "and c.mx_keyserver_key_change._offset > @x3 " + + "and c.mx_keyserver_key_change._offset < @x4 " + + "group by c.mx_keyserver_key_change.user_id " type keyChangesStatements struct { - db *sql.DB - upsertKeyChangeStmt *sql.Stmt - selectKeyChangesStmt *sql.Stmt + db *Database + // upsertKeyChangeStmt *sql.Stmt + selectKeyChangesStmt string + tableName string } -func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { +func queryKeyChangeUserMax(s *keyChangesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyChangeUserMaxCosmosData, error) { + var response []KeyChangeUserMaxCosmosData + + var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() + var query = cosmosdbapi.GetQuery(qry, params) + var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + // When there are no Rows we seem to get the generic Bad Req JSON error + if err != nil { + // return nil, err + } + + return response, nil +} + +func NewCosmosDBKeyChangesTable(db *Database) (tables.KeyChanges, error) { s := &keyChangesStatements{ db: db, } - _, err := db.Exec(keyChangesSchema) - if err != nil { - return nil, err - } - if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil { - return nil, err - } - if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil { - return nil, err - } + s.selectKeyChangesStmt = selectKeyChangesSQL + s.tableName = "key_changes" return s, nil } func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { - _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) + + // "INSERT INTO keyserver_key_changes (partition, offset, user_id)" + + // " VALUES ($1, $2, $3)" + + // " ON CONFLICT (partition, offset)" + + // " DO UPDATE SET user_id = $3" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + // UNIQUE (partition, offset) + docId := fmt.Sprintf("%d_%d", partition, offset) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + + data := KeyChangeCosmos{ + Offset: offset, + Partition: partition, + UserID: userID, + } + + dbData := KeyChangeCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + KeyChange: data, + } + + // _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData, + options) + return err } @@ -84,17 +159,29 @@ func (s *keyChangesStatements) SelectKeyChanges( toOffset = math.MaxInt64 } latestOffset = fromOffset - rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset) + + // "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id" + // rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": partition, + "@x3": fromOffset, + "@x4": toOffset, + } + + response, err := queryKeyChangeUserMax(s, ctx, s.selectKeyChangesStmt, params) + if err != nil { return nil, 0, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed") - for rows.Next() { + + for _, item := range response { var userID string var offset int64 - if err := rows.Scan(&userID, &offset); err != nil { - return nil, 0, err - } + userID = item.UserID + offset = item.MaxOffset if offset > latestOffset { latestOffset = offset } diff --git a/keyserver/storage/cosmosdb/one_time_keys_table.go b/keyserver/storage/cosmosdb/one_time_keys_table.go index 942c7532a..6eb33d470 100644 --- a/keyserver/storage/cosmosdb/one_time_keys_table.go +++ b/keyserver/storage/cosmosdb/one_time_keys_table.go @@ -18,87 +18,194 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "time" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) -var oneTimeKeysSchema = ` --- Stores one-time public keys for users -CREATE TABLE IF NOT EXISTS keyserver_one_time_keys ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - key_id TEXT NOT NULL, - algorithm TEXT NOT NULL, - ts_added_secs BIGINT NOT NULL, - key_json TEXT NOT NULL, - -- Clobber based on 4-uple of user/device/key/algorithm. - UNIQUE (user_id, device_id, key_id, algorithm) -); -` +// var oneTimeKeysSchema = ` +// -- Stores one-time public keys for users +// CREATE TABLE IF NOT EXISTS keyserver_one_time_keys ( +// user_id TEXT NOT NULL, +// device_id TEXT NOT NULL, +// key_id TEXT NOT NULL, +// algorithm TEXT NOT NULL, +// ts_added_secs BIGINT NOT NULL, +// key_json TEXT NOT NULL, +// -- Clobber based on 4-uple of user/device/key/algorithm. +// UNIQUE (user_id, device_id, key_id, algorithm) +// ); +// ` -const upsertKeysSQL = "" + - "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + - " VALUES ($1, $2, $3, $4, $5, $6)" + - " ON CONFLICT (user_id, device_id, key_id, algorithm)" + - " DO UPDATE SET key_json = $6" +type OneTimeKeyCosmos struct { + UserID string `json:"user_id"` + DeviceID string `json:"device_id"` + KeyID string `json:"key_id"` + Algorithm string `json:"algorithm"` + // Use the CosmosDB.Timestamp for this one + // ts_added_secs int64 `json:"ts_added_secs"` + KeyJSON []byte `json:"key_json"` +} +type OneTimeKeyAlgoCountCosmosData struct { + Algorithm string `json:"algorithm"` + Count int `json:"count"` +} + +type OneTimeKeyCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + OneTimeKey OneTimeKeyCosmos `json:"mx_keyserver_one_time_key"` +} + +// const upsertKeysSQL = "" + +// "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + +// " VALUES ($1, $2, $3, $4, $5, $6)" + +// " ON CONFLICT (user_id, device_id, key_id, algorithm)" + +// " DO UPDATE SET key_json = $6" + +// "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" const selectKeysSQL = "" + - "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" + "select * from c where c._cn = @x1 " + + "and c.mx_keyserver_one_time_key.user_id = @x2 " + + "and c.mx_keyserver_one_time_key.device_id = @x3 " +// "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" const selectKeysCountSQL = "" + - "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" + "select c.mx_keyserver_one_time_key.algorithm as algorithm, count(c.mx_keyserver_one_time_key.key_id) as count " + + "from c where c._cn = @x1 " + + "and c.mx_keyserver_one_time_key.user_id = @x2 " + + "and c.mx_keyserver_one_time_key.device_id = @x3 " + + "group by c.mx_keyserver_one_time_key.algorithm " const deleteOneTimeKeySQL = "" + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4" +// "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" const selectKeyByAlgorithmSQL = "" + - "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" + "select top 1 * from c where c._cn = @x1 " + + "and c.mx_keyserver_one_time_key.user_id = @x2 " + + "and c.mx_keyserver_one_time_key.device_id = @x3 " + + "and c.mx_keyserver_one_time_key.algorithm = @x4 " type oneTimeKeysStatements struct { - db *sql.DB - upsertKeysStmt *sql.Stmt - selectKeysStmt *sql.Stmt - selectKeysCountStmt *sql.Stmt - selectKeyByAlgorithmStmt *sql.Stmt - deleteOneTimeKeyStmt *sql.Stmt + db *Database + // upsertKeysStmt *sql.Stmt + selectKeysStmt string + selectKeysCountStmt string + selectKeyByAlgorithmStmt string + // deleteOneTimeKeyStmt *sql.Stmt + tableName string } -func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { - s := &oneTimeKeysStatements{ - db: db, - } - _, err := db.Exec(oneTimeKeysSchema) +func queryOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyCosmosData, error) { + var response []OneTimeKeyCosmosData + + var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() + var query = cosmosdbapi.GetQuery(qry, params) + var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } - if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil { - return nil, err + + return response, nil +} + +func queryOneTimeKeyAlgoCount(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyAlgoCountCosmosData, error) { + var response []OneTimeKeyAlgoCountCosmosData + var test interface{} + + var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() + var query = cosmosdbapi.GetQuery(qry, params) + var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &test, + optionsQry) + + // When there are no Rows we seem to get the generic Bad Req JSON error + if err != nil { + // return nil, err } - if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { - return nil, err + + return response, nil +} + +func insertOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData OneTimeKeyCosmosData) error { + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData, + options) + + if err != nil { + return err } - if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { - return nil, err + + return nil +} + +func deleteOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData OneTimeKeyCosmosData) error { + var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) + var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData.Id, + options) + + if err != nil { + return err } - if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { - return nil, err + return err +} + +func NewCosmosDBOneTimeKeysTable(db *Database) (tables.OneTimeKeys, error) { + s := &oneTimeKeysStatements{ + db: db, } + s.selectKeysStmt = selectKeysSQL + s.selectKeysCountStmt = selectKeysCountSQL + s.selectKeyByAlgorithmStmt = selectKeyByAlgorithmSQL + s.tableName = "one_time_keys" return s, nil } func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { - rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID) + + // "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, + "@x3": deviceID, + } + + response, err := queryOneTimeKey(s, ctx, s.selectKeyByAlgorithmStmt, params) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed") wantSet := make(map[string]bool, len(keyIDsWithAlgorithms)) for _, ka := range keyIDsWithAlgorithms { @@ -106,19 +213,18 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d } result := make(map[string]json.RawMessage) - for rows.Next() { + for _, item := range response { var keyID string var algorithm string - var keyJSONStr string - if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil { - return nil, err - } + keyID = item.OneTimeKey.KeyID + algorithm = item.OneTimeKey.Algorithm + keyIDWithAlgo := algorithm + ":" + keyID if wantSet[keyIDWithAlgo] { - result[keyIDWithAlgo] = json.RawMessage(keyJSONStr) + result[keyIDWithAlgo] = item.OneTimeKey.KeyJSON } } - return result, rows.Err() + return result, nil } func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { @@ -127,17 +233,26 @@ func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, de UserID: userID, KeyCount: make(map[string]int), } - rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) + // rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": counts.UserID, + "@x3": counts.DeviceID, + } + + // "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" + response, err := queryOneTimeKeyAlgoCount(s, ctx, s.selectKeysCountStmt, params) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") - for rows.Next() { + + for _, item := range response { var algorithm string var count int - if err = rows.Scan(&algorithm, &count); err != nil { - return nil, err - } + algorithm = item.Algorithm + count = item.Count counts.KeyCount[algorithm] = count } return counts, nil @@ -152,30 +267,68 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys( UserID: keys.UserID, KeyCount: make(map[string]int), } + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + for keyIDWithAlgo, keyJSON := range keys.KeyJSON { + + // "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + + // " VALUES ($1, $2, $3, $4, $5, $6)" + + // " ON CONFLICT (user_id, device_id, key_id, algorithm)" + + // " DO UPDATE SET key_json = $6" + algo, keyID := keys.Split(keyIDWithAlgo) - _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( - ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), - ) + + // UNIQUE (user_id, device_id, key_id, algorithm) + docId := fmt.Sprintf("%s_%s_%s_%s", keys.UserID, keys.DeviceID, keyID, algo) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + + data := OneTimeKeyCosmos{ + Algorithm: algo, + DeviceID: keys.DeviceID, + KeyID: keyID, + KeyJSON: keyJSON, + UserID: keys.UserID, + } + + dbData := &OneTimeKeyCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: now, + OneTimeKey: data, + } + + err := insertOneTimeKeyCore(s, ctx, *dbData) + if err != nil { return nil, err } } - rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + // rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": keys.UserID, + "@x3": keys.DeviceID, + } + + // "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" + response, err := queryOneTimeKeyAlgoCount(s, ctx, s.selectKeysCountStmt, params) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") - for rows.Next() { + + for _, item := range response { var algorithm string var count int - if err = rows.Scan(&algorithm, &count); err != nil { - return nil, err - } + algorithm = item.Algorithm + count = item.Count counts.KeyCount[algorithm] = count } - return counts, rows.Err() + return counts, nil } func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( @@ -183,14 +336,25 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( ) (map[string]json.RawMessage, error) { var keyID string var keyJSON string - err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + + // "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, + "@x3": deviceID, + "@x4": algorithm, + } + + response, err := queryOneTimeKey(s, ctx, s.selectKeyByAlgorithmStmt, params) if err != nil { - if err == sql.ErrNoRows { + if err == cosmosdbutil.ErrNoRows { return nil, nil } return nil, err } - _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + err = deleteOneTimeKeyCore(s, ctx, response[0]) if err != nil { return nil, err } diff --git a/keyserver/storage/cosmosdb/stale_device_lists.go b/keyserver/storage/cosmosdb/stale_device_lists.go index 2c4e0d8e2..ea84709c3 100644 --- a/keyserver/storage/cosmosdb/stale_device_lists.go +++ b/keyserver/storage/cosmosdb/stale_device_lists.go @@ -16,78 +16,154 @@ package cosmosdb import ( "context" - "database/sql" "time" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) -var staleDeviceListsSchema = ` --- Stores whether a user's device lists are stale or not. -CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( - user_id TEXT PRIMARY KEY NOT NULL, - domain TEXT NOT NULL, - is_stale BOOLEAN NOT NULL, - ts_added_secs BIGINT NOT NULL -); +// var staleDeviceListsSchema = ` +// -- Stores whether a user's device lists are stale or not. +// CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( +// user_id TEXT PRIMARY KEY NOT NULL, +// domain TEXT NOT NULL, +// is_stale BOOLEAN NOT NULL, +// ts_added_secs BIGINT NOT NULL +// ); -CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); -` +// CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); +// ` -const upsertStaleDeviceListSQL = "" + - "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + - " VALUES ($1, $2, $3, $4)" + - " ON CONFLICT (user_id)" + - " DO UPDATE SET is_stale = $3, ts_added_secs = $4" - -const selectStaleDeviceListsWithDomainsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" - -const selectStaleDeviceListsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" - -type staleDeviceListsStatements struct { - db *sql.DB - upsertStaleDeviceListStmt *sql.Stmt - selectStaleDeviceListsWithDomainsStmt *sql.Stmt - selectStaleDeviceListsStmt *sql.Stmt +type StaleDeviceListCosmos struct { + UserID string `json:"user_id"` + Domain string `json:"domain"` + IsStale bool `json:"is_stale"` + // Use the CosmosDB.Timestamp for this one + // ts_added_secs int64 `json:"ts_added_secs"` } -func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { - s := &staleDeviceListsStatements{ - db: db, - } - _, err := db.Exec(staleDeviceListsSchema) +type StaleDeviceListCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + StaleDeviceList StaleDeviceListCosmos `json:"mx_keyserver_stale_device_list"` +} + +// const upsertStaleDeviceListSQL = "" + +// "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + +// " VALUES ($1, $2, $3, $4)" + +// " ON CONFLICT (user_id)" + +// " DO UPDATE SET is_stale = $3, ts_added_secs = $4" + +// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" +const selectStaleDeviceListsWithDomainsSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_keyserver_stale_device_list.is_stale = @x2 " + + "and c.mx_keyserver_stale_device_list.domain = @x3 " + +// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" +const selectStaleDeviceListsSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_keyserver_stale_device_list.is_stale = @x2 " + +type staleDeviceListsStatements struct { + db *Database + // upsertStaleDeviceListStmt *sql.Stmt + selectStaleDeviceListsWithDomainsStmt string + selectStaleDeviceListsStmt string + tableName string +} + +func queryStaleDeviceList(s *staleDeviceListsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]StaleDeviceListCosmosData, error) { + var response []StaleDeviceListCosmosData + + var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() + var query = cosmosdbapi.GetQuery(qry, params) + var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } - if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { - return nil, err + + return response, nil +} + +func NewCosmosDBStaleDeviceListsTable(db *Database) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{ + db: db, } + s.selectStaleDeviceListsStmt = selectStaleDeviceListsSQL + s.selectStaleDeviceListsWithDomainsStmt = selectStaleDeviceListsWithDomainsSQL + s.tableName = "stale_device_lists" return s, nil } func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { + + // "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + + // " VALUES ($1, $2, $3, $4)" + + // " ON CONFLICT (user_id)" + + // " DO UPDATE SET is_stale = $3, ts_added_secs = $4" + _, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { return err } - _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + // user_id TEXT PRIMARY KEY NOT NULL, + docId := userID + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + + data := StaleDeviceListCosmos{ + Domain: string(domain), + IsStale: isStale, + UserID: userID, + } + + dbData := StaleDeviceListCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + StaleDeviceList: data, + } + + // _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData, + options) + return err } func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { // we only query for 1 domain or all domains so optimise for those use cases + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) if len(domains) == 0 { - rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) + + // "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" + // rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": true, + } + rows, err := queryStaleDeviceList(s, ctx, s.selectStaleDeviceListsWithDomainsStmt, params) + if err != nil { return nil, err } @@ -95,7 +171,17 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte } var result []string for _, domain := range domains { - rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) + + // "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" + // rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": true, + "@x3": string(domain), + } + + rows, err := queryStaleDeviceList(s, ctx, s.selectStaleDeviceListsWithDomainsStmt, params) + if err != nil { return nil, err } @@ -108,14 +194,11 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte return result, nil } -func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { - defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") - for rows.Next() { +func rowsToUserIDs(ctx context.Context, rows []StaleDeviceListCosmosData) (result []string, err error) { + for _, item := range rows { var userID string - if err := rows.Scan(&userID); err != nil { - return nil, err - } + userID = item.StaleDeviceList.UserID result = append(result, userID) } - return result, rows.Err() + return result, nil } diff --git a/keyserver/storage/cosmosdb/storage.go b/keyserver/storage/cosmosdb/storage.go index ba000cb24..004d8d7d6 100644 --- a/keyserver/storage/cosmosdb/storage.go +++ b/keyserver/storage/cosmosdb/storage.go @@ -15,35 +15,53 @@ package cosmosdb import ( - "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "github.com/matrix-org/dendrite/keyserver/storage/shared" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" ) +// A Database is used to store room events and stream offsets. +type Database struct { + shared.Database + connection cosmosdbapi.CosmosConnection + databaseName string + cosmosConfig cosmosdbapi.CosmosConfig + serverName gomatrixserverlib.ServerName +} + func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) { - db, err := sqlutil.Open(dbProperties) + conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString) + config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString) + d := &Database{ + databaseName: "keyserver", + connection: conn, + cosmosConfig: config, + } + + // db, err := sqlutil.Open(dbProperties) + // if err != nil { + // return nil, err + // } + otk, err := NewCosmosDBOneTimeKeysTable(d) if err != nil { return nil, err } - otk, err := NewSqliteOneTimeKeysTable(db) + dk, err := NewCosmosDBDeviceKeysTable(d) if err != nil { return nil, err } - dk, err := NewSqliteDeviceKeysTable(db) + kc, err := NewCosmosDBKeyChangesTable(d) if err != nil { return nil, err } - kc, err := NewSqliteKeyChangesTable(db) - if err != nil { - return nil, err - } - sdl, err := NewSqliteStaleDeviceListsTable(db) + sdl, err := NewCosmosDBStaleDeviceListsTable(d) if err != nil { return nil, err } return &shared.Database{ - DB: db, - Writer: sqlutil.NewExclusiveWriter(), + Writer: cosmosdbutil.NewExclusiveWriterFake(), OneTimeKeysTable: otk, DeviceKeysTable: dk, KeyChangesTable: kc,