From 60e11f88b84ea685d962e121e165ad197783bb5c Mon Sep 17 00:00:00 2001 From: alexfca <75228224+alexfca@users.noreply.github.com> Date: Fri, 8 Oct 2021 11:17:22 +1100 Subject: [PATCH] - Remove PerformQueryAllPartitions as it does not support aggreates (#24) - Update queries to all use PartitionKeys - Remove the _sid from queries as the PK contains the Tenant - Fix some bugs around empty values and ordering Co-authored-by: alexf@example.com --- internal/cosmosdbapi/client.go | 46 ------------------- internal/cosmosdbapi/documentoperations.go | 8 ---- .../storage/cosmosdb/device_keys_table.go | 22 ++++----- .../storage/cosmosdb/key_changes_table.go | 20 ++++---- .../storage/cosmosdb/stale_device_lists.go | 11 ++--- roomserver/storage/cosmosdb/events_table.go | 22 +++++---- .../storage/cosmosdb/account_data_table.go | 7 ++- syncapi/storage/cosmosdb/filtering.go | 4 +- syncapi/storage/cosmosdb/invites_table.go | 7 ++- .../cosmosdb/output_room_events_table.go | 15 ++---- syncapi/storage/cosmosdb/peeks_table.go | 3 +- syncapi/storage/cosmosdb/receipt_table.go | 7 ++- .../storage/cosmosdb/send_to_device_table.go | 7 ++- .../cosmosdb/key_backup_version_table.go | 17 ++++--- 14 files changed, 71 insertions(+), 125 deletions(-) diff --git a/internal/cosmosdbapi/client.go b/internal/cosmosdbapi/client.go index e6a271d3f..513b6bd48 100644 --- a/internal/cosmosdbapi/client.go +++ b/internal/cosmosdbapi/client.go @@ -84,35 +84,6 @@ func PerformQuery(ctx context.Context, return err } -func PerformQueryAllPartitions(ctx context.Context, - conn CosmosConnection, - databaseName string, - containerName string, - qryString string, - params map[string]interface{}, - response interface{}) error { - err := validateQueryAllPartitions(qryString) - if err != nil { - return err - } - var optionsQry = getQueryAllPartitionsDocumentsOptions() - var query = getQuery(qryString, params) - _, err = GetClient(conn).QueryDocuments( - ctx, - databaseName, - containerName, - query, - &response, - optionsQry) - - // When there are no Rows we seem to get the generic Bad Req JSON error - if err != nil { - // return nil, err - } - - return nil -} - func GenerateDocument( collection string, tenantName string, @@ -181,20 +152,3 @@ func validateQuery(qryString string) error { } return nil } - -func validateQueryAllPartitions(qryString string) error { - err := validateQuery(qryString) - if err != nil { - return err - } - if strings.Contains(qryString, " top ") { - return errors.New("qryString contains [ top ] ") - } - if strings.Contains(qryString, " order by ") { - return errors.New("qryString contains [ order by ] ") - } - if !strings.Contains(qryString, " c._sid = ") { - return errors.New("qryString must contain [ c._sid = ] ") - } - return nil -} diff --git a/internal/cosmosdbapi/documentoperations.go b/internal/cosmosdbapi/documentoperations.go index 2285b3511..cc1be7bc7 100644 --- a/internal/cosmosdbapi/documentoperations.go +++ b/internal/cosmosdbapi/documentoperations.go @@ -26,14 +26,6 @@ func getQueryDocumentsOptions(pk string) cosmosapi.QueryDocumentsOptions { } } -func getQueryAllPartitionsDocumentsOptions() cosmosapi.QueryDocumentsOptions { - return cosmosapi.QueryDocumentsOptions{ - IsQuery: true, - EnableCrossPartition: true, - ContentType: cosmosapi.QUERY_CONTENT_TYPE, - } -} - func GetGetDocumentOptions(pk string) cosmosapi.GetDocumentOptions { return cosmosapi.GetDocumentOptions{ PartitionKeyValue: pk, diff --git a/keyserver/storage/cosmosdb/device_keys_table.go b/keyserver/storage/cosmosdb/device_keys_table.go index 67bd6424d..ebec4f259 100644 --- a/keyserver/storage/cosmosdb/device_keys_table.go +++ b/keyserver/storage/cosmosdb/device_keys_table.go @@ -75,14 +75,14 @@ const selectBatchDeviceKeysSQL = "" + // "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" const selectMaxStreamForUserSQL = "" + - "select max(c.mx_keyserver_device_key.stream_id) as number from c where c._sid = @x1 and c._cn = @x2 " + - "and c.mx_keyserver_device_key.user_id = @x3 " + "select max(c.mx_keyserver_device_key.stream_id) as number from c where c._cn = @x1 " + + "and c.mx_keyserver_device_key.user_id = @x2 " // "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" const countStreamIDsForUserSQL = "" + - "select count(c._ts) as number from c where c._sid = @x1 and c._cn = @x2 " + - "and c.mx_keyserver_device_key.user_id = @x3 " + - "and ARRAY_CONTAINS(@x4, c.mx_keyserver_device_key.stream_id) " + "select count(c._ts) as number from c where c._cn = @x1 " + + "and c.mx_keyserver_device_key.user_id = @x2 " + + "and ARRAY_CONTAINS(@x3, c.mx_keyserver_device_key.stream_id) " const selectAllDeviceKeysSQL = "" + "select * from c where c._cn = @x1 " + @@ -356,9 +356,8 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn // "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" params := map[string]interface{}{ - "@x1": s.db.cosmosConfig.TenantName, - "@x2": s.getCollectionName(), - "@x3": userID, + "@x1": s.getCollectionName(), + "@x2": userID, } // err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) @@ -398,10 +397,9 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID } params := map[string]interface{}{ - "@x1": s.db.cosmosConfig.TenantName, - "@x2": s.getCollectionName(), - "@x3": userID, - "@x4": iStreamIDs, + "@x1": s.getCollectionName(), + "@x2": userID, + "@x3": iStreamIDs, } // query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1) diff --git a/keyserver/storage/cosmosdb/key_changes_table.go b/keyserver/storage/cosmosdb/key_changes_table.go index 86ced845c..d70f75438 100644 --- a/keyserver/storage/cosmosdb/key_changes_table.go +++ b/keyserver/storage/cosmosdb/key_changes_table.go @@ -65,10 +65,10 @@ type keyChangeCosmosData struct { // "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id" const selectKeyChangesSQL = "" + "select c.mx_keyserver_key_change.user_id as user_id, max(c.mx_keyserver_key_change._offset) as max_offset " + - "from c where c._sid = @x1 and c._cn = @x2 " + - "and c.mx_keyserver_key_change.partition = @x3 " + - "and c.mx_keyserver_key_change._offset > @x4 " + - "and c.mx_keyserver_key_change._offset < @x5 " + + "from c where c._cn = @x1 " + + "and c.mx_keyserver_key_change.partition = @x2 " + + "and c.mx_keyserver_key_change._offset > @x3 " + + "and c.mx_keyserver_key_change._offset < @x4 " + "group by c.mx_keyserver_key_change.user_id " type keyChangesStatements struct { @@ -161,18 +161,18 @@ func (s *keyChangesStatements) SelectKeyChanges( // rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset) params := map[string]interface{}{ - "@x1": s.db.cosmosConfig.TenantName, - "@x2": s.getCollectionName(), - "@x3": partition, - "@x4": fromOffset, - "@x5": toOffset, + "@x1": s.getCollectionName(), + "@x2": partition, + "@x3": fromOffset, + "@x4": toOffset, } var rows []keyChangeUserMaxCosmosData - err = cosmosdbapi.PerformQueryAllPartitions(ctx, + err = cosmosdbapi.PerformQuery(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectKeyChangesStmt, params, &rows) if err != nil { diff --git a/keyserver/storage/cosmosdb/stale_device_lists.go b/keyserver/storage/cosmosdb/stale_device_lists.go index 60cfa46c3..bfe7bae8c 100644 --- a/keyserver/storage/cosmosdb/stale_device_lists.go +++ b/keyserver/storage/cosmosdb/stale_device_lists.go @@ -55,9 +55,9 @@ type staleDeviceListCosmosData struct { // "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" const selectStaleDeviceListsWithDomainsSQL = "" + - "select * from c where c._sid = @x1 and c._cn = @x2 " + - "and c.mx_keyserver_stale_device_list.is_stale = @x3 " + - "and c.mx_keyserver_stale_device_list.domain = @x4 " + "select * from c where c._cn = @x1 " + + "and c.mx_keyserver_stale_device_list.is_stale = @x2 " + + "and c.mx_keyserver_stale_device_list.domain = @x3 " // "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" const selectStaleDeviceListsSQL = "" + @@ -156,9 +156,8 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte // "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" // rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) params := map[string]interface{}{ - "@x1": s.db.cosmosConfig.TenantName, - "@x2": s.getCollectionName(), - "@x3": true, + "@x1": s.getCollectionName(), + "@x2": true, } var rows []staleDeviceListCosmosData diff --git a/roomserver/storage/cosmosdb/events_table.go b/roomserver/storage/cosmosdb/events_table.go index ec63120ac..91a95788d 100644 --- a/roomserver/storage/cosmosdb/events_table.go +++ b/roomserver/storage/cosmosdb/events_table.go @@ -146,9 +146,8 @@ const bulkSelectEventNIDSQL = "" + // "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" const selectMaxEventDepthSQL = "" + - "select sub.maxinner != null ? sub.maxinner + 1 : 0 as maxdepth from " + - "(select MAX(c.mx_roomserver_event.depth) maxinner from c where c._sid = @x1 and c._cn = @x2 " + - " and ARRAY_CONTAINS(@x3, c.mx_roomserver_event.event_nid)) sub" + "select MAX(c.mx_roomserver_event.depth) maxdepth from c where c._cn = @x1 " + + " and ARRAY_CONTAINS(@x2, c.mx_roomserver_event.event_nid)" // "SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid IN ($1)" const selectRoomNIDsForEventNIDsSQL = "" + @@ -877,22 +876,29 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, // "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" params := map[string]interface{}{ - "@x1": s.db.cosmosConfig.TenantName, - "@x2": s.getCollectionName(), - "@x3": eventNIDs, + "@x1": s.getCollectionName(), + "@x2": eventNIDs, } var rows []eventCosmosMaxDepth - err := cosmosdbapi.PerformQueryAllPartitions(ctx, + err := cosmosdbapi.PerformQuery(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), selectMaxEventDepthSQL, params, &rows) if err != nil { return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err) } - return rows[0].Max, nil + if len(rows) == 0 { + return 0, nil + } + result := rows[0].Max + if result == 0 { + return 0, nil + } + return result + 1, nil } func (s *eventStatements) SelectRoomNIDsForEventNIDs( diff --git a/syncapi/storage/cosmosdb/account_data_table.go b/syncapi/storage/cosmosdb/account_data_table.go index e6e74e09e..36ee845e4 100644 --- a/syncapi/storage/cosmosdb/account_data_table.go +++ b/syncapi/storage/cosmosdb/account_data_table.go @@ -72,8 +72,7 @@ const selectAccountDataInRangeSQL = "" + // "SELECT MAX(id) FROM syncapi_account_data_type" const selectMaxAccountDataIDSQL = "" + - "select max(c.mx_syncapi_account_data_type.id) as number from c where c._cn = @x1 " + - "and c._sid = @x2 " + "select max(c.mx_syncapi_account_data_type.id) as number from c where c._cn = @x1 " type accountDataStatements struct { db *SyncServerDatasource @@ -248,14 +247,14 @@ func (s *accountDataStatements) SelectMaxAccountDataID( // err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID) params := map[string]interface{}{ "@x1": s.getCollectionName(), - "@x2": s.db.cosmosConfig.TenantName, } var rows []AccountDataTypeNumberCosmosData - err = cosmosdbapi.PerformQueryAllPartitions(ctx, + err = cosmosdbapi.PerformQuery(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectMaxAccountDataIDStmt, params, &rows) if err != cosmosdbutil.ErrNoRows && len(rows) == 1 { diff --git a/syncapi/storage/cosmosdb/filtering.go b/syncapi/storage/cosmosdb/filtering.go index 6d5acda4b..0ced91349 100644 --- a/syncapi/storage/cosmosdb/filtering.go +++ b/syncapi/storage/cosmosdb/filtering.go @@ -69,9 +69,9 @@ func prepareWithFilters( } switch order { case FilterOrderAsc: - sql += fmt.Sprintf("order by c.%s.event_id asc ", collectionName) + sql += fmt.Sprintf("order by c.%s.id asc ", collectionName) case FilterOrderDesc: - sql += fmt.Sprintf("order by c.%s.event_id desc ", collectionName) + sql += fmt.Sprintf("order by c.%s.id desc ", collectionName) } // query += fmt.Sprintf(" LIMIT $%d", offset+1) return diff --git a/syncapi/storage/cosmosdb/invites_table.go b/syncapi/storage/cosmosdb/invites_table.go index d05ed268e..a25efb8a1 100644 --- a/syncapi/storage/cosmosdb/invites_table.go +++ b/syncapi/storage/cosmosdb/invites_table.go @@ -83,8 +83,7 @@ const selectInviteEventsInRangeSQL = "" + // "SELECT MAX(id) FROM syncapi_invite_events" const selectMaxInviteIDSQL = "" + - "select max(c.mx_syncapi_invite_event.id) from c where c._cn = @x1 " + - "and c._sid = @x2 " + "select max(c.mx_syncapi_invite_event.id) from c where c._cn = @x1 " type inviteEventsStatements struct { db *SyncServerDatasource @@ -296,14 +295,14 @@ func (s *inviteEventsStatements) SelectMaxInviteID( // err = stmt.QueryRowContext(ctx).Scan(&nullableID) params := map[string]interface{}{ "@x1": s.getCollectionName(), - "@x2": s.db.cosmosConfig.TenantName, } var rows []inviteEventCosmosMaxNumber - err = cosmosdbapi.PerformQueryAllPartitions(ctx, + err = cosmosdbapi.PerformQuery(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectMaxInviteIDStmt, params, &rows) if len(rows) > 0 { diff --git a/syncapi/storage/cosmosdb/output_room_events_table.go b/syncapi/storage/cosmosdb/output_room_events_table.go index 4b2cbced0..36edaf78d 100644 --- a/syncapi/storage/cosmosdb/output_room_events_table.go +++ b/syncapi/storage/cosmosdb/output_room_events_table.go @@ -115,8 +115,7 @@ const selectEarlyEventsSQL = "" + // "SELECT MAX(id) FROM syncapi_output_room_events" const selectMaxEventIDSQL = "" + - "select max(c.mx_syncapi_output_room_event.id) as number from c where c._cn = @x1 " + - "and c._sid = @x2 " + "select max(c.mx_syncapi_output_room_event.id) as number from c where c._cn = @x1 " // "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2" const updateEventJSONSQL = "" + @@ -336,29 +335,25 @@ func (s *outputRoomEventsStatements) SelectStateInRange( func (s *outputRoomEventsStatements) SelectMaxEventID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { - var nullableID sql.NullInt64 params := map[string]interface{}{ "@x1": s.getCollectionName(), - "@x2": s.db.cosmosConfig.TenantName, } // stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt) var rows []outputRoomEventCosmosMaxNumber - err = cosmosdbapi.PerformQueryAllPartitions(ctx, + err = cosmosdbapi.PerformQuery(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectMaxEventIDStmt, params, &rows) // err = stmt.QueryRowContext(ctx).Scan(&nullableID) - if rows != nil { - nullableID.Int64 = rows[0].Max + if len(rows) > 0 { + id = rows[0].Max } - if nullableID.Valid { - id = nullableID.Int64 - } return } diff --git a/syncapi/storage/cosmosdb/peeks_table.go b/syncapi/storage/cosmosdb/peeks_table.go index 6795e08d3..5db726889 100644 --- a/syncapi/storage/cosmosdb/peeks_table.go +++ b/syncapi/storage/cosmosdb/peeks_table.go @@ -382,10 +382,11 @@ func (s *peekStatements) SelectMaxPeekID( "@x1": s.getCollectionName(), } var rows []peekCosmosMaxNumber - err = cosmosdbapi.PerformQueryAllPartitions(ctx, + err = cosmosdbapi.PerformQuery(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectMaxPeekIDStmt, params, &rows) // err = stmt.QueryRowContext(ctx).Scan(&nullableID) diff --git a/syncapi/storage/cosmosdb/receipt_table.go b/syncapi/storage/cosmosdb/receipt_table.go index 90affd3c1..4ce868ddc 100644 --- a/syncapi/storage/cosmosdb/receipt_table.go +++ b/syncapi/storage/cosmosdb/receipt_table.go @@ -76,8 +76,7 @@ const selectRoomReceipts = "" + // "SELECT MAX(id) FROM syncapi_receipts" const selectMaxReceiptIDSQL = "" + - "select max(c.mx_syncapi_receipt.id) as number from c where c._cn = @x1 " + - "and c._sid = @x2 " + "select max(c.mx_syncapi_receipt.id) as number from c where c._cn = @x1 " type receiptStatements struct { db *SyncServerDatasource @@ -210,13 +209,13 @@ func (s *receiptStatements) SelectMaxReceiptID( params := map[string]interface{}{ "@x1": s.getCollectionName(), - "@x2": s.db.cosmosConfig.TenantName, } var rows []receiptCosmosMaxNumber - err = cosmosdbapi.PerformQueryAllPartitions(ctx, + err = cosmosdbapi.PerformQuery(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectMaxReceiptID, params, &rows) // stmt := sqlutil.TxStmt(txn, s.selectMaxReceiptID) diff --git a/syncapi/storage/cosmosdb/send_to_device_table.go b/syncapi/storage/cosmosdb/send_to_device_table.go index 9dd102531..ec0793edc 100644 --- a/syncapi/storage/cosmosdb/send_to_device_table.go +++ b/syncapi/storage/cosmosdb/send_to_device_table.go @@ -83,8 +83,7 @@ const deleteSendToDeviceMessagesSQL = "" + // "SELECT MAX(id) FROM syncapi_send_to_device" const selectMaxSendToDeviceIDSQL = "" + - "select max(c.mx_syncapi_send_to_device.id) as number from c where c._cn = @x1 " + - "and c._sid = @x2 " + "select max(c.mx_syncapi_send_to_device.id) as number from c where c._cn = @x1 " type sendToDeviceStatements struct { db *SyncServerDatasource @@ -274,13 +273,13 @@ func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID( params := map[string]interface{}{ "@x1": s.getCollectionName(), - "@x2": s.db.cosmosConfig.TenantName, } var rows []SendToDeviceCosmosMaxNumber - err = cosmosdbapi.PerformQueryAllPartitions(ctx, + err = cosmosdbapi.PerformQuery(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectMaxSendToDeviceIDStmt, params, &rows) // stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt) diff --git a/userapi/storage/accounts/cosmosdb/key_backup_version_table.go b/userapi/storage/accounts/cosmosdb/key_backup_version_table.go index 9ed269c8c..5f986c8f8 100644 --- a/userapi/storage/accounts/cosmosdb/key_backup_version_table.go +++ b/userapi/storage/accounts/cosmosdb/key_backup_version_table.go @@ -76,8 +76,8 @@ type keyBackupVersionCosmosNumber struct { // "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" const selectLatestVersionSQL = "" + - "select max(c.mx_userapi_account_e2e_room_keys_versions.version) as number from c where c._sid = @x1 and c._cn = @x2 " + - "and c.mx_userapi_account_e2e_room_keys_versions.user_id = @x3 " + "select max(c.mx_userapi_account_e2e_room_keys_versions.version) as number from c where c._cn = @x1 " + + "and c.mx_userapi_account_e2e_room_keys_versions.user_id = @x2 " type keyBackupVersionStatements struct { db *Database @@ -276,17 +276,17 @@ func (s *keyBackupVersionStatements) selectKeyBackup( if version == "" { // var v *int64 // allows nulls params := map[string]interface{}{ - "@x1": s.db.cosmosConfig.TenantName, - "@x2": s.getCollectionName(), - "@x3": userID, + "@x1": s.getCollectionName(), + "@x2": userID, } // err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) var rows []keyBackupVersionCosmosNumber - err = cosmosdbapi.PerformQueryAllPartitions(ctx, + err = cosmosdbapi.PerformQuery(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, + s.getPartitionKey(userID), s.selectLatestVersionStmt, params, &rows) if err != nil { @@ -303,6 +303,11 @@ func (s *keyBackupVersionStatements) selectKeyBackup( return } versionInt = rows[0].Number + if versionInt == 0 { + err = cosmosdbutil.ErrNoRows + return + } + } else { if versionInt, err = strconv.ParseInt(version, 10, 64); err != nil { return