Add UniqueId to PartitionKey for some Dendrite tables (where possible) (#19)

* - Make all PartitionKeys include the tablename
- Update specific PKs to be item specific
- Add validation to the PerformQueryXX methods
- Fix queries that fail validation

* - Revert the PK back to CollectionName as it already includes the TableName

Co-authored-by: alexf@example.com <alexf@example.com>
This commit is contained in:
alexfca 2021-09-23 14:48:32 +10:00 committed by GitHub
parent 927238a686
commit 3088238419
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 184 additions and 116 deletions

View file

@ -92,8 +92,9 @@ func (s *inboundPeeksStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
} }
func (s *inboundPeeksStatements) getPartitionKey() string { func (s *inboundPeeksStatements) getPartitionKey(roomId string) string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) uniqueId := roomId
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
} }
func getInboundPeek(s *inboundPeeksStatements, ctx context.Context, pk string, docId string) (*inboundPeekCosmosData, error) { func getInboundPeek(s *inboundPeeksStatements, ctx context.Context, pk string, docId string) (*inboundPeekCosmosData, error) {
@ -163,7 +164,7 @@ func (s *inboundPeeksStatements) InsertInboundPeek(
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getInboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) dbData, _ := getInboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId)
if dbData != nil { if dbData != nil {
dbData.SetUpdateTime() dbData.SetUpdateTime()
dbData.InboundPeek.RenewedTimestamp = nowMilli dbData.InboundPeek.RenewedTimestamp = nowMilli
@ -179,7 +180,7 @@ func (s *inboundPeeksStatements) InsertInboundPeek(
} }
dbData = &inboundPeekCosmosData{ dbData = &inboundPeekCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(roomID), cosmosDocId),
InboundPeek: data, InboundPeek: data,
} }
} }
@ -208,7 +209,7 @@ func (s *inboundPeeksStatements) RenewInboundPeek(
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
// _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) // _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
res, err := getInboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) res, err := getInboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId)
if err != nil { if err != nil {
return return
@ -233,10 +234,10 @@ func (s *inboundPeeksStatements) SelectInboundPeek(
// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" // "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
// UNIQUE (room_id, server_name, peek_id) // UNIQUE (room_id, server_name, peek_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getPartitionKey(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getPartitionKey(roomID), docId)
// row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID) // row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID)
row, err := getInboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) row, err := getInboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId)
if row == nil { if row == nil {
return nil, nil return nil, nil
@ -270,7 +271,7 @@ func (s *inboundPeeksStatements) SelectInboundPeeks(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectInboundPeeksStmt, params, &rows) s.getPartitionKey(roomID), s.selectInboundPeeksStmt, params, &rows)
if err != nil { if err != nil {
return return
@ -307,7 +308,7 @@ func (s *inboundPeeksStatements) DeleteInboundPeek(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteInboundPeekStmt, params, &rows) s.getPartitionKey(roomID), s.deleteInboundPeekStmt, params, &rows)
if err != nil { if err != nil {
return return
@ -339,7 +340,7 @@ func (s *inboundPeeksStatements) DeleteInboundPeeks(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteInboundPeekStmt, params, &rows) s.getPartitionKey(roomID), s.deleteInboundPeekStmt, params, &rows)
if err != nil { if err != nil {
return return

View file

@ -89,8 +89,9 @@ func (s *outboundPeeksStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
} }
func (s *outboundPeeksStatements) getPartitionKey() string { func (s *outboundPeeksStatements) getPartitionKey(roomId string) string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) uniqueId := roomId
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
} }
func getOutboundPeek(s *outboundPeeksStatements, ctx context.Context, pk string, docId string) (*outboundPeekCosmosData, error) { func getOutboundPeek(s *outboundPeeksStatements, ctx context.Context, pk string, docId string) (*outboundPeekCosmosData, error) {
@ -159,7 +160,7 @@ func (s *outboundPeeksStatements) InsertOutboundPeek(
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getOutboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) dbData, _ := getOutboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId)
if dbData != nil { if dbData != nil {
dbData.SetUpdateTime() dbData.SetUpdateTime()
dbData.OutboundPeek.RenewalInterval = renewalInterval dbData.OutboundPeek.RenewalInterval = renewalInterval
@ -176,7 +177,7 @@ func (s *outboundPeeksStatements) InsertOutboundPeek(
} }
dbData = &outboundPeekCosmosData{ dbData = &outboundPeekCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(roomID), cosmosDocId),
OutboundPeek: data, OutboundPeek: data,
} }
@ -205,7 +206,7 @@ func (s *outboundPeeksStatements) RenewOutboundPeek(
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
// _, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) // _, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
res, err := getOutboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) res, err := getOutboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId)
if err != nil { if err != nil {
return return
@ -233,7 +234,7 @@ func (s *outboundPeeksStatements) SelectOutboundPeek(
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
// row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID) // row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID)
row, err := getOutboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) row, err := getOutboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId)
if err != nil { if err != nil {
return nil, err return nil, err
@ -273,7 +274,7 @@ func (s *outboundPeeksStatements) SelectOutboundPeeks(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectOutboundPeeksStmt, params, &rows) s.getPartitionKey(roomID), s.selectOutboundPeeksStmt, params, &rows)
if err != nil { if err != nil {
return return
@ -311,7 +312,7 @@ func (s *outboundPeeksStatements) DeleteOutboundPeek(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteOutboundPeekStmt, params, &rows) s.getPartitionKey(roomID), s.deleteOutboundPeekStmt, params, &rows)
if err != nil { if err != nil {
return return
@ -344,7 +345,7 @@ func (s *outboundPeeksStatements) DeleteOutboundPeeks(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteOutboundPeeksStmt, params, &rows) s.getPartitionKey(roomID), s.deleteOutboundPeeksStmt, params, &rows)
if err != nil { if err != nil {
return return

View file

@ -2,6 +2,8 @@ package cosmosdbapi
import ( import (
"context" "context"
"errors"
"strings"
"time" "time"
cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi" cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi"
@ -55,9 +57,13 @@ func PerformQuery(ctx context.Context,
qryString string, qryString string,
params map[string]interface{}, params map[string]interface{},
response interface{}) error { response interface{}) error {
err := validateQuery(qryString)
if err != nil {
return err
}
optionsQry := GetQueryDocumentsOptions(partitonKey) optionsQry := GetQueryDocumentsOptions(partitonKey)
var query = GetQuery(qryString, params) var query = GetQuery(qryString, params)
_, err := GetClient(conn).QueryDocuments( _, err = GetClient(conn).QueryDocuments(
ctx, ctx,
databaseName, databaseName,
containerName, containerName,
@ -74,9 +80,13 @@ func PerformQueryAllPartitions(ctx context.Context,
qryString string, qryString string,
params map[string]interface{}, params map[string]interface{},
response interface{}) error { response interface{}) error {
err := validateQueryAllPartitions(qryString)
if err != nil {
return err
}
var optionsQry = GetQueryAllPartitionsDocumentsOptions() var optionsQry = GetQueryAllPartitionsDocumentsOptions()
var query = GetQuery(qryString, params) var query = GetQuery(qryString, params)
_, err := GetClient(conn).QueryDocuments( _, err = GetClient(conn).QueryDocuments(
ctx, ctx,
databaseName, databaseName,
containerName, containerName,
@ -130,3 +140,30 @@ func GetDocumentOrNil(connection CosmosConnection, config CosmosConfig, ctx cont
return nil return nil
} }
func validateQuery(qryString string) error {
if len(qryString) == 0 {
return errors.New("qryString was nil")
}
if !strings.Contains(qryString, " c._cn = ") {
return errors.New("qryString must contain [ c._cn = ] ")
}
return nil
}
func validateQueryAllPartitions(qryString string) error {
err := validateQuery(qryString)
if err != nil {
return err
}
if strings.Contains(qryString, " top ") {
return errors.New("qryString contains [ top ] ")
}
if strings.Contains(qryString, " order by ") {
return errors.New("qryString contains [ order by ] ")
}
if !strings.Contains(qryString, " c._sid = ") {
return errors.New("qryString must contain [ c._sid = ] ")
}
return nil
}

View file

@ -90,8 +90,9 @@ func (s PartitionOffsetStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.DatabaseName, tableName) return cosmosdbapi.GetCollectionName(s.db.DatabaseName, tableName)
} }
func (s *PartitionOffsetStatements) getPartitionKey() string { func (s *PartitionOffsetStatements) getPartitionKey(topic string) string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.CosmosConfig.TenantName, s.getCollectionName()) uniqueId := topic
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.CosmosConfig.TenantName, s.getCollectionName(), uniqueId)
} }
func getPartitionOffset(s *PartitionOffsetStatements, ctx context.Context, pk string, docId string) (*partitionOffsetCosmosData, error) { func getPartitionOffset(s *PartitionOffsetStatements, ctx context.Context, pk string, docId string) (*partitionOffsetCosmosData, error) {
@ -154,7 +155,7 @@ func (s *PartitionOffsetStatements) selectPartitionOffsets(
s.db.Connection, s.db.Connection,
s.db.CosmosConfig.DatabaseName, s.db.CosmosConfig.DatabaseName,
s.db.CosmosConfig.ContainerName, s.db.CosmosConfig.ContainerName,
s.getPartitionKey(), s.selectPartitionOffsetsStmt, params, &rows) s.getPartitionKey(topic), s.selectPartitionOffsetsStmt, params, &rows)
// rows, err := s.selectPartitionOffsetsStmt.QueryContext(ctx, topic) // rows, err := s.selectPartitionOffsetsStmt.QueryContext(ctx, topic)
if err != nil { if err != nil {
@ -195,7 +196,7 @@ func (s *PartitionOffsetStatements) upsertPartitionOffset(
docId := fmt.Sprintf("%s_%d", topic, partition) docId := fmt.Sprintf("%s_%d", topic, partition)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.CosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.CosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getPartitionOffset(s, ctx, s.getPartitionKey(), cosmosDocId) dbData, _ := getPartitionOffset(s, ctx, s.getPartitionKey(topic), cosmosDocId)
if dbData != nil { if dbData != nil {
dbData.SetUpdateTime() dbData.SetUpdateTime()
dbData.PartitionOffset.PartitionOffset = offset dbData.PartitionOffset.PartitionOffset = offset
@ -207,7 +208,7 @@ func (s *PartitionOffsetStatements) upsertPartitionOffset(
} }
dbData = &partitionOffsetCosmosData{ dbData = &partitionOffsetCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.CosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.CosmosConfig.TenantName, s.getPartitionKey(topic), cosmosDocId),
PartitionOffset: data, PartitionOffset: data,
} }

View file

@ -116,8 +116,9 @@ func (s *topicsStatements) getCollectionNameMessages() string {
return cosmosdbapi.GetCollectionName(s.DB.databaseName, s.tableNameMessages) return cosmosdbapi.GetCollectionName(s.DB.databaseName, s.tableNameMessages)
} }
func (s *topicsStatements) getPartitionKeyMessages() string { func (s *topicsStatements) getPartitionKeyMessages(topicNid int64) string {
return cosmosdbapi.GetPartitionKeyByCollection(s.DB.cosmosConfig.TenantName, s.getCollectionNameMessages()) uniqueId := fmt.Sprintf("%d", topicNid)
return cosmosdbapi.GetPartitionKeyByUniqueId(s.DB.cosmosConfig.TenantName, s.getCollectionNameMessages(), uniqueId)
} }
func getTopic(s *topicsStatements, ctx context.Context, pk string, docId string) (*topicCosmosData, error) { func getTopic(s *topicsStatements, ctx context.Context, pk string, docId string) (*topicCosmosData, error) {
@ -310,7 +311,7 @@ func (t *topicsStatements) InsertTopics(
} }
dbData := &messageCosmosData{ dbData := &messageCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(t.getCollectionNameMessages(), t.DB.cosmosConfig.TenantName, t.getPartitionKeyMessages(), cosmosDocId), CosmosDocument: cosmosdbapi.GenerateDocument(t.getCollectionNameMessages(), t.DB.cosmosConfig.TenantName, t.getPartitionKeyMessages(topicNID), cosmosDocId),
Message: data, Message: data,
} }
@ -348,7 +349,7 @@ func (t *topicsStatements) SelectMessages(
t.DB.connection, t.DB.connection,
t.DB.cosmosConfig.DatabaseName, t.DB.cosmosConfig.DatabaseName,
t.DB.cosmosConfig.ContainerName, t.DB.cosmosConfig.ContainerName,
t.getPartitionKeyMessages(), t.selectMessagesStmt, params, &rows) t.getPartitionKeyMessages(topicNID), t.selectMessagesStmt, params, &rows)
if err != nil { if err != nil {
return nil, err return nil, err
@ -387,7 +388,7 @@ func (t *topicsStatements) SelectMaxOffset(
t.DB.connection, t.DB.connection,
t.DB.cosmosConfig.DatabaseName, t.DB.cosmosConfig.DatabaseName,
t.DB.cosmosConfig.ContainerName, t.DB.cosmosConfig.ContainerName,
t.getPartitionKeyMessages(), t.selectMaxOffsetStmt, params, &rows) t.getPartitionKeyMessages(topicNID), t.selectMaxOffsetStmt, params, &rows)
if err != nil { if err != nil {
return 0, err return 0, err

View file

@ -66,8 +66,9 @@ func (s *crossSigningKeysStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
} }
func (s *crossSigningKeysStatements) getPartitionKey() string { func (s *crossSigningKeysStatements) getPartitionKey(userId string) string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) uniqueId := userId
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
} }
func getCrossSigningKeys(s *crossSigningKeysStatements, ctx context.Context, pk string, docId string) (*crossSigningKeysCosmosData, error) { func getCrossSigningKeys(s *crossSigningKeysStatements, ctx context.Context, pk string, docId string) (*crossSigningKeysCosmosData, error) {
@ -112,7 +113,7 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectCrossSigningKeysForUserStmt, params, &rows) s.getPartitionKey(userID), s.selectCrossSigningKeysForUserStmt, params, &rows)
// rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID) // rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
if err != nil { if err != nil {
@ -151,7 +152,7 @@ func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
docId := fmt.Sprintf("%s_%s", userID, keyType) docId := fmt.Sprintf("%s_%s", userID, keyType)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getCrossSigningKeys(s, ctx, s.getPartitionKey(), cosmosDocId) dbData, _ := getCrossSigningKeys(s, ctx, s.getPartitionKey(userID), cosmosDocId)
if dbData != nil { if dbData != nil {
dbData.SetUpdateTime() dbData.SetUpdateTime()
dbData.CrossSigningKeys.KeyData = keyData dbData.CrossSigningKeys.KeyData = keyData
@ -163,7 +164,7 @@ func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
} }
dbData = &crossSigningKeysCosmosData{ dbData = &crossSigningKeysCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(userID), cosmosDocId),
CrossSigningKeys: data, CrossSigningKeys: data,
} }
} }

View file

@ -78,8 +78,9 @@ func (s *crossSigningSigsStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
} }
func (s *crossSigningSigsStatements) getPartitionKey() string { func (s *crossSigningSigsStatements) getPartitionKey(targetUserId string) string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) uniqueId := targetUserId
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
} }
func getCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, pk string, docId string) (*crossSigningSigsCosmosData, error) { func getCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, pk string, docId string) (*crossSigningSigsCosmosData, error) {
@ -145,7 +146,7 @@ func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectCrossSigningSigsForTargetStmt, params, &rows) s.getPartitionKey(targetUserID), s.selectCrossSigningSigsForTargetStmt, params, &rows)
// rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID) // rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID)
if err != nil { if err != nil {
@ -185,7 +186,7 @@ func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
docId := fmt.Sprintf("%s_%s_%s", originUserID, targetUserID, targetKeyID) docId := fmt.Sprintf("%s_%s_%s", originUserID, targetUserID, targetKeyID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getCrossSigningSigs(s, ctx, s.getPartitionKey(), cosmosDocId) dbData, _ := getCrossSigningSigs(s, ctx, s.getPartitionKey(targetUserID), cosmosDocId)
if dbData != nil { if dbData != nil {
dbData.SetUpdateTime() dbData.SetUpdateTime()
dbData.CrossSigningSigs.OriginKeyId = string(originKeyID) dbData.CrossSigningSigs.OriginKeyId = string(originKeyID)
@ -200,7 +201,7 @@ func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
} }
dbData = &crossSigningSigsCosmosData{ dbData = &crossSigningSigsCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(targetUserID), cosmosDocId),
CrossSigningSigs: data, CrossSigningSigs: data,
} }
} }
@ -230,7 +231,7 @@ func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectCrossSigningSigsForTargetStmt, params, &rows) s.getPartitionKey(targetUserID), s.selectCrossSigningSigsForTargetStmt, params, &rows)
// if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil { // if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
// return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err) // return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)

View file

@ -168,8 +168,9 @@ func (s *deviceKeysStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
} }
func (s *deviceKeysStatements) getPartitionKey() string { func (s *deviceKeysStatements) getPartitionKey(userId string) string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) uniqueId := userId
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
} }
func NewCosmosDBDeviceKeysTable(db *Database) (tables.DeviceKeys, error) { func NewCosmosDBDeviceKeysTable(db *Database) (tables.DeviceKeys, error) {
@ -212,7 +213,7 @@ func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), selectAllDeviceKeysSQL, params, &rows) s.getPartitionKey(userID), selectAllDeviceKeysSQL, params, &rows)
if err != nil { if err != nil {
return err return err
@ -242,7 +243,7 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), selectAllDeviceKeysSQL, params, &rows) s.getPartitionKey(userID), selectAllDeviceKeysSQL, params, &rows)
if err != nil { if err != nil {
return err return err
@ -275,7 +276,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectBatchDeviceKeysStmt, params, &rows) s.getPartitionKey(userID), s.selectBatchDeviceKeysStmt, params, &rows)
// rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) // rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
if err != nil { if err != nil {
@ -327,7 +328,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
docId := fmt.Sprintf("%s_%s", key.UserID, key.DeviceID) docId := fmt.Sprintf("%s_%s", key.UserID, key.DeviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
response, err := getDeviceKey(s, ctx, s.getPartitionKey(), cosmosDocId) response, err := getDeviceKey(s, ctx, s.getPartitionKey(key.UserID), cosmosDocId)
if err != nil && err != cosmosdbutil.ErrNoRows { if err != nil && err != cosmosdbutil.ErrNoRows {
return err return err
@ -366,7 +367,7 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), selectMaxStreamForUserSQL, params, &rows) s.getPartitionKey(userID), selectMaxStreamForUserSQL, params, &rows)
if err != nil { if err != nil {
if err == cosmosdbutil.ErrNoRows { if err == cosmosdbutil.ErrNoRows {
@ -413,7 +414,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), countStreamIDsForUserSQL, params, &rows) s.getPartitionKey(userID), countStreamIDsForUserSQL, params, &rows)
if err != nil { if err != nil {
return 0, err return 0, err
@ -440,7 +441,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData := &deviceKeyCosmosData{ dbData := &deviceKeyCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(key.UserID), cosmosDocId),
DeviceKey: mapFromDeviceKeyMessage(key), DeviceKey: mapFromDeviceKeyMessage(key),
} }

View file

@ -106,8 +106,9 @@ func (s *oneTimeKeysStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
} }
func (s *oneTimeKeysStatements) getPartitionKey() string { func (s *oneTimeKeysStatements) getPartitionKey(userId string) string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) uniqueId := userId
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
} }
func getOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, pk string, docId string) (*oneTimeKeyCosmosData, error) { func getOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, pk string, docId string) (*oneTimeKeyCosmosData, error) {
@ -194,7 +195,7 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectKeyByAlgorithmStmt, params, &rows) s.getPartitionKey(userID), s.selectKeyByAlgorithmStmt, params, &rows)
if err != nil { if err != nil {
return nil, err return nil, err
@ -239,7 +240,7 @@ func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, de
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectKeysCountStmt, params, &rows) s.getPartitionKey(counts.UserID), s.selectKeysCountStmt, params, &rows)
if err != nil { if err != nil {
return nil, err return nil, err
@ -286,7 +287,7 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(
} }
dbData := &oneTimeKeyCosmosData{ dbData := &oneTimeKeyCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(counts.UserID), cosmosDocId),
OneTimeKey: data, OneTimeKey: data,
} }
@ -309,7 +310,7 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectKeysCountStmt, params, &rows) s.getPartitionKey(keys.UserID), s.selectKeysCountStmt, params, &rows)
if err != nil { if err != nil {
return nil, err return nil, err
@ -346,7 +347,7 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectKeyByAlgorithmStmt, params, &rows) s.getPartitionKey(userID), s.selectKeyByAlgorithmStmt, params, &rows)
if err != nil { if err != nil {
if err == cosmosdbutil.ErrNoRows { if err == cosmosdbutil.ErrNoRows {

View file

@ -99,6 +99,7 @@ func (s *mediaStatements) getCollectionName() string {
} }
func (s *mediaStatements) getPartitionKey() string { func (s *mediaStatements) getPartitionKey() string {
//No easy PK, so just use the collection
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
} }

View file

@ -89,8 +89,9 @@ func (s *thumbnailStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
} }
func (s *thumbnailStatements) getPartitionKey() string { func (s *thumbnailStatements) getPartitionKey(mediaId string) string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) uniqueId := mediaId
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
} }
func getThumbnail(s *thumbnailStatements, ctx context.Context, pk string, docId string) (*thumbnailCosmosData, error) { func getThumbnail(s *thumbnailStatements, ctx context.Context, pk string, docId string) (*thumbnailCosmosData, error) {
@ -163,7 +164,7 @@ func (s *thumbnailStatements) insertThumbnail(
} }
dbData := &thumbnailCosmosData{ dbData := &thumbnailCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(data.MediaID), cosmosDocId),
Thumbnail: data, Thumbnail: data,
} }
@ -209,7 +210,7 @@ func (s *thumbnailStatements) selectThumbnail(
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
// row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID) // row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID)
row, err := getThumbnail(s, ctx, s.getPartitionKey(), cosmosDocId) row, err := getThumbnail(s, ctx, s.getPartitionKey(string(mediaID)), cosmosDocId)
if err != nil { if err != nil {
return nil, err return nil, err
@ -250,7 +251,7 @@ func (s *thumbnailStatements) selectThumbnails(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectThumbnailsStmt, params, &rows) s.getPartitionKey(string(mediaID)), s.selectThumbnailsStmt, params, &rows)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -70,6 +70,7 @@ func (s *eventJSONStatements) getCollectionName() string {
} }
func (s *eventJSONStatements) getPartitionKey() string { func (s *eventJSONStatements) getPartitionKey() string {
//No easy PK, so just use the collection
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
} }

View file

@ -178,6 +178,7 @@ func (s *eventStatements) getCollectionName() string {
} }
func (s *eventStatements) getPartitionKey() string { func (s *eventStatements) getPartitionKey() string {
//No easy PK, so just use the collection
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
} }

