From 3088238419833242d4d0a0fda9ea368e3284a835 Mon Sep 17 00:00:00 2001 From: alexfca <75228224+alexfca@users.noreply.github.com> Date: Thu, 23 Sep 2021 14:48:32 +1000 Subject: [PATCH] Add UniqueId to PartitionKey for some Dendrite tables (where possible) (#19) * - Make all PartitionKeys include the tablename - Update specific PKs to be item specific - Add validation to the PerformQueryXX methods - Fix queries that fail validation * - Revert the PK back to CollectionName as it already includes the TableName Co-authored-by: alexf@example.com --- .../storage/cosmosdb/inbound_peeks_table.go | 21 +++++----- .../storage/cosmosdb/outbound_peeks_table.go | 19 +++++---- internal/cosmosdbapi/client.go | 41 ++++++++++++++++++- .../cosmosdbutil/partition_offset_table.go | 11 ++--- .../naffkacosmosdb/naffka_topics_table.go | 11 ++--- .../cosmosdb/cross_signing_keys_table.go | 11 ++--- .../cosmosdb/cross_signing_sigs_table.go | 13 +++--- .../storage/cosmosdb/device_keys_table.go | 19 +++++---- .../storage/cosmosdb/one_time_keys_table.go | 15 +++---- .../cosmosdb/media_repository_table.go | 1 + mediaapi/storage/cosmosdb/thumbnail_table.go | 11 ++--- .../storage/cosmosdb/event_json_table.go | 1 + roomserver/storage/cosmosdb/events_table.go | 1 + roomserver/storage/cosmosdb/invite_table.go | 12 +++--- .../storage/cosmosdb/previous_events_table.go | 11 ++--- .../storage/cosmosdb/redactions_table.go | 1 - .../storage/cosmosdb/transactions_table.go | 9 ++-- .../storage/cosmosdb/account_data_table.go | 5 ++- .../cosmosdb/backwards_extremities_table.go | 15 +++---- syncapi/storage/cosmosdb/invites_table.go | 5 ++- .../cosmosdb/output_room_events_table.go | 5 ++- .../output_room_events_topology_table.go | 1 + syncapi/storage/cosmosdb/receipt_table.go | 4 +- .../storage/cosmosdb/send_to_device_table.go | 4 +- .../accounts/cosmosdb/account_data_table.go | 13 +++--- .../accounts/cosmosdb/key_backup_table.go | 17 ++++---- .../cosmosdb/key_backup_version_table.go | 15 +++---- .../accounts/cosmosdb/profile_table.go | 4 +- .../storage/devices/cosmosdb/devices_table.go | 1 + userapi/storage/devices/cosmosdb/storage.go | 3 -- 30 files changed, 184 insertions(+), 116 deletions(-) diff --git a/federationsender/storage/cosmosdb/inbound_peeks_table.go b/federationsender/storage/cosmosdb/inbound_peeks_table.go index 48ad7af23..9dd5256dd 100644 --- a/federationsender/storage/cosmosdb/inbound_peeks_table.go +++ b/federationsender/storage/cosmosdb/inbound_peeks_table.go @@ -92,8 +92,9 @@ func (s *inboundPeeksStatements) getCollectionName() string { return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func (s *inboundPeeksStatements) getPartitionKey() string { - return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +func (s *inboundPeeksStatements) getPartitionKey(roomId string) string { + uniqueId := roomId + return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId) } func getInboundPeek(s *inboundPeeksStatements, ctx context.Context, pk string, docId string) (*inboundPeekCosmosData, error) { @@ -163,7 +164,7 @@ func (s *inboundPeeksStatements) InsertInboundPeek( docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getInboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) + dbData, _ := getInboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.InboundPeek.RenewedTimestamp = nowMilli @@ -179,7 +180,7 @@ func (s *inboundPeeksStatements) InsertInboundPeek( } dbData = &inboundPeekCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(roomID), cosmosDocId), InboundPeek: data, } } @@ -208,7 +209,7 @@ func (s *inboundPeeksStatements) RenewInboundPeek( cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) // _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) - res, err := getInboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) + res, err := getInboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId) if err != nil { return @@ -233,10 +234,10 @@ func (s *inboundPeeksStatements) SelectInboundPeek( // "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" // UNIQUE (room_id, server_name, peek_id) docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getPartitionKey(), docId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getPartitionKey(roomID), docId) // row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID) - row, err := getInboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) + row, err := getInboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId) if row == nil { return nil, nil @@ -270,7 +271,7 @@ func (s *inboundPeeksStatements) SelectInboundPeeks( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectInboundPeeksStmt, params, &rows) + s.getPartitionKey(roomID), s.selectInboundPeeksStmt, params, &rows) if err != nil { return @@ -307,7 +308,7 @@ func (s *inboundPeeksStatements) DeleteInboundPeek( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.deleteInboundPeekStmt, params, &rows) + s.getPartitionKey(roomID), s.deleteInboundPeekStmt, params, &rows) if err != nil { return @@ -339,7 +340,7 @@ func (s *inboundPeeksStatements) DeleteInboundPeeks( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.deleteInboundPeekStmt, params, &rows) + s.getPartitionKey(roomID), s.deleteInboundPeekStmt, params, &rows) if err != nil { return diff --git a/federationsender/storage/cosmosdb/outbound_peeks_table.go b/federationsender/storage/cosmosdb/outbound_peeks_table.go index 828ca2222..cc707886f 100644 --- a/federationsender/storage/cosmosdb/outbound_peeks_table.go +++ b/federationsender/storage/cosmosdb/outbound_peeks_table.go @@ -89,8 +89,9 @@ func (s *outboundPeeksStatements) getCollectionName() string { return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func (s *outboundPeeksStatements) getPartitionKey() string { - return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +func (s *outboundPeeksStatements) getPartitionKey(roomId string) string { + uniqueId := roomId + return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId) } func getOutboundPeek(s *outboundPeeksStatements, ctx context.Context, pk string, docId string) (*outboundPeekCosmosData, error) { @@ -159,7 +160,7 @@ func (s *outboundPeeksStatements) InsertOutboundPeek( docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getOutboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) + dbData, _ := getOutboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.OutboundPeek.RenewalInterval = renewalInterval @@ -176,7 +177,7 @@ func (s *outboundPeeksStatements) InsertOutboundPeek( } dbData = &outboundPeekCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(roomID), cosmosDocId), OutboundPeek: data, } @@ -205,7 +206,7 @@ func (s *outboundPeeksStatements) RenewOutboundPeek( cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) // _, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) - res, err := getOutboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) + res, err := getOutboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId) if err != nil { return @@ -233,7 +234,7 @@ func (s *outboundPeeksStatements) SelectOutboundPeek( cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) // row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID) - row, err := getOutboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) + row, err := getOutboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId) if err != nil { return nil, err @@ -273,7 +274,7 @@ func (s *outboundPeeksStatements) SelectOutboundPeeks( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectOutboundPeeksStmt, params, &rows) + s.getPartitionKey(roomID), s.selectOutboundPeeksStmt, params, &rows) if err != nil { return @@ -311,7 +312,7 @@ func (s *outboundPeeksStatements) DeleteOutboundPeek( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.deleteOutboundPeekStmt, params, &rows) + s.getPartitionKey(roomID), s.deleteOutboundPeekStmt, params, &rows) if err != nil { return @@ -344,7 +345,7 @@ func (s *outboundPeeksStatements) DeleteOutboundPeeks( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.deleteOutboundPeeksStmt, params, &rows) + s.getPartitionKey(roomID), s.deleteOutboundPeeksStmt, params, &rows) if err != nil { return diff --git a/internal/cosmosdbapi/client.go b/internal/cosmosdbapi/client.go index f7e74001c..29837f00f 100644 --- a/internal/cosmosdbapi/client.go +++ b/internal/cosmosdbapi/client.go @@ -2,6 +2,8 @@ package cosmosdbapi import ( "context" + "errors" + "strings" "time" cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi" @@ -55,9 +57,13 @@ func PerformQuery(ctx context.Context, qryString string, params map[string]interface{}, response interface{}) error { + err := validateQuery(qryString) + if err != nil { + return err + } optionsQry := GetQueryDocumentsOptions(partitonKey) var query = GetQuery(qryString, params) - _, err := GetClient(conn).QueryDocuments( + _, err = GetClient(conn).QueryDocuments( ctx, databaseName, containerName, @@ -74,9 +80,13 @@ func PerformQueryAllPartitions(ctx context.Context, qryString string, params map[string]interface{}, response interface{}) error { + err := validateQueryAllPartitions(qryString) + if err != nil { + return err + } var optionsQry = GetQueryAllPartitionsDocumentsOptions() var query = GetQuery(qryString, params) - _, err := GetClient(conn).QueryDocuments( + _, err = GetClient(conn).QueryDocuments( ctx, databaseName, containerName, @@ -130,3 +140,30 @@ func GetDocumentOrNil(connection CosmosConnection, config CosmosConfig, ctx cont return nil } + +func validateQuery(qryString string) error { + if len(qryString) == 0 { + return errors.New("qryString was nil") + } + if !strings.Contains(qryString, " c._cn = ") { + return errors.New("qryString must contain [ c._cn = ] ") + } + return nil +} + +func validateQueryAllPartitions(qryString string) error { + err := validateQuery(qryString) + if err != nil { + return err + } + if strings.Contains(qryString, " top ") { + return errors.New("qryString contains [ top ] ") + } + if strings.Contains(qryString, " order by ") { + return errors.New("qryString contains [ order by ] ") + } + if !strings.Contains(qryString, " c._sid = ") { + return errors.New("qryString must contain [ c._sid = ] ") + } + return nil +} diff --git a/internal/cosmosdbutil/partition_offset_table.go b/internal/cosmosdbutil/partition_offset_table.go index c3c2fe991..cb206ae60 100644 --- a/internal/cosmosdbutil/partition_offset_table.go +++ b/internal/cosmosdbutil/partition_offset_table.go @@ -90,8 +90,9 @@ func (s PartitionOffsetStatements) getCollectionName() string { return cosmosdbapi.GetCollectionName(s.db.DatabaseName, tableName) } -func (s *PartitionOffsetStatements) getPartitionKey() string { - return cosmosdbapi.GetPartitionKeyByCollection(s.db.CosmosConfig.TenantName, s.getCollectionName()) +func (s *PartitionOffsetStatements) getPartitionKey(topic string) string { + uniqueId := topic + return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.CosmosConfig.TenantName, s.getCollectionName(), uniqueId) } func getPartitionOffset(s *PartitionOffsetStatements, ctx context.Context, pk string, docId string) (*partitionOffsetCosmosData, error) { @@ -154,7 +155,7 @@ func (s *PartitionOffsetStatements) selectPartitionOffsets( s.db.Connection, s.db.CosmosConfig.DatabaseName, s.db.CosmosConfig.ContainerName, - s.getPartitionKey(), s.selectPartitionOffsetsStmt, params, &rows) + s.getPartitionKey(topic), s.selectPartitionOffsetsStmt, params, &rows) // rows, err := s.selectPartitionOffsetsStmt.QueryContext(ctx, topic) if err != nil { @@ -195,7 +196,7 @@ func (s *PartitionOffsetStatements) upsertPartitionOffset( docId := fmt.Sprintf("%s_%d", topic, partition) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.CosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getPartitionOffset(s, ctx, s.getPartitionKey(), cosmosDocId) + dbData, _ := getPartitionOffset(s, ctx, s.getPartitionKey(topic), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.PartitionOffset.PartitionOffset = offset @@ -207,7 +208,7 @@ func (s *PartitionOffsetStatements) upsertPartitionOffset( } dbData = &partitionOffsetCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.CosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.CosmosConfig.TenantName, s.getPartitionKey(topic), cosmosDocId), PartitionOffset: data, } diff --git a/internal/naffka/naffkacosmosdb/naffka_topics_table.go b/internal/naffka/naffkacosmosdb/naffka_topics_table.go index 67d912364..2e123e6ea 100644 --- a/internal/naffka/naffkacosmosdb/naffka_topics_table.go +++ b/internal/naffka/naffkacosmosdb/naffka_topics_table.go @@ -116,8 +116,9 @@ func (s *topicsStatements) getCollectionNameMessages() string { return cosmosdbapi.GetCollectionName(s.DB.databaseName, s.tableNameMessages) } -func (s *topicsStatements) getPartitionKeyMessages() string { - return cosmosdbapi.GetPartitionKeyByCollection(s.DB.cosmosConfig.TenantName, s.getCollectionNameMessages()) +func (s *topicsStatements) getPartitionKeyMessages(topicNid int64) string { + uniqueId := fmt.Sprintf("%d", topicNid) + return cosmosdbapi.GetPartitionKeyByUniqueId(s.DB.cosmosConfig.TenantName, s.getCollectionNameMessages(), uniqueId) } func getTopic(s *topicsStatements, ctx context.Context, pk string, docId string) (*topicCosmosData, error) { @@ -310,7 +311,7 @@ func (t *topicsStatements) InsertTopics( } dbData := &messageCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(t.getCollectionNameMessages(), t.DB.cosmosConfig.TenantName, t.getPartitionKeyMessages(), cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(t.getCollectionNameMessages(), t.DB.cosmosConfig.TenantName, t.getPartitionKeyMessages(topicNID), cosmosDocId), Message: data, } @@ -348,7 +349,7 @@ func (t *topicsStatements) SelectMessages( t.DB.connection, t.DB.cosmosConfig.DatabaseName, t.DB.cosmosConfig.ContainerName, - t.getPartitionKeyMessages(), t.selectMessagesStmt, params, &rows) + t.getPartitionKeyMessages(topicNID), t.selectMessagesStmt, params, &rows) if err != nil { return nil, err @@ -387,7 +388,7 @@ func (t *topicsStatements) SelectMaxOffset( t.DB.connection, t.DB.cosmosConfig.DatabaseName, t.DB.cosmosConfig.ContainerName, - t.getPartitionKeyMessages(), t.selectMaxOffsetStmt, params, &rows) + t.getPartitionKeyMessages(topicNID), t.selectMaxOffsetStmt, params, &rows) if err != nil { return 0, err diff --git a/keyserver/storage/cosmosdb/cross_signing_keys_table.go b/keyserver/storage/cosmosdb/cross_signing_keys_table.go index 735d062ff..f654bde64 100644 --- a/keyserver/storage/cosmosdb/cross_signing_keys_table.go +++ b/keyserver/storage/cosmosdb/cross_signing_keys_table.go @@ -66,8 +66,9 @@ func (s *crossSigningKeysStatements) getCollectionName() string { return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func (s *crossSigningKeysStatements) getPartitionKey() string { - return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +func (s *crossSigningKeysStatements) getPartitionKey(userId string) string { + uniqueId := userId + return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId) } func getCrossSigningKeys(s *crossSigningKeysStatements, ctx context.Context, pk string, docId string) (*crossSigningKeysCosmosData, error) { @@ -112,7 +113,7 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectCrossSigningKeysForUserStmt, params, &rows) + s.getPartitionKey(userID), s.selectCrossSigningKeysForUserStmt, params, &rows) // rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID) if err != nil { @@ -151,7 +152,7 @@ func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( docId := fmt.Sprintf("%s_%s", userID, keyType) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getCrossSigningKeys(s, ctx, s.getPartitionKey(), cosmosDocId) + dbData, _ := getCrossSigningKeys(s, ctx, s.getPartitionKey(userID), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.CrossSigningKeys.KeyData = keyData @@ -163,7 +164,7 @@ func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( } dbData = &crossSigningKeysCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(userID), cosmosDocId), CrossSigningKeys: data, } } diff --git a/keyserver/storage/cosmosdb/cross_signing_sigs_table.go b/keyserver/storage/cosmosdb/cross_signing_sigs_table.go index 5e1cc0e3e..30037d830 100644 --- a/keyserver/storage/cosmosdb/cross_signing_sigs_table.go +++ b/keyserver/storage/cosmosdb/cross_signing_sigs_table.go @@ -78,8 +78,9 @@ func (s *crossSigningSigsStatements) getCollectionName() string { return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func (s *crossSigningSigsStatements) getPartitionKey() string { - return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +func (s *crossSigningSigsStatements) getPartitionKey(targetUserId string) string { + uniqueId := targetUserId + return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId) } func getCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, pk string, docId string) (*crossSigningSigsCosmosData, error) { @@ -145,7 +146,7 @@ func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectCrossSigningSigsForTargetStmt, params, &rows) + s.getPartitionKey(targetUserID), s.selectCrossSigningSigsForTargetStmt, params, &rows) // rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID) if err != nil { @@ -185,7 +186,7 @@ func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( docId := fmt.Sprintf("%s_%s_%s", originUserID, targetUserID, targetKeyID) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getCrossSigningSigs(s, ctx, s.getPartitionKey(), cosmosDocId) + dbData, _ := getCrossSigningSigs(s, ctx, s.getPartitionKey(targetUserID), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.CrossSigningSigs.OriginKeyId = string(originKeyID) @@ -200,7 +201,7 @@ func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( } dbData = &crossSigningSigsCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(targetUserID), cosmosDocId), CrossSigningSigs: data, } } @@ -230,7 +231,7 @@ func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectCrossSigningSigsForTargetStmt, params, &rows) + s.getPartitionKey(targetUserID), s.selectCrossSigningSigsForTargetStmt, params, &rows) // if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil { // return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err) diff --git a/keyserver/storage/cosmosdb/device_keys_table.go b/keyserver/storage/cosmosdb/device_keys_table.go index cf17c1c6e..d18023262 100644 --- a/keyserver/storage/cosmosdb/device_keys_table.go +++ b/keyserver/storage/cosmosdb/device_keys_table.go @@ -168,8 +168,9 @@ func (s *deviceKeysStatements) getCollectionName() string { return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func (s *deviceKeysStatements) getPartitionKey() string { - return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +func (s *deviceKeysStatements) getPartitionKey(userId string) string { + uniqueId := userId + return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId) } func NewCosmosDBDeviceKeysTable(db *Database) (tables.DeviceKeys, error) { @@ -212,7 +213,7 @@ func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), selectAllDeviceKeysSQL, params, &rows) + s.getPartitionKey(userID), selectAllDeviceKeysSQL, params, &rows) if err != nil { return err @@ -242,7 +243,7 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), selectAllDeviceKeysSQL, params, &rows) + s.getPartitionKey(userID), selectAllDeviceKeysSQL, params, &rows) if err != nil { return err @@ -275,7 +276,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectBatchDeviceKeysStmt, params, &rows) + s.getPartitionKey(userID), s.selectBatchDeviceKeysStmt, params, &rows) // rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) if err != nil { @@ -327,7 +328,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] docId := fmt.Sprintf("%s_%s", key.UserID, key.DeviceID) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - response, err := getDeviceKey(s, ctx, s.getPartitionKey(), cosmosDocId) + response, err := getDeviceKey(s, ctx, s.getPartitionKey(key.UserID), cosmosDocId) if err != nil && err != cosmosdbutil.ErrNoRows { return err @@ -366,7 +367,7 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), selectMaxStreamForUserSQL, params, &rows) + s.getPartitionKey(userID), selectMaxStreamForUserSQL, params, &rows) if err != nil { if err == cosmosdbutil.ErrNoRows { @@ -413,7 +414,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), countStreamIDsForUserSQL, params, &rows) + s.getPartitionKey(userID), countStreamIDsForUserSQL, params, &rows) if err != nil { return 0, err @@ -440,7 +441,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) dbData := &deviceKeyCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(key.UserID), cosmosDocId), DeviceKey: mapFromDeviceKeyMessage(key), } diff --git a/keyserver/storage/cosmosdb/one_time_keys_table.go b/keyserver/storage/cosmosdb/one_time_keys_table.go index 0f2a52e5f..a3972e070 100644 --- a/keyserver/storage/cosmosdb/one_time_keys_table.go +++ b/keyserver/storage/cosmosdb/one_time_keys_table.go @@ -106,8 +106,9 @@ func (s *oneTimeKeysStatements) getCollectionName() string { return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func (s *oneTimeKeysStatements) getPartitionKey() string { - return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +func (s *oneTimeKeysStatements) getPartitionKey(userId string) string { + uniqueId := userId + return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId) } func getOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, pk string, docId string) (*oneTimeKeyCosmosData, error) { @@ -194,7 +195,7 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectKeyByAlgorithmStmt, params, &rows) + s.getPartitionKey(userID), s.selectKeyByAlgorithmStmt, params, &rows) if err != nil { return nil, err @@ -239,7 +240,7 @@ func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, de s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectKeysCountStmt, params, &rows) + s.getPartitionKey(counts.UserID), s.selectKeysCountStmt, params, &rows) if err != nil { return nil, err @@ -286,7 +287,7 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys( } dbData := &oneTimeKeyCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(counts.UserID), cosmosDocId), OneTimeKey: data, } @@ -309,7 +310,7 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectKeysCountStmt, params, &rows) + s.getPartitionKey(keys.UserID), s.selectKeysCountStmt, params, &rows) if err != nil { return nil, err @@ -346,7 +347,7 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectKeyByAlgorithmStmt, params, &rows) + s.getPartitionKey(userID), s.selectKeyByAlgorithmStmt, params, &rows) if err != nil { if err == cosmosdbutil.ErrNoRows { diff --git a/mediaapi/storage/cosmosdb/media_repository_table.go b/mediaapi/storage/cosmosdb/media_repository_table.go index ff0d010bf..63622269b 100644 --- a/mediaapi/storage/cosmosdb/media_repository_table.go +++ b/mediaapi/storage/cosmosdb/media_repository_table.go @@ -99,6 +99,7 @@ func (s *mediaStatements) getCollectionName() string { } func (s *mediaStatements) getPartitionKey() string { + //No easy PK, so just use the collection return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } diff --git a/mediaapi/storage/cosmosdb/thumbnail_table.go b/mediaapi/storage/cosmosdb/thumbnail_table.go index 058d1be70..a2cea0f6b 100644 --- a/mediaapi/storage/cosmosdb/thumbnail_table.go +++ b/mediaapi/storage/cosmosdb/thumbnail_table.go @@ -89,8 +89,9 @@ func (s *thumbnailStatements) getCollectionName() string { return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func (s *thumbnailStatements) getPartitionKey() string { - return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +func (s *thumbnailStatements) getPartitionKey(mediaId string) string { + uniqueId := mediaId + return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId) } func getThumbnail(s *thumbnailStatements, ctx context.Context, pk string, docId string) (*thumbnailCosmosData, error) { @@ -163,7 +164,7 @@ func (s *thumbnailStatements) insertThumbnail( } dbData := &thumbnailCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(data.MediaID), cosmosDocId), Thumbnail: data, } @@ -209,7 +210,7 @@ func (s *thumbnailStatements) selectThumbnail( cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) // row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID) - row, err := getThumbnail(s, ctx, s.getPartitionKey(), cosmosDocId) + row, err := getThumbnail(s, ctx, s.getPartitionKey(string(mediaID)), cosmosDocId) if err != nil { return nil, err @@ -250,7 +251,7 @@ func (s *thumbnailStatements) selectThumbnails( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectThumbnailsStmt, params, &rows) + s.getPartitionKey(string(mediaID)), s.selectThumbnailsStmt, params, &rows) if err != nil { return nil, err diff --git a/roomserver/storage/cosmosdb/event_json_table.go b/roomserver/storage/cosmosdb/event_json_table.go index 48ff4fed7..7b97022ce 100644 --- a/roomserver/storage/cosmosdb/event_json_table.go +++ b/roomserver/storage/cosmosdb/event_json_table.go @@ -70,6 +70,7 @@ func (s *eventJSONStatements) getCollectionName() string { } func (s *eventJSONStatements) getPartitionKey() string { + //No easy PK, so just use the collection return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } diff --git a/roomserver/storage/cosmosdb/events_table.go b/roomserver/storage/cosmosdb/events_table.go index 4c199b720..470483656 100644 --- a/roomserver/storage/cosmosdb/events_table.go +++ b/roomserver/storage/cosmosdb/events_table.go @@ -178,6 +178,7 @@ func (s *eventStatements) getCollectionName() string { } func (s *eventStatements) getPartitionKey() string { + //No easy PK, so just use the collection return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } diff --git a/roomserver/storage/cosmosdb/invite_table.go b/roomserver/storage/cosmosdb/invite_table.go index 68f5b6c60..c5083fabd 100644 --- a/roomserver/storage/cosmosdb/invite_table.go +++ b/roomserver/storage/cosmosdb/invite_table.go @@ -18,6 +18,7 @@ package cosmosdb import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/internal/cosmosdbutil" @@ -97,8 +98,9 @@ func (s *inviteStatements) getCollectionName() string { return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func (s *inviteStatements) getPartitionKey() string { - return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +func (s *inviteStatements) getPartitionKey(roomNId int64) string { + uniqueId := fmt.Sprintf("%d", roomNId) + return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId) } func getInvite(s *inviteStatements, ctx context.Context, pk string, docId string) (*inviteCosmosData, error) { @@ -169,7 +171,7 @@ func (s *inviteStatements) InsertInviteEvent( } var dbData = inviteCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(int64(roomNID)), cosmosDocId), Invite: data, } @@ -211,7 +213,7 @@ func (s *inviteStatements) UpdateInviteRetired( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectInvitesAboutToRetireStmt, params, &rows) + s.getPartitionKey(int64(roomNID)), s.selectInvitesAboutToRetireStmt, params, &rows) if err != nil { return @@ -248,7 +250,7 @@ func (s *inviteStatements) SelectInviteActiveForUserInRoom( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectInviteActiveForUserInRoomStmt, params, &rows) + s.getPartitionKey(int64(roomNID)), s.selectInviteActiveForUserInRoomStmt, params, &rows) if err != nil { return nil, nil, err diff --git a/roomserver/storage/cosmosdb/previous_events_table.go b/roomserver/storage/cosmosdb/previous_events_table.go index 448ade8a1..75c70b56a 100644 --- a/roomserver/storage/cosmosdb/previous_events_table.go +++ b/roomserver/storage/cosmosdb/previous_events_table.go @@ -89,8 +89,9 @@ func (s *previousEventStatements) getCollectionName() string { return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func (s *previousEventStatements) getPartitionKey() string { - return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +func (s *previousEventStatements) getPartitionKey(previousEventId string) string { + uniqueId := previousEventId + return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId) } func getPreviousEvent(s *previousEventStatements, ctx context.Context, pk string, docId string) (*previousEventCosmosData, error) { @@ -141,7 +142,7 @@ func (s *previousEventStatements) InsertPreviousEvent( // SELECT 1 FROM roomserver_previous_events // WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 - existing, err := getPreviousEvent(s, ctx, s.getPartitionKey(), cosmosDocId) + existing, err := getPreviousEvent(s, ctx, s.getPartitionKey(previousEventID), cosmosDocId) if err != nil { if err != cosmosdbutil.ErrNoRows { @@ -159,7 +160,7 @@ func (s *previousEventStatements) InsertPreviousEvent( } dbData = previousEventCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(previousEventID), cosmosDocId), PreviousEvent: data, } } else { @@ -206,7 +207,7 @@ func (s *previousEventStatements) SelectPreviousEventExists( // SELECT 1 FROM roomserver_previous_events // WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 - dbData, err := getPreviousEvent(s, ctx, s.getPartitionKey(), cosmosDocId) + dbData, err := getPreviousEvent(s, ctx, s.getPartitionKey(eventID), cosmosDocId) if err != nil { return err } diff --git a/roomserver/storage/cosmosdb/redactions_table.go b/roomserver/storage/cosmosdb/redactions_table.go index c3dcad635..0f71819d4 100644 --- a/roomserver/storage/cosmosdb/redactions_table.go +++ b/roomserver/storage/cosmosdb/redactions_table.go @@ -180,7 +180,6 @@ func (s *redactionStatements) SelectRedactionInfoByRedactionEventID( response, err := getRedaction(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { info = nil - err = err return } diff --git a/roomserver/storage/cosmosdb/transactions_table.go b/roomserver/storage/cosmosdb/transactions_table.go index 016254512..c245ffa20 100644 --- a/roomserver/storage/cosmosdb/transactions_table.go +++ b/roomserver/storage/cosmosdb/transactions_table.go @@ -68,8 +68,9 @@ func (s *transactionStatements) getCollectionName() string { return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func (s *transactionStatements) getPartitionKey() string { - return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +func (s *transactionStatements) getPartitionKey(transactionID string) string { + uniqueId := transactionID + return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId) } func getTransaction(s *transactionStatements, ctx context.Context, pk string, docId string) (*transactionCosmosData, error) { @@ -124,7 +125,7 @@ func (s *transactionStatements) InsertTransaction( } var dbData = transactionCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(transactionID), cosmosDocId), Transaction: data, } @@ -153,7 +154,7 @@ func (s *transactionStatements) SelectTransactionEventID( docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - response, err := getTransaction(s, ctx, s.getPartitionKey(), cosmosDocId) + response, err := getTransaction(s, ctx, s.getPartitionKey(transactionID), cosmosDocId) if err != nil { return "", err diff --git a/syncapi/storage/cosmosdb/account_data_table.go b/syncapi/storage/cosmosdb/account_data_table.go index 670b4f309..bad466016 100644 --- a/syncapi/storage/cosmosdb/account_data_table.go +++ b/syncapi/storage/cosmosdb/account_data_table.go @@ -72,7 +72,8 @@ const selectAccountDataInRangeSQL = "" + // "SELECT MAX(id) FROM syncapi_account_data_type" const selectMaxAccountDataIDSQL = "" + - "select max(c.mx_syncapi_account_data_type.id) as number from c where c._cn = @x1 " + "select max(c.mx_syncapi_account_data_type.id) as number from c where c._cn = @x1 " + + "and c._sid = @x2 " type accountDataStatements struct { db *SyncServerDatasource @@ -88,6 +89,7 @@ func (s *accountDataStatements) getCollectionName() string { } func (s *accountDataStatements) getPartitionKey() string { + //No easy PK, so just use the collection return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } @@ -245,6 +247,7 @@ func (s *accountDataStatements) SelectMaxAccountDataID( // err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID) params := map[string]interface{}{ "@x1": s.getCollectionName(), + "@x2": s.db.cosmosConfig.TenantName, } var rows []AccountDataTypeNumberCosmosData diff --git a/syncapi/storage/cosmosdb/backwards_extremities_table.go b/syncapi/storage/cosmosdb/backwards_extremities_table.go index 113e7e573..d19c20427 100644 --- a/syncapi/storage/cosmosdb/backwards_extremities_table.go +++ b/syncapi/storage/cosmosdb/backwards_extremities_table.go @@ -82,8 +82,9 @@ func (s *backwardExtremitiesStatements) getCollectionName() string { return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func (s *backwardExtremitiesStatements) getPartitionKey() string { - return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +func (s *backwardExtremitiesStatements) getPartitionKey(roomId string) string { + uniqueId := roomId + return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId) } func getBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, pk string, docId string) (*backwardExtremityCosmosData, error) { @@ -143,7 +144,7 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( docId := fmt.Sprintf("%s_%s_%s", roomID, eventID, prevEventID) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getBackwardExtremity(s, ctx, s.getPartitionKey(), cosmosDocId) + dbData, _ := getBackwardExtremity(s, ctx, s.getPartitionKey(roomID), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() } else { @@ -154,7 +155,7 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( } dbData = &backwardExtremityCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(roomID), cosmosDocId), BackwardExtremity: data, } } @@ -185,7 +186,7 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectBackwardExtremitiesForRoomStmt, params, &rows) + s.getPartitionKey(roomID), s.selectBackwardExtremitiesForRoomStmt, params, &rows) if err != nil { return @@ -221,7 +222,7 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.deleteBackwardExtremityStmt, params, &rows) + s.getPartitionKey(roomID), s.deleteBackwardExtremityStmt, params, &rows) if err != nil { return @@ -251,7 +252,7 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.deleteBackwardExtremitiesForRoomStmt, params, &rows) + s.getPartitionKey(roomID), s.deleteBackwardExtremitiesForRoomStmt, params, &rows) if err != nil { return diff --git a/syncapi/storage/cosmosdb/invites_table.go b/syncapi/storage/cosmosdb/invites_table.go index 4af177156..611301881 100644 --- a/syncapi/storage/cosmosdb/invites_table.go +++ b/syncapi/storage/cosmosdb/invites_table.go @@ -83,7 +83,8 @@ const selectInviteEventsInRangeSQL = "" + // "SELECT MAX(id) FROM syncapi_invite_events" const selectMaxInviteIDSQL = "" + - "select max(c.mx_syncapi_invite_event.id) from c where c._cn = @x1 " + "select max(c.mx_syncapi_invite_event.id) from c where c._cn = @x1 " + + "and c._sid = @x2 " type inviteEventsStatements struct { db *SyncServerDatasource @@ -306,7 +307,9 @@ func (s *inviteEventsStatements) SelectMaxInviteID( // err = stmt.QueryRowContext(ctx).Scan(&nullableID) params := map[string]interface{}{ "@x1": s.getCollectionName(), + "@x2": s.db.cosmosConfig.TenantName, } + var rows []inviteEventCosmosMaxNumber err = cosmosdbapi.PerformQueryAllPartitions(ctx, s.db.connection, diff --git a/syncapi/storage/cosmosdb/output_room_events_table.go b/syncapi/storage/cosmosdb/output_room_events_table.go index e4d0d0f34..988fd9125 100644 --- a/syncapi/storage/cosmosdb/output_room_events_table.go +++ b/syncapi/storage/cosmosdb/output_room_events_table.go @@ -115,7 +115,8 @@ const selectEarlyEventsSQL = "" + // "SELECT MAX(id) FROM syncapi_output_room_events" const selectMaxEventIDSQL = "" + - "select max(c.mx_syncapi_output_room_event.id) as number from c where c._cn = @x1 " + "select max(c.mx_syncapi_output_room_event.id) as number from c where c._cn = @x1 " + + "and c._sid = @x2 " // "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2" const updateEventJSONSQL = "" + @@ -155,6 +156,7 @@ func (s *outputRoomEventsStatements) getCollectionName() string { } func (s *outputRoomEventsStatements) getPartitionKey() string { + //No easy PK, so just use the collection return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } @@ -349,6 +351,7 @@ func (s *outputRoomEventsStatements) SelectMaxEventID( params := map[string]interface{}{ "@x1": s.getCollectionName(), + "@x2": s.db.cosmosConfig.TenantName, } // stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt) var rows []outputRoomEventCosmosMaxNumber diff --git a/syncapi/storage/cosmosdb/output_room_events_topology_table.go b/syncapi/storage/cosmosdb/output_room_events_topology_table.go index db7d5cfd2..6e9f45d6f 100644 --- a/syncapi/storage/cosmosdb/output_room_events_topology_table.go +++ b/syncapi/storage/cosmosdb/output_room_events_topology_table.go @@ -131,6 +131,7 @@ func (s *outputRoomEventsTopologyStatements) getCollectionName() string { } func (s *outputRoomEventsTopologyStatements) getPartitionKey() string { + //No easy PK, so just use the collection return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } diff --git a/syncapi/storage/cosmosdb/receipt_table.go b/syncapi/storage/cosmosdb/receipt_table.go index 73e792c4b..b16970461 100644 --- a/syncapi/storage/cosmosdb/receipt_table.go +++ b/syncapi/storage/cosmosdb/receipt_table.go @@ -76,7 +76,8 @@ const selectRoomReceipts = "" + // "SELECT MAX(id) FROM syncapi_receipts" const selectMaxReceiptIDSQL = "" + - "select max(c.mx_syncapi_receipt.id) as number from c where c._cn = @x1 " + "select max(c.mx_syncapi_receipt.id) as number from c where c._cn = @x1 " + + "and c._sid = @x2 " type receiptStatements struct { db *SyncServerDatasource @@ -209,6 +210,7 @@ func (s *receiptStatements) SelectMaxReceiptID( params := map[string]interface{}{ "@x1": s.getCollectionName(), + "@x2": s.db.cosmosConfig.TenantName, } var rows []receiptCosmosMaxNumber err = cosmosdbapi.PerformQueryAllPartitions(ctx, diff --git a/syncapi/storage/cosmosdb/send_to_device_table.go b/syncapi/storage/cosmosdb/send_to_device_table.go index 719d7801a..9dd102531 100644 --- a/syncapi/storage/cosmosdb/send_to_device_table.go +++ b/syncapi/storage/cosmosdb/send_to_device_table.go @@ -83,7 +83,8 @@ const deleteSendToDeviceMessagesSQL = "" + // "SELECT MAX(id) FROM syncapi_send_to_device" const selectMaxSendToDeviceIDSQL = "" + - "select max(c.mx_syncapi_send_to_device.id) as number from c where c._cn = @x1 " + "select max(c.mx_syncapi_send_to_device.id) as number from c where c._cn = @x1 " + + "and c._sid = @x2 " type sendToDeviceStatements struct { db *SyncServerDatasource @@ -273,6 +274,7 @@ func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID( params := map[string]interface{}{ "@x1": s.getCollectionName(), + "@x2": s.db.cosmosConfig.TenantName, } var rows []SendToDeviceCosmosMaxNumber err = cosmosdbapi.PerformQueryAllPartitions(ctx, diff --git a/userapi/storage/accounts/cosmosdb/account_data_table.go b/userapi/storage/accounts/cosmosdb/account_data_table.go index 9b76b1fac..898dbc777 100644 --- a/userapi/storage/accounts/cosmosdb/account_data_table.go +++ b/userapi/storage/accounts/cosmosdb/account_data_table.go @@ -62,8 +62,9 @@ func (s *accountDataStatements) getCollectionName() string { return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func (s *accountDataStatements) getPartitionKey() string { - return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +func (s *accountDataStatements) getPartitionKey(localPart string) string { + uniqueId := localPart + return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId) } func (s *accountDataStatements) prepare(db *Database) (err error) { @@ -107,7 +108,7 @@ func (s *accountDataStatements) insertAccountData( docId := id cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getAccountData(s, ctx, s.getPartitionKey(), cosmosDocId) + dbData, _ := getAccountData(s, ctx, s.getPartitionKey(localpart), cosmosDocId) if dbData != nil { // ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 dbData.SetUpdateTime() @@ -121,7 +122,7 @@ func (s *accountDataStatements) insertAccountData( } dbData = &accountDataCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(localpart), cosmosDocId), AccountData: result, } } @@ -151,7 +152,7 @@ func (s *accountDataStatements) selectAccountData( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectAccountDataStmt, params, &rows) + s.getPartitionKey(localpart), s.selectAccountDataStmt, params, &rows) if err != nil { return nil, nil, err @@ -193,7 +194,7 @@ func (s *accountDataStatements) selectAccountDataByType( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectAccountDataByTypeStmt, params, &rows) + s.getPartitionKey(localpart), s.selectAccountDataByTypeStmt, params, &rows) if err != nil { return nil, err diff --git a/userapi/storage/accounts/cosmosdb/key_backup_table.go b/userapi/storage/accounts/cosmosdb/key_backup_table.go index bd525e8e7..60f74441f 100644 --- a/userapi/storage/accounts/cosmosdb/key_backup_table.go +++ b/userapi/storage/accounts/cosmosdb/key_backup_table.go @@ -114,8 +114,9 @@ func (s *keyBackupStatements) getCollectionName() string { return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func (s *keyBackupStatements) getPartitionKey() string { - return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +func (s *keyBackupStatements) getPartitionKey(userId string) string { + uniqueId := userId + return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId) } func getKeyBackup(s *keyBackupStatements, ctx context.Context, pk string, docId string) (*keyBackupCosmosData, error) { @@ -176,7 +177,7 @@ func (s keyBackupStatements) countKeys( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.countKeysStmt, params, &rows) + s.getPartitionKey(userID), s.countKeysStmt, params, &rows) if err != nil { return -1, err @@ -214,7 +215,7 @@ func (s *keyBackupStatements) insertBackupKey( } dbData := &keyBackupCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(userID), cosmosDocId), KeyBackup: data, } @@ -242,7 +243,7 @@ func (s *keyBackupStatements) updateBackupKey( docId := fmt.Sprintf("%s_%s_%s_%s", userID, key.RoomID, key.SessionID, version) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - res, err := getKeyBackup(s, ctx, s.getPartitionKey(), cosmosDocId) + res, err := getKeyBackup(s, ctx, s.getPartitionKey(userID), cosmosDocId) if err != nil { return @@ -278,7 +279,7 @@ func (s *keyBackupStatements) selectKeys( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectKeysStmt, params, &rows) + s.getPartitionKey(userID), s.selectKeysStmt, params, &rows) if err != nil { return nil, err @@ -308,7 +309,7 @@ func (s *keyBackupStatements) selectKeysByRoomID( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectKeysByRoomIDStmt, params, &rows) + s.getPartitionKey(userID), s.selectKeysByRoomIDStmt, params, &rows) if err != nil { return nil, err @@ -341,7 +342,7 @@ func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - s.getPartitionKey(), s.selectKeysByRoomIDAndSessionIDStmt, params, &rows) + s.getPartitionKey(userID), s.selectKeysByRoomIDAndSessionIDStmt, params, &rows) if err != nil { return nil, err diff --git a/userapi/storage/accounts/cosmosdb/key_backup_version_table.go b/userapi/storage/accounts/cosmosdb/key_backup_version_table.go index ecd2c56ef..5079ae41d 100644 --- a/userapi/storage/accounts/cosmosdb/key_backup_version_table.go +++ b/userapi/storage/accounts/cosmosdb/key_backup_version_table.go @@ -95,8 +95,9 @@ func (s *keyBackupVersionStatements) getCollectionName() string { return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func (s *keyBackupVersionStatements) getPartitionKey() string { - return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +func (s *keyBackupVersionStatements) getPartitionKey(userId string) string { + uniqueId := userId + return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId) } func getKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, pk string, docId string) (*keyBackupVersionCosmosData, error) { @@ -168,7 +169,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup( } dbData := &keyBackupVersionCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(userID), cosmosDocId), KeyBackupVersion: data, } @@ -195,7 +196,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData( docId := fmt.Sprintf("%s_%d", userID, versionInt) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId) + res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId) if err != nil { return err @@ -225,7 +226,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag( docId := fmt.Sprintf("%s_%d", userID, versionInt) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId) + res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId) if err != nil { return err @@ -255,7 +256,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup( docId := fmt.Sprintf("%s_%d", userID, versionInt) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId) + res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId) if err != nil { return false, err @@ -324,7 +325,7 @@ func (s *keyBackupVersionStatements) selectKeyBackup( docId := fmt.Sprintf("%s_%d", userID, versionInt) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId) + res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId) if err != nil { return diff --git a/userapi/storage/accounts/cosmosdb/profile_table.go b/userapi/storage/accounts/cosmosdb/profile_table.go index cc5daad87..418e48acb 100644 --- a/userapi/storage/accounts/cosmosdb/profile_table.go +++ b/userapi/storage/accounts/cosmosdb/profile_table.go @@ -169,11 +169,11 @@ func (s *profilesStatements) selectProfileByLocalpart( } if len(rows) == 0 { - return nil, errors.New(fmt.Sprintf("Localpart %s not found", len(rows))) + return nil, errors.New(fmt.Sprintf("Localpart %d not found", len(rows))) } if len(rows) != 1 { - return nil, errors.New(fmt.Sprintf("Localpart %s has multiple entries", len(rows))) + return nil, errors.New(fmt.Sprintf("Localpart %d has multiple entries", len(rows))) } var result = mapFromProfile(rows[0].Profile) diff --git a/userapi/storage/devices/cosmosdb/devices_table.go b/userapi/storage/devices/cosmosdb/devices_table.go index 3c31d8e50..2d7dab646 100644 --- a/userapi/storage/devices/cosmosdb/devices_table.go +++ b/userapi/storage/devices/cosmosdb/devices_table.go @@ -95,6 +95,7 @@ func (s *devicesStatements) getCollectionName() string { } func (s *devicesStatements) getPartitionKey() string { + //No easy PK, so just use the collection return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } diff --git a/userapi/storage/devices/cosmosdb/storage.go b/userapi/storage/devices/cosmosdb/storage.go index 47c7a6d4e..1dc56d6ec 100644 --- a/userapi/storage/devices/cosmosdb/storage.go +++ b/userapi/storage/devices/cosmosdb/storage.go @@ -122,9 +122,6 @@ func (d *Database) CreateDevice( var err error dev, err = d.devices.insertDevice(ctx, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) return dev, err - if returnErr == nil { - return - } } } return