View file

@ -18,6 +18,7 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil" "github.com/matrix-org/dendrite/internal/cosmosdbutil"
@ -97,8 +98,9 @@ func (s *inviteStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
} }
func (s *inviteStatements) getPartitionKey() string { func (s *inviteStatements) getPartitionKey(roomNId int64) string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) uniqueId := fmt.Sprintf("%d", roomNId)
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
} }
func getInvite(s *inviteStatements, ctx context.Context, pk string, docId string) (*inviteCosmosData, error) { func getInvite(s *inviteStatements, ctx context.Context, pk string, docId string) (*inviteCosmosData, error) {
@ -169,7 +171,7 @@ func (s *inviteStatements) InsertInviteEvent(
} }
var dbData = inviteCosmosData{ var dbData = inviteCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(int64(roomNID)), cosmosDocId),
Invite: data, Invite: data,
} }
@ -211,7 +213,7 @@ func (s *inviteStatements) UpdateInviteRetired(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectInvitesAboutToRetireStmt, params, &rows) s.getPartitionKey(int64(roomNID)), s.selectInvitesAboutToRetireStmt, params, &rows)
if err != nil { if err != nil {
return return
@ -248,7 +250,7 @@ func (s *inviteStatements) SelectInviteActiveForUserInRoom(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectInviteActiveForUserInRoomStmt, params, &rows) s.getPartitionKey(int64(roomNID)), s.selectInviteActiveForUserInRoomStmt, params, &rows)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err

View file

@ -89,8 +89,9 @@ func (s *previousEventStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
} }
func (s *previousEventStatements) getPartitionKey() string { func (s *previousEventStatements) getPartitionKey(previousEventId string) string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) uniqueId := previousEventId
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
} }
func getPreviousEvent(s *previousEventStatements, ctx context.Context, pk string, docId string) (*previousEventCosmosData, error) { func getPreviousEvent(s *previousEventStatements, ctx context.Context, pk string, docId string) (*previousEventCosmosData, error) {
@ -141,7 +142,7 @@ func (s *previousEventStatements) InsertPreviousEvent(
// SELECT 1 FROM roomserver_previous_events // SELECT 1 FROM roomserver_previous_events
// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 // WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
existing, err := getPreviousEvent(s, ctx, s.getPartitionKey(), cosmosDocId) existing, err := getPreviousEvent(s, ctx, s.getPartitionKey(previousEventID), cosmosDocId)
if err != nil { if err != nil {
if err != cosmosdbutil.ErrNoRows { if err != cosmosdbutil.ErrNoRows {
@ -159,7 +160,7 @@ func (s *previousEventStatements) InsertPreviousEvent(
} }
dbData = previousEventCosmosData{ dbData = previousEventCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(previousEventID), cosmosDocId),
PreviousEvent: data, PreviousEvent: data,
} }
} else { } else {
@ -206,7 +207,7 @@ func (s *previousEventStatements) SelectPreviousEventExists(
// SELECT 1 FROM roomserver_previous_events // SELECT 1 FROM roomserver_previous_events
// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 // WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
dbData, err := getPreviousEvent(s, ctx, s.getPartitionKey(), cosmosDocId) dbData, err := getPreviousEvent(s, ctx, s.getPartitionKey(eventID), cosmosDocId)
if err != nil { if err != nil {
return err return err
} }

View file

@ -180,7 +180,6 @@ func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
response, err := getRedaction(s, ctx, s.getPartitionKey(), cosmosDocId) response, err := getRedaction(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil { if err != nil {
info = nil info = nil
err = err
return return
} }

View file

@ -68,8 +68,9 @@ func (s *transactionStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
} }
func (s *transactionStatements) getPartitionKey() string { func (s *transactionStatements) getPartitionKey(transactionID string) string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) uniqueId := transactionID
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
} }
func getTransaction(s *transactionStatements, ctx context.Context, pk string, docId string) (*transactionCosmosData, error) { func getTransaction(s *transactionStatements, ctx context.Context, pk string, docId string) (*transactionCosmosData, error) {
@ -124,7 +125,7 @@ func (s *transactionStatements) InsertTransaction(
} }
var dbData = transactionCosmosData{ var dbData = transactionCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(transactionID), cosmosDocId),
Transaction: data, Transaction: data,
} }
@ -153,7 +154,7 @@ func (s *transactionStatements) SelectTransactionEventID(
docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID) docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
response, err := getTransaction(s, ctx, s.getPartitionKey(), cosmosDocId) response, err := getTransaction(s, ctx, s.getPartitionKey(transactionID), cosmosDocId)
if err != nil { if err != nil {
return "", err return "", err

View file

@ -72,7 +72,8 @@ const selectAccountDataInRangeSQL = "" +
// "SELECT MAX(id) FROM syncapi_account_data_type" // "SELECT MAX(id) FROM syncapi_account_data_type"
const selectMaxAccountDataIDSQL = "" + const selectMaxAccountDataIDSQL = "" +
"select max(c.mx_syncapi_account_data_type.id) as number from c where c._cn = @x1 " "select max(c.mx_syncapi_account_data_type.id) as number from c where c._cn = @x1 " +
"and c._sid = @x2 "
type accountDataStatements struct { type accountDataStatements struct {
db *SyncServerDatasource db *SyncServerDatasource
@ -88,6 +89,7 @@ func (s *accountDataStatements) getCollectionName() string {
} }
func (s *accountDataStatements) getPartitionKey() string { func (s *accountDataStatements) getPartitionKey() string {
//No easy PK, so just use the collection
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
} }
@ -245,6 +247,7 @@ func (s *accountDataStatements) SelectMaxAccountDataID(
// err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID) // err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID)
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": s.getCollectionName(), "@x1": s.getCollectionName(),
"@x2": s.db.cosmosConfig.TenantName,
} }
var rows []AccountDataTypeNumberCosmosData var rows []AccountDataTypeNumberCosmosData

View file

@ -82,8 +82,9 @@ func (s *backwardExtremitiesStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
} }
func (s *backwardExtremitiesStatements) getPartitionKey() string { func (s *backwardExtremitiesStatements) getPartitionKey(roomId string) string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) uniqueId := roomId
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
} }
func getBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, pk string, docId string) (*backwardExtremityCosmosData, error) { func getBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, pk string, docId string) (*backwardExtremityCosmosData, error) {
@ -143,7 +144,7 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
docId := fmt.Sprintf("%s_%s_%s", roomID, eventID, prevEventID) docId := fmt.Sprintf("%s_%s_%s", roomID, eventID, prevEventID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getBackwardExtremity(s, ctx, s.getPartitionKey(), cosmosDocId) dbData, _ := getBackwardExtremity(s, ctx, s.getPartitionKey(roomID), cosmosDocId)
if dbData != nil { if dbData != nil {
dbData.SetUpdateTime() dbData.SetUpdateTime()
} else { } else {
@ -154,7 +155,7 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
} }
dbData = &backwardExtremityCosmosData{ dbData = &backwardExtremityCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(roomID), cosmosDocId),
BackwardExtremity: data, BackwardExtremity: data,
} }
} }
@ -185,7 +186,7 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectBackwardExtremitiesForRoomStmt, params, &rows) s.getPartitionKey(roomID), s.selectBackwardExtremitiesForRoomStmt, params, &rows)
if err != nil { if err != nil {
return return
@ -221,7 +222,7 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteBackwardExtremityStmt, params, &rows) s.getPartitionKey(roomID), s.deleteBackwardExtremityStmt, params, &rows)
if err != nil { if err != nil {
return return
@ -251,7 +252,7 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteBackwardExtremitiesForRoomStmt, params, &rows) s.getPartitionKey(roomID), s.deleteBackwardExtremitiesForRoomStmt, params, &rows)
if err != nil { if err != nil {
return return

View file

@ -83,7 +83,8 @@ const selectInviteEventsInRangeSQL = "" +
// "SELECT MAX(id) FROM syncapi_invite_events" // "SELECT MAX(id) FROM syncapi_invite_events"
const selectMaxInviteIDSQL = "" + const selectMaxInviteIDSQL = "" +
"select max(c.mx_syncapi_invite_event.id) from c where c._cn = @x1 " "select max(c.mx_syncapi_invite_event.id) from c where c._cn = @x1 " +
"and c._sid = @x2 "
type inviteEventsStatements struct { type inviteEventsStatements struct {
db *SyncServerDatasource db *SyncServerDatasource
@ -306,7 +307,9 @@ func (s *inviteEventsStatements) SelectMaxInviteID(
// err = stmt.QueryRowContext(ctx).Scan(&nullableID) // err = stmt.QueryRowContext(ctx).Scan(&nullableID)
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": s.getCollectionName(), "@x1": s.getCollectionName(),
"@x2": s.db.cosmosConfig.TenantName,
} }
var rows []inviteEventCosmosMaxNumber var rows []inviteEventCosmosMaxNumber
err = cosmosdbapi.PerformQueryAllPartitions(ctx, err = cosmosdbapi.PerformQueryAllPartitions(ctx,
s.db.connection, s.db.connection,

View file

@ -115,7 +115,8 @@ const selectEarlyEventsSQL = "" +
// "SELECT MAX(id) FROM syncapi_output_room_events" // "SELECT MAX(id) FROM syncapi_output_room_events"
const selectMaxEventIDSQL = "" + const selectMaxEventIDSQL = "" +
"select max(c.mx_syncapi_output_room_event.id) as number from c where c._cn = @x1 " "select max(c.mx_syncapi_output_room_event.id) as number from c where c._cn = @x1 " +
"and c._sid = @x2 "
// "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2" // "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2"
const updateEventJSONSQL = "" + const updateEventJSONSQL = "" +
@ -155,6 +156,7 @@ func (s *outputRoomEventsStatements) getCollectionName() string {
} }
func (s *outputRoomEventsStatements) getPartitionKey() string { func (s *outputRoomEventsStatements) getPartitionKey() string {
//No easy PK, so just use the collection
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
} }
@ -349,6 +351,7 @@ func (s *outputRoomEventsStatements) SelectMaxEventID(
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": s.getCollectionName(), "@x1": s.getCollectionName(),
"@x2": s.db.cosmosConfig.TenantName,
} }
// stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt) // stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt)
var rows []outputRoomEventCosmosMaxNumber var rows []outputRoomEventCosmosMaxNumber

View file

@ -131,6 +131,7 @@ func (s *outputRoomEventsTopologyStatements) getCollectionName() string {
} }
func (s *outputRoomEventsTopologyStatements) getPartitionKey() string { func (s *outputRoomEventsTopologyStatements) getPartitionKey() string {
//No easy PK, so just use the collection
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
} }

View file

@ -76,7 +76,8 @@ const selectRoomReceipts = "" +
// "SELECT MAX(id) FROM syncapi_receipts" // "SELECT MAX(id) FROM syncapi_receipts"
const selectMaxReceiptIDSQL = "" + const selectMaxReceiptIDSQL = "" +
"select max(c.mx_syncapi_receipt.id) as number from c where c._cn = @x1 " "select max(c.mx_syncapi_receipt.id) as number from c where c._cn = @x1 " +
"and c._sid = @x2 "
type receiptStatements struct { type receiptStatements struct {
db *SyncServerDatasource db *SyncServerDatasource
@ -209,6 +210,7 @@ func (s *receiptStatements) SelectMaxReceiptID(
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": s.getCollectionName(), "@x1": s.getCollectionName(),
"@x2": s.db.cosmosConfig.TenantName,
} }
var rows []receiptCosmosMaxNumber var rows []receiptCosmosMaxNumber
err = cosmosdbapi.PerformQueryAllPartitions(ctx, err = cosmosdbapi.PerformQueryAllPartitions(ctx,

View file

@ -83,7 +83,8 @@ const deleteSendToDeviceMessagesSQL = "" +
// "SELECT MAX(id) FROM syncapi_send_to_device" // "SELECT MAX(id) FROM syncapi_send_to_device"
const selectMaxSendToDeviceIDSQL = "" + const selectMaxSendToDeviceIDSQL = "" +
"select max(c.mx_syncapi_send_to_device.id) as number from c where c._cn = @x1 " "select max(c.mx_syncapi_send_to_device.id) as number from c where c._cn = @x1 " +
"and c._sid = @x2 "
type sendToDeviceStatements struct { type sendToDeviceStatements struct {
db *SyncServerDatasource db *SyncServerDatasource
@ -273,6 +274,7 @@ func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID(
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": s.getCollectionName(), "@x1": s.getCollectionName(),
"@x2": s.db.cosmosConfig.TenantName,
} }
var rows []SendToDeviceCosmosMaxNumber var rows []SendToDeviceCosmosMaxNumber
err = cosmosdbapi.PerformQueryAllPartitions(ctx, err = cosmosdbapi.PerformQueryAllPartitions(ctx,

View file

@ -62,8 +62,9 @@ func (s *accountDataStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
} }
func (s *accountDataStatements) getPartitionKey() string { func (s *accountDataStatements) getPartitionKey(localPart string) string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) uniqueId := localPart
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
} }
func (s *accountDataStatements) prepare(db *Database) (err error) { func (s *accountDataStatements) prepare(db *Database) (err error) {
@ -107,7 +108,7 @@ func (s *accountDataStatements) insertAccountData(
docId := id docId := id
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getAccountData(s, ctx, s.getPartitionKey(), cosmosDocId) dbData, _ := getAccountData(s, ctx, s.getPartitionKey(localpart), cosmosDocId)
if dbData != nil { if dbData != nil {
// ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 // ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
dbData.SetUpdateTime() dbData.SetUpdateTime()
@ -121,7 +122,7 @@ func (s *accountDataStatements) insertAccountData(
} }
dbData = &accountDataCosmosData{ dbData = &accountDataCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(localpart), cosmosDocId),
AccountData: result, AccountData: result,
} }
} }
@ -151,7 +152,7 @@ func (s *accountDataStatements) selectAccountData(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectAccountDataStmt, params, &rows) s.getPartitionKey(localpart), s.selectAccountDataStmt, params, &rows)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -193,7 +194,7 @@ func (s *accountDataStatements) selectAccountDataByType(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectAccountDataByTypeStmt, params, &rows) s.getPartitionKey(localpart), s.selectAccountDataByTypeStmt, params, &rows)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -114,8 +114,9 @@ func (s *keyBackupStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
} }
func (s *keyBackupStatements) getPartitionKey() string { func (s *keyBackupStatements) getPartitionKey(userId string) string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) uniqueId := userId
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
} }
func getKeyBackup(s *keyBackupStatements, ctx context.Context, pk string, docId string) (*keyBackupCosmosData, error) { func getKeyBackup(s *keyBackupStatements, ctx context.Context, pk string, docId string) (*keyBackupCosmosData, error) {
@ -176,7 +177,7 @@ func (s keyBackupStatements) countKeys(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.countKeysStmt, params, &rows) s.getPartitionKey(userID), s.countKeysStmt, params, &rows)
if err != nil { if err != nil {
return -1, err return -1, err
@ -214,7 +215,7 @@ func (s *keyBackupStatements) insertBackupKey(
} }
dbData := &keyBackupCosmosData{ dbData := &keyBackupCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(userID), cosmosDocId),
KeyBackup: data, KeyBackup: data,
} }
@ -242,7 +243,7 @@ func (s *keyBackupStatements) updateBackupKey(
docId := fmt.Sprintf("%s_%s_%s_%s", userID, key.RoomID, key.SessionID, version) docId := fmt.Sprintf("%s_%s_%s_%s", userID, key.RoomID, key.SessionID, version)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
res, err := getKeyBackup(s, ctx, s.getPartitionKey(), cosmosDocId) res, err := getKeyBackup(s, ctx, s.getPartitionKey(userID), cosmosDocId)
if err != nil { if err != nil {
return return
@ -278,7 +279,7 @@ func (s *keyBackupStatements) selectKeys(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectKeysStmt, params, &rows) s.getPartitionKey(userID), s.selectKeysStmt, params, &rows)
if err != nil { if err != nil {
return nil, err return nil, err
@ -308,7 +309,7 @@ func (s *keyBackupStatements) selectKeysByRoomID(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectKeysByRoomIDStmt, params, &rows) s.getPartitionKey(userID), s.selectKeysByRoomIDStmt, params, &rows)
if err != nil { if err != nil {
return nil, err return nil, err
@ -341,7 +342,7 @@ func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectKeysByRoomIDAndSessionIDStmt, params, &rows) s.getPartitionKey(userID), s.selectKeysByRoomIDAndSessionIDStmt, params, &rows)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -95,8 +95,9 @@ func (s *keyBackupVersionStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
} }
func (s *keyBackupVersionStatements) getPartitionKey() string { func (s *keyBackupVersionStatements) getPartitionKey(userId string) string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) uniqueId := userId
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
} }
func getKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, pk string, docId string) (*keyBackupVersionCosmosData, error) { func getKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, pk string, docId string) (*keyBackupVersionCosmosData, error) {
@ -168,7 +169,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup(
} }
dbData := &keyBackupVersionCosmosData{ dbData := &keyBackupVersionCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(userID), cosmosDocId),
KeyBackupVersion: data, KeyBackupVersion: data,
} }
@ -195,7 +196,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
docId := fmt.Sprintf("%s_%d", userID, versionInt) docId := fmt.Sprintf("%s_%d", userID, versionInt)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId) res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId)
if err != nil { if err != nil {
return err return err
@ -225,7 +226,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag(
docId := fmt.Sprintf("%s_%d", userID, versionInt) docId := fmt.Sprintf("%s_%d", userID, versionInt)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId) res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId)
if err != nil { if err != nil {
return err return err
@ -255,7 +256,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup(
docId := fmt.Sprintf("%s_%d", userID, versionInt) docId := fmt.Sprintf("%s_%d", userID, versionInt)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId) res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId)
if err != nil { if err != nil {
return false, err return false, err
@ -324,7 +325,7 @@ func (s *keyBackupVersionStatements) selectKeyBackup(
docId := fmt.Sprintf("%s_%d", userID, versionInt) docId := fmt.Sprintf("%s_%d", userID, versionInt)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId) res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId)
if err != nil { if err != nil {
return return

View file

@ -169,11 +169,11 @@ func (s *profilesStatements) selectProfileByLocalpart(
} }
if len(rows) == 0 { if len(rows) == 0 {
return nil, errors.New(fmt.Sprintf("Localpart %s not found", len(rows))) return nil, errors.New(fmt.Sprintf("Localpart %d not found", len(rows)))
} }
if len(rows) != 1 { if len(rows) != 1 {
return nil, errors.New(fmt.Sprintf("Localpart %s has multiple entries", len(rows))) return nil, errors.New(fmt.Sprintf("Localpart %d has multiple entries", len(rows)))
} }
var result = mapFromProfile(rows[0].Profile) var result = mapFromProfile(rows[0].Profile)

View file

@ -95,6 +95,7 @@ func (s *devicesStatements) getCollectionName() string {
} }
func (s *devicesStatements) getPartitionKey() string { func (s *devicesStatements) getPartitionKey() string {
//No easy PK, so just use the collection
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
} }

View file

@ -122,9 +122,6 @@ func (d *Database) CreateDevice(
var err error var err error
dev, err = d.devices.insertDevice(ctx, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) dev, err = d.devices.insertDevice(ctx, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return dev, err return dev, err
if returnErr == nil {
return
}
} }
} }
return return