mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-29 01:33:10 -06:00
Add UniqueId to PartitionKey for some Dendrite tables (where possible) (#19)
* - Make all PartitionKeys include the tablename - Update specific PKs to be item specific - Add validation to the PerformQueryXX methods - Fix queries that fail validation * - Revert the PK back to CollectionName as it already includes the TableName Co-authored-by: alexf@example.com <alexf@example.com>
This commit is contained in:
parent
927238a686
commit
3088238419
|
|
@ -92,8 +92,9 @@ func (s *inboundPeeksStatements) getCollectionName() string {
|
|||
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) getPartitionKey() string {
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
func (s *inboundPeeksStatements) getPartitionKey(roomId string) string {
|
||||
uniqueId := roomId
|
||||
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
|
||||
}
|
||||
|
||||
func getInboundPeek(s *inboundPeeksStatements, ctx context.Context, pk string, docId string) (*inboundPeekCosmosData, error) {
|
||||
|
|
@ -163,7 +164,7 @@ func (s *inboundPeeksStatements) InsertInboundPeek(
|
|||
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
dbData, _ := getInboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
dbData, _ := getInboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId)
|
||||
if dbData != nil {
|
||||
dbData.SetUpdateTime()
|
||||
dbData.InboundPeek.RenewedTimestamp = nowMilli
|
||||
|
|
@ -179,7 +180,7 @@ func (s *inboundPeeksStatements) InsertInboundPeek(
|
|||
}
|
||||
|
||||
dbData = &inboundPeekCosmosData{
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(roomID), cosmosDocId),
|
||||
InboundPeek: data,
|
||||
}
|
||||
}
|
||||
|
|
@ -208,7 +209,7 @@ func (s *inboundPeeksStatements) RenewInboundPeek(
|
|||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
// _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
|
||||
res, err := getInboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
res, err := getInboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -233,10 +234,10 @@ func (s *inboundPeeksStatements) SelectInboundPeek(
|
|||
// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||
// UNIQUE (room_id, server_name, peek_id)
|
||||
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getPartitionKey(), docId)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getPartitionKey(roomID), docId)
|
||||
|
||||
// row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID)
|
||||
row, err := getInboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
row, err := getInboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId)
|
||||
|
||||
if row == nil {
|
||||
return nil, nil
|
||||
|
|
@ -270,7 +271,7 @@ func (s *inboundPeeksStatements) SelectInboundPeeks(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectInboundPeeksStmt, params, &rows)
|
||||
s.getPartitionKey(roomID), s.selectInboundPeeksStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -307,7 +308,7 @@ func (s *inboundPeeksStatements) DeleteInboundPeek(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.deleteInboundPeekStmt, params, &rows)
|
||||
s.getPartitionKey(roomID), s.deleteInboundPeekStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -339,7 +340,7 @@ func (s *inboundPeeksStatements) DeleteInboundPeeks(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.deleteInboundPeekStmt, params, &rows)
|
||||
s.getPartitionKey(roomID), s.deleteInboundPeekStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
|
|||
|
|
@ -89,8 +89,9 @@ func (s *outboundPeeksStatements) getCollectionName() string {
|
|||
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) getPartitionKey() string {
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
func (s *outboundPeeksStatements) getPartitionKey(roomId string) string {
|
||||
uniqueId := roomId
|
||||
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
|
||||
}
|
||||
|
||||
func getOutboundPeek(s *outboundPeeksStatements, ctx context.Context, pk string, docId string) (*outboundPeekCosmosData, error) {
|
||||
|
|
@ -159,7 +160,7 @@ func (s *outboundPeeksStatements) InsertOutboundPeek(
|
|||
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
dbData, _ := getOutboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
dbData, _ := getOutboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId)
|
||||
if dbData != nil {
|
||||
dbData.SetUpdateTime()
|
||||
dbData.OutboundPeek.RenewalInterval = renewalInterval
|
||||
|
|
@ -176,7 +177,7 @@ func (s *outboundPeeksStatements) InsertOutboundPeek(
|
|||
}
|
||||
|
||||
dbData = &outboundPeekCosmosData{
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(roomID), cosmosDocId),
|
||||
OutboundPeek: data,
|
||||
}
|
||||
|
||||
|
|
@ -205,7 +206,7 @@ func (s *outboundPeeksStatements) RenewOutboundPeek(
|
|||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
// _, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
|
||||
res, err := getOutboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
res, err := getOutboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -233,7 +234,7 @@ func (s *outboundPeeksStatements) SelectOutboundPeek(
|
|||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
// row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID)
|
||||
row, err := getOutboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
row, err := getOutboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -273,7 +274,7 @@ func (s *outboundPeeksStatements) SelectOutboundPeeks(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectOutboundPeeksStmt, params, &rows)
|
||||
s.getPartitionKey(roomID), s.selectOutboundPeeksStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -311,7 +312,7 @@ func (s *outboundPeeksStatements) DeleteOutboundPeek(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.deleteOutboundPeekStmt, params, &rows)
|
||||
s.getPartitionKey(roomID), s.deleteOutboundPeekStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -344,7 +345,7 @@ func (s *outboundPeeksStatements) DeleteOutboundPeeks(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.deleteOutboundPeeksStmt, params, &rows)
|
||||
s.getPartitionKey(roomID), s.deleteOutboundPeeksStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ package cosmosdbapi
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi"
|
||||
|
|
@ -55,9 +57,13 @@ func PerformQuery(ctx context.Context,
|
|||
qryString string,
|
||||
params map[string]interface{},
|
||||
response interface{}) error {
|
||||
err := validateQuery(qryString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
optionsQry := GetQueryDocumentsOptions(partitonKey)
|
||||
var query = GetQuery(qryString, params)
|
||||
_, err := GetClient(conn).QueryDocuments(
|
||||
_, err = GetClient(conn).QueryDocuments(
|
||||
ctx,
|
||||
databaseName,
|
||||
containerName,
|
||||
|
|
@ -74,9 +80,13 @@ func PerformQueryAllPartitions(ctx context.Context,
|
|||
qryString string,
|
||||
params map[string]interface{},
|
||||
response interface{}) error {
|
||||
err := validateQueryAllPartitions(qryString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var optionsQry = GetQueryAllPartitionsDocumentsOptions()
|
||||
var query = GetQuery(qryString, params)
|
||||
_, err := GetClient(conn).QueryDocuments(
|
||||
_, err = GetClient(conn).QueryDocuments(
|
||||
ctx,
|
||||
databaseName,
|
||||
containerName,
|
||||
|
|
@ -130,3 +140,30 @@ func GetDocumentOrNil(connection CosmosConnection, config CosmosConfig, ctx cont
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateQuery(qryString string) error {
|
||||
if len(qryString) == 0 {
|
||||
return errors.New("qryString was nil")
|
||||
}
|
||||
if !strings.Contains(qryString, " c._cn = ") {
|
||||
return errors.New("qryString must contain [ c._cn = ] ")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateQueryAllPartitions(qryString string) error {
|
||||
err := validateQuery(qryString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.Contains(qryString, " top ") {
|
||||
return errors.New("qryString contains [ top ] ")
|
||||
}
|
||||
if strings.Contains(qryString, " order by ") {
|
||||
return errors.New("qryString contains [ order by ] ")
|
||||
}
|
||||
if !strings.Contains(qryString, " c._sid = ") {
|
||||
return errors.New("qryString must contain [ c._sid = ] ")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -90,8 +90,9 @@ func (s PartitionOffsetStatements) getCollectionName() string {
|
|||
return cosmosdbapi.GetCollectionName(s.db.DatabaseName, tableName)
|
||||
}
|
||||
|
||||
func (s *PartitionOffsetStatements) getPartitionKey() string {
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.CosmosConfig.TenantName, s.getCollectionName())
|
||||
func (s *PartitionOffsetStatements) getPartitionKey(topic string) string {
|
||||
uniqueId := topic
|
||||
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.CosmosConfig.TenantName, s.getCollectionName(), uniqueId)
|
||||
}
|
||||
|
||||
func getPartitionOffset(s *PartitionOffsetStatements, ctx context.Context, pk string, docId string) (*partitionOffsetCosmosData, error) {
|
||||
|
|
@ -154,7 +155,7 @@ func (s *PartitionOffsetStatements) selectPartitionOffsets(
|
|||
s.db.Connection,
|
||||
s.db.CosmosConfig.DatabaseName,
|
||||
s.db.CosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectPartitionOffsetsStmt, params, &rows)
|
||||
s.getPartitionKey(topic), s.selectPartitionOffsetsStmt, params, &rows)
|
||||
|
||||
// rows, err := s.selectPartitionOffsetsStmt.QueryContext(ctx, topic)
|
||||
if err != nil {
|
||||
|
|
@ -195,7 +196,7 @@ func (s *PartitionOffsetStatements) upsertPartitionOffset(
|
|||
docId := fmt.Sprintf("%s_%d", topic, partition)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.CosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
dbData, _ := getPartitionOffset(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
dbData, _ := getPartitionOffset(s, ctx, s.getPartitionKey(topic), cosmosDocId)
|
||||
if dbData != nil {
|
||||
dbData.SetUpdateTime()
|
||||
dbData.PartitionOffset.PartitionOffset = offset
|
||||
|
|
@ -207,7 +208,7 @@ func (s *PartitionOffsetStatements) upsertPartitionOffset(
|
|||
}
|
||||
|
||||
dbData = &partitionOffsetCosmosData{
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.CosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.CosmosConfig.TenantName, s.getPartitionKey(topic), cosmosDocId),
|
||||
PartitionOffset: data,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -116,8 +116,9 @@ func (s *topicsStatements) getCollectionNameMessages() string {
|
|||
return cosmosdbapi.GetCollectionName(s.DB.databaseName, s.tableNameMessages)
|
||||
}
|
||||
|
||||
func (s *topicsStatements) getPartitionKeyMessages() string {
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.DB.cosmosConfig.TenantName, s.getCollectionNameMessages())
|
||||
func (s *topicsStatements) getPartitionKeyMessages(topicNid int64) string {
|
||||
uniqueId := fmt.Sprintf("%d", topicNid)
|
||||
return cosmosdbapi.GetPartitionKeyByUniqueId(s.DB.cosmosConfig.TenantName, s.getCollectionNameMessages(), uniqueId)
|
||||
}
|
||||
|
||||
func getTopic(s *topicsStatements, ctx context.Context, pk string, docId string) (*topicCosmosData, error) {
|
||||
|
|
@ -310,7 +311,7 @@ func (t *topicsStatements) InsertTopics(
|
|||
}
|
||||
|
||||
dbData := &messageCosmosData{
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(t.getCollectionNameMessages(), t.DB.cosmosConfig.TenantName, t.getPartitionKeyMessages(), cosmosDocId),
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(t.getCollectionNameMessages(), t.DB.cosmosConfig.TenantName, t.getPartitionKeyMessages(topicNID), cosmosDocId),
|
||||
Message: data,
|
||||
}
|
||||
|
||||
|
|
@ -348,7 +349,7 @@ func (t *topicsStatements) SelectMessages(
|
|||
t.DB.connection,
|
||||
t.DB.cosmosConfig.DatabaseName,
|
||||
t.DB.cosmosConfig.ContainerName,
|
||||
t.getPartitionKeyMessages(), t.selectMessagesStmt, params, &rows)
|
||||
t.getPartitionKeyMessages(topicNID), t.selectMessagesStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -387,7 +388,7 @@ func (t *topicsStatements) SelectMaxOffset(
|
|||
t.DB.connection,
|
||||
t.DB.cosmosConfig.DatabaseName,
|
||||
t.DB.cosmosConfig.ContainerName,
|
||||
t.getPartitionKeyMessages(), t.selectMaxOffsetStmt, params, &rows)
|
||||
t.getPartitionKeyMessages(topicNID), t.selectMaxOffsetStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
|
|
|||
|
|
@ -66,8 +66,9 @@ func (s *crossSigningKeysStatements) getCollectionName() string {
|
|||
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
}
|
||||
|
||||
func (s *crossSigningKeysStatements) getPartitionKey() string {
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
func (s *crossSigningKeysStatements) getPartitionKey(userId string) string {
|
||||
uniqueId := userId
|
||||
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
|
||||
}
|
||||
|
||||
func getCrossSigningKeys(s *crossSigningKeysStatements, ctx context.Context, pk string, docId string) (*crossSigningKeysCosmosData, error) {
|
||||
|
|
@ -112,7 +113,7 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectCrossSigningKeysForUserStmt, params, &rows)
|
||||
s.getPartitionKey(userID), s.selectCrossSigningKeysForUserStmt, params, &rows)
|
||||
|
||||
// rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
|
|
@ -151,7 +152,7 @@ func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
|
|||
docId := fmt.Sprintf("%s_%s", userID, keyType)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
dbData, _ := getCrossSigningKeys(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
dbData, _ := getCrossSigningKeys(s, ctx, s.getPartitionKey(userID), cosmosDocId)
|
||||
if dbData != nil {
|
||||
dbData.SetUpdateTime()
|
||||
dbData.CrossSigningKeys.KeyData = keyData
|
||||
|
|
@ -163,7 +164,7 @@ func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
|
|||
}
|
||||
|
||||
dbData = &crossSigningKeysCosmosData{
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(userID), cosmosDocId),
|
||||
CrossSigningKeys: data,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -78,8 +78,9 @@ func (s *crossSigningSigsStatements) getCollectionName() string {
|
|||
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
}
|
||||
|
||||
func (s *crossSigningSigsStatements) getPartitionKey() string {
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
func (s *crossSigningSigsStatements) getPartitionKey(targetUserId string) string {
|
||||
uniqueId := targetUserId
|
||||
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
|
||||
}
|
||||
|
||||
func getCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, pk string, docId string) (*crossSigningSigsCosmosData, error) {
|
||||
|
|
@ -145,7 +146,7 @@ func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectCrossSigningSigsForTargetStmt, params, &rows)
|
||||
s.getPartitionKey(targetUserID), s.selectCrossSigningSigsForTargetStmt, params, &rows)
|
||||
|
||||
// rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID)
|
||||
if err != nil {
|
||||
|
|
@ -185,7 +186,7 @@ func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
|
|||
docId := fmt.Sprintf("%s_%s_%s", originUserID, targetUserID, targetKeyID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
dbData, _ := getCrossSigningSigs(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
dbData, _ := getCrossSigningSigs(s, ctx, s.getPartitionKey(targetUserID), cosmosDocId)
|
||||
if dbData != nil {
|
||||
dbData.SetUpdateTime()
|
||||
dbData.CrossSigningSigs.OriginKeyId = string(originKeyID)
|
||||
|
|
@ -200,7 +201,7 @@ func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
|
|||
}
|
||||
|
||||
dbData = &crossSigningSigsCosmosData{
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(targetUserID), cosmosDocId),
|
||||
CrossSigningSigs: data,
|
||||
}
|
||||
}
|
||||
|
|
@ -230,7 +231,7 @@ func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectCrossSigningSigsForTargetStmt, params, &rows)
|
||||
s.getPartitionKey(targetUserID), s.selectCrossSigningSigsForTargetStmt, params, &rows)
|
||||
|
||||
// if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
|
||||
// return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
|
||||
|
|
|
|||
|
|
@ -168,8 +168,9 @@ func (s *deviceKeysStatements) getCollectionName() string {
|
|||
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) getPartitionKey() string {
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
func (s *deviceKeysStatements) getPartitionKey(userId string) string {
|
||||
uniqueId := userId
|
||||
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
|
||||
}
|
||||
|
||||
func NewCosmosDBDeviceKeysTable(db *Database) (tables.DeviceKeys, error) {
|
||||
|
|
@ -212,7 +213,7 @@ func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), selectAllDeviceKeysSQL, params, &rows)
|
||||
s.getPartitionKey(userID), selectAllDeviceKeysSQL, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -242,7 +243,7 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), selectAllDeviceKeysSQL, params, &rows)
|
||||
s.getPartitionKey(userID), selectAllDeviceKeysSQL, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -275,7 +276,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectBatchDeviceKeysStmt, params, &rows)
|
||||
s.getPartitionKey(userID), s.selectBatchDeviceKeysStmt, params, &rows)
|
||||
|
||||
// rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
|
|
@ -327,7 +328,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
|
|||
docId := fmt.Sprintf("%s_%s", key.UserID, key.DeviceID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
response, err := getDeviceKey(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
response, err := getDeviceKey(s, ctx, s.getPartitionKey(key.UserID), cosmosDocId)
|
||||
|
||||
if err != nil && err != cosmosdbutil.ErrNoRows {
|
||||
return err
|
||||
|
|
@ -366,7 +367,7 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), selectMaxStreamForUserSQL, params, &rows)
|
||||
s.getPartitionKey(userID), selectMaxStreamForUserSQL, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
if err == cosmosdbutil.ErrNoRows {
|
||||
|
|
@ -413,7 +414,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), countStreamIDsForUserSQL, params, &rows)
|
||||
s.getPartitionKey(userID), countStreamIDsForUserSQL, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
|
@ -440,7 +441,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
|
|||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
dbData := &deviceKeyCosmosData{
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(key.UserID), cosmosDocId),
|
||||
DeviceKey: mapFromDeviceKeyMessage(key),
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -106,8 +106,9 @@ func (s *oneTimeKeysStatements) getCollectionName() string {
|
|||
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) getPartitionKey() string {
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
func (s *oneTimeKeysStatements) getPartitionKey(userId string) string {
|
||||
uniqueId := userId
|
||||
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
|
||||
}
|
||||
|
||||
func getOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, pk string, docId string) (*oneTimeKeyCosmosData, error) {
|
||||
|
|
@ -194,7 +195,7 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectKeyByAlgorithmStmt, params, &rows)
|
||||
s.getPartitionKey(userID), s.selectKeyByAlgorithmStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -239,7 +240,7 @@ func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, de
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectKeysCountStmt, params, &rows)
|
||||
s.getPartitionKey(counts.UserID), s.selectKeysCountStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -286,7 +287,7 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(
|
|||
}
|
||||
|
||||
dbData := &oneTimeKeyCosmosData{
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(counts.UserID), cosmosDocId),
|
||||
OneTimeKey: data,
|
||||
}
|
||||
|
||||
|
|
@ -309,7 +310,7 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectKeysCountStmt, params, &rows)
|
||||
s.getPartitionKey(keys.UserID), s.selectKeysCountStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -346,7 +347,7 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectKeyByAlgorithmStmt, params, &rows)
|
||||
s.getPartitionKey(userID), s.selectKeyByAlgorithmStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
if err == cosmosdbutil.ErrNoRows {
|
||||
|
|
|
|||
|
|
@ -99,6 +99,7 @@ func (s *mediaStatements) getCollectionName() string {
|
|||
}
|
||||
|
||||
func (s *mediaStatements) getPartitionKey() string {
|
||||
//No easy PK, so just use the collection
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -89,8 +89,9 @@ func (s *thumbnailStatements) getCollectionName() string {
|
|||
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
}
|
||||
|
||||
func (s *thumbnailStatements) getPartitionKey() string {
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
func (s *thumbnailStatements) getPartitionKey(mediaId string) string {
|
||||
uniqueId := mediaId
|
||||
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
|
||||
}
|
||||
|
||||
func getThumbnail(s *thumbnailStatements, ctx context.Context, pk string, docId string) (*thumbnailCosmosData, error) {
|
||||
|
|
@ -163,7 +164,7 @@ func (s *thumbnailStatements) insertThumbnail(
|
|||
}
|
||||
|
||||
dbData := &thumbnailCosmosData{
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(data.MediaID), cosmosDocId),
|
||||
Thumbnail: data,
|
||||
}
|
||||
|
||||
|
|
@ -209,7 +210,7 @@ func (s *thumbnailStatements) selectThumbnail(
|
|||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
// row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID)
|
||||
row, err := getThumbnail(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
row, err := getThumbnail(s, ctx, s.getPartitionKey(string(mediaID)), cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -250,7 +251,7 @@ func (s *thumbnailStatements) selectThumbnails(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectThumbnailsStmt, params, &rows)
|
||||
s.getPartitionKey(string(mediaID)), s.selectThumbnailsStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ func (s *eventJSONStatements) getCollectionName() string {
|
|||
}
|
||||
|
||||
func (s *eventJSONStatements) getPartitionKey() string {
|
||||
//No easy PK, so just use the collection
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -178,6 +178,7 @@ func (s *eventStatements) getCollectionName() string {
|
|||
}
|
||||
|
||||
func (s *eventStatements) getPartitionKey() string {
|
||||
//No easy PK, so just use the collection
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ package cosmosdb
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||
|
|
@ -97,8 +98,9 @@ func (s *inviteStatements) getCollectionName() string {
|
|||
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
}
|
||||
|
||||
func (s *inviteStatements) getPartitionKey() string {
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
func (s *inviteStatements) getPartitionKey(roomNId int64) string {
|
||||
uniqueId := fmt.Sprintf("%d", roomNId)
|
||||
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
|
||||
}
|
||||
|
||||
func getInvite(s *inviteStatements, ctx context.Context, pk string, docId string) (*inviteCosmosData, error) {
|
||||
|
|
@ -169,7 +171,7 @@ func (s *inviteStatements) InsertInviteEvent(
|
|||
}
|
||||
|
||||
var dbData = inviteCosmosData{
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(int64(roomNID)), cosmosDocId),
|
||||
Invite: data,
|
||||
}
|
||||
|
||||
|
|
@ -211,7 +213,7 @@ func (s *inviteStatements) UpdateInviteRetired(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectInvitesAboutToRetireStmt, params, &rows)
|
||||
s.getPartitionKey(int64(roomNID)), s.selectInvitesAboutToRetireStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -248,7 +250,7 @@ func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectInviteActiveForUserInRoomStmt, params, &rows)
|
||||
s.getPartitionKey(int64(roomNID)), s.selectInviteActiveForUserInRoomStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
|
|
|||
|
|
@ -89,8 +89,9 @@ func (s *previousEventStatements) getCollectionName() string {
|
|||
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
}
|
||||
|
||||
func (s *previousEventStatements) getPartitionKey() string {
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
func (s *previousEventStatements) getPartitionKey(previousEventId string) string {
|
||||
uniqueId := previousEventId
|
||||
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
|
||||
}
|
||||
|
||||
func getPreviousEvent(s *previousEventStatements, ctx context.Context, pk string, docId string) (*previousEventCosmosData, error) {
|
||||
|
|
@ -141,7 +142,7 @@ func (s *previousEventStatements) InsertPreviousEvent(
|
|||
|
||||
// SELECT 1 FROM roomserver_previous_events
|
||||
// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
|
||||
existing, err := getPreviousEvent(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
existing, err := getPreviousEvent(s, ctx, s.getPartitionKey(previousEventID), cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
if err != cosmosdbutil.ErrNoRows {
|
||||
|
|
@ -159,7 +160,7 @@ func (s *previousEventStatements) InsertPreviousEvent(
|
|||
}
|
||||
|
||||
dbData = previousEventCosmosData{
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(previousEventID), cosmosDocId),
|
||||
PreviousEvent: data,
|
||||
}
|
||||
} else {
|
||||
|
|
@ -206,7 +207,7 @@ func (s *previousEventStatements) SelectPreviousEventExists(
|
|||
|
||||
// SELECT 1 FROM roomserver_previous_events
|
||||
// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
|
||||
dbData, err := getPreviousEvent(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
dbData, err := getPreviousEvent(s, ctx, s.getPartitionKey(eventID), cosmosDocId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -180,7 +180,6 @@ func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
|
|||
response, err := getRedaction(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
if err != nil {
|
||||
info = nil
|
||||
err = err
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -68,8 +68,9 @@ func (s *transactionStatements) getCollectionName() string {
|
|||
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
}
|
||||
|
||||
func (s *transactionStatements) getPartitionKey() string {
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
func (s *transactionStatements) getPartitionKey(transactionID string) string {
|
||||
uniqueId := transactionID
|
||||
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
|
||||
}
|
||||
|
||||
func getTransaction(s *transactionStatements, ctx context.Context, pk string, docId string) (*transactionCosmosData, error) {
|
||||
|
|
@ -124,7 +125,7 @@ func (s *transactionStatements) InsertTransaction(
|
|||
}
|
||||
|
||||
var dbData = transactionCosmosData{
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(transactionID), cosmosDocId),
|
||||
Transaction: data,
|
||||
}
|
||||
|
||||
|
|
@ -153,7 +154,7 @@ func (s *transactionStatements) SelectTransactionEventID(
|
|||
docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
response, err := getTransaction(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
response, err := getTransaction(s, ctx, s.getPartitionKey(transactionID), cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
|
|||
|
|
@ -72,7 +72,8 @@ const selectAccountDataInRangeSQL = "" +
|
|||
|
||||
// "SELECT MAX(id) FROM syncapi_account_data_type"
|
||||
const selectMaxAccountDataIDSQL = "" +
|
||||
"select max(c.mx_syncapi_account_data_type.id) as number from c where c._cn = @x1 "
|
||||
"select max(c.mx_syncapi_account_data_type.id) as number from c where c._cn = @x1 " +
|
||||
"and c._sid = @x2 "
|
||||
|
||||
type accountDataStatements struct {
|
||||
db *SyncServerDatasource
|
||||
|
|
@ -88,6 +89,7 @@ func (s *accountDataStatements) getCollectionName() string {
|
|||
}
|
||||
|
||||
func (s *accountDataStatements) getPartitionKey() string {
|
||||
//No easy PK, so just use the collection
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
}
|
||||
|
||||
|
|
@ -245,6 +247,7 @@ func (s *accountDataStatements) SelectMaxAccountDataID(
|
|||
// err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID)
|
||||
params := map[string]interface{}{
|
||||
"@x1": s.getCollectionName(),
|
||||
"@x2": s.db.cosmosConfig.TenantName,
|
||||
}
|
||||
|
||||
var rows []AccountDataTypeNumberCosmosData
|
||||
|
|
|
|||
|
|
@ -82,8 +82,9 @@ func (s *backwardExtremitiesStatements) getCollectionName() string {
|
|||
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
}
|
||||
|
||||
func (s *backwardExtremitiesStatements) getPartitionKey() string {
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
func (s *backwardExtremitiesStatements) getPartitionKey(roomId string) string {
|
||||
uniqueId := roomId
|
||||
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
|
||||
}
|
||||
|
||||
func getBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, pk string, docId string) (*backwardExtremityCosmosData, error) {
|
||||
|
|
@ -143,7 +144,7 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
|
|||
docId := fmt.Sprintf("%s_%s_%s", roomID, eventID, prevEventID)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
dbData, _ := getBackwardExtremity(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
dbData, _ := getBackwardExtremity(s, ctx, s.getPartitionKey(roomID), cosmosDocId)
|
||||
if dbData != nil {
|
||||
dbData.SetUpdateTime()
|
||||
} else {
|
||||
|
|
@ -154,7 +155,7 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
|
|||
}
|
||||
|
||||
dbData = &backwardExtremityCosmosData{
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(roomID), cosmosDocId),
|
||||
BackwardExtremity: data,
|
||||
}
|
||||
}
|
||||
|
|
@ -185,7 +186,7 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectBackwardExtremitiesForRoomStmt, params, &rows)
|
||||
s.getPartitionKey(roomID), s.selectBackwardExtremitiesForRoomStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -221,7 +222,7 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.deleteBackwardExtremityStmt, params, &rows)
|
||||
s.getPartitionKey(roomID), s.deleteBackwardExtremityStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -251,7 +252,7 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.deleteBackwardExtremitiesForRoomStmt, params, &rows)
|
||||
s.getPartitionKey(roomID), s.deleteBackwardExtremitiesForRoomStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
|
|||
|
|
@ -83,7 +83,8 @@ const selectInviteEventsInRangeSQL = "" +
|
|||
|
||||
// "SELECT MAX(id) FROM syncapi_invite_events"
|
||||
const selectMaxInviteIDSQL = "" +
|
||||
"select max(c.mx_syncapi_invite_event.id) from c where c._cn = @x1 "
|
||||
"select max(c.mx_syncapi_invite_event.id) from c where c._cn = @x1 " +
|
||||
"and c._sid = @x2 "
|
||||
|
||||
type inviteEventsStatements struct {
|
||||
db *SyncServerDatasource
|
||||
|
|
@ -306,7 +307,9 @@ func (s *inviteEventsStatements) SelectMaxInviteID(
|
|||
// err = stmt.QueryRowContext(ctx).Scan(&nullableID)
|
||||
params := map[string]interface{}{
|
||||
"@x1": s.getCollectionName(),
|
||||
"@x2": s.db.cosmosConfig.TenantName,
|
||||
}
|
||||
|
||||
var rows []inviteEventCosmosMaxNumber
|
||||
err = cosmosdbapi.PerformQueryAllPartitions(ctx,
|
||||
s.db.connection,
|
||||
|
|
|
|||
|
|
@ -115,7 +115,8 @@ const selectEarlyEventsSQL = "" +
|
|||
|
||||
// "SELECT MAX(id) FROM syncapi_output_room_events"
|
||||
const selectMaxEventIDSQL = "" +
|
||||
"select max(c.mx_syncapi_output_room_event.id) as number from c where c._cn = @x1 "
|
||||
"select max(c.mx_syncapi_output_room_event.id) as number from c where c._cn = @x1 " +
|
||||
"and c._sid = @x2 "
|
||||
|
||||
// "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2"
|
||||
const updateEventJSONSQL = "" +
|
||||
|
|
@ -155,6 +156,7 @@ func (s *outputRoomEventsStatements) getCollectionName() string {
|
|||
}
|
||||
|
||||
func (s *outputRoomEventsStatements) getPartitionKey() string {
|
||||
//No easy PK, so just use the collection
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
}
|
||||
|
||||
|
|
@ -349,6 +351,7 @@ func (s *outputRoomEventsStatements) SelectMaxEventID(
|
|||
|
||||
params := map[string]interface{}{
|
||||
"@x1": s.getCollectionName(),
|
||||
"@x2": s.db.cosmosConfig.TenantName,
|
||||
}
|
||||
// stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt)
|
||||
var rows []outputRoomEventCosmosMaxNumber
|
||||
|
|
|
|||
|
|
@ -131,6 +131,7 @@ func (s *outputRoomEventsTopologyStatements) getCollectionName() string {
|
|||
}
|
||||
|
||||
func (s *outputRoomEventsTopologyStatements) getPartitionKey() string {
|
||||
//No easy PK, so just use the collection
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -76,7 +76,8 @@ const selectRoomReceipts = "" +
|
|||
|
||||
// "SELECT MAX(id) FROM syncapi_receipts"
|
||||
const selectMaxReceiptIDSQL = "" +
|
||||
"select max(c.mx_syncapi_receipt.id) as number from c where c._cn = @x1 "
|
||||
"select max(c.mx_syncapi_receipt.id) as number from c where c._cn = @x1 " +
|
||||
"and c._sid = @x2 "
|
||||
|
||||
type receiptStatements struct {
|
||||
db *SyncServerDatasource
|
||||
|
|
@ -209,6 +210,7 @@ func (s *receiptStatements) SelectMaxReceiptID(
|
|||
|
||||
params := map[string]interface{}{
|
||||
"@x1": s.getCollectionName(),
|
||||
"@x2": s.db.cosmosConfig.TenantName,
|
||||
}
|
||||
var rows []receiptCosmosMaxNumber
|
||||
err = cosmosdbapi.PerformQueryAllPartitions(ctx,
|
||||
|
|
|
|||
|
|
@ -83,7 +83,8 @@ const deleteSendToDeviceMessagesSQL = "" +
|
|||
|
||||
// "SELECT MAX(id) FROM syncapi_send_to_device"
|
||||
const selectMaxSendToDeviceIDSQL = "" +
|
||||
"select max(c.mx_syncapi_send_to_device.id) as number from c where c._cn = @x1 "
|
||||
"select max(c.mx_syncapi_send_to_device.id) as number from c where c._cn = @x1 " +
|
||||
"and c._sid = @x2 "
|
||||
|
||||
type sendToDeviceStatements struct {
|
||||
db *SyncServerDatasource
|
||||
|
|
@ -273,6 +274,7 @@ func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID(
|
|||
|
||||
params := map[string]interface{}{
|
||||
"@x1": s.getCollectionName(),
|
||||
"@x2": s.db.cosmosConfig.TenantName,
|
||||
}
|
||||
var rows []SendToDeviceCosmosMaxNumber
|
||||
err = cosmosdbapi.PerformQueryAllPartitions(ctx,
|
||||
|
|
|
|||
|
|
@ -62,8 +62,9 @@ func (s *accountDataStatements) getCollectionName() string {
|
|||
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) getPartitionKey() string {
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
func (s *accountDataStatements) getPartitionKey(localPart string) string {
|
||||
uniqueId := localPart
|
||||
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) prepare(db *Database) (err error) {
|
||||
|
|
@ -107,7 +108,7 @@ func (s *accountDataStatements) insertAccountData(
|
|||
docId := id
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
dbData, _ := getAccountData(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
dbData, _ := getAccountData(s, ctx, s.getPartitionKey(localpart), cosmosDocId)
|
||||
if dbData != nil {
|
||||
// ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
|
||||
dbData.SetUpdateTime()
|
||||
|
|
@ -121,7 +122,7 @@ func (s *accountDataStatements) insertAccountData(
|
|||
}
|
||||
|
||||
dbData = &accountDataCosmosData{
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(localpart), cosmosDocId),
|
||||
AccountData: result,
|
||||
}
|
||||
}
|
||||
|
|
@ -151,7 +152,7 @@ func (s *accountDataStatements) selectAccountData(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectAccountDataStmt, params, &rows)
|
||||
s.getPartitionKey(localpart), s.selectAccountDataStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
|
@ -193,7 +194,7 @@ func (s *accountDataStatements) selectAccountDataByType(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectAccountDataByTypeStmt, params, &rows)
|
||||
s.getPartitionKey(localpart), s.selectAccountDataByTypeStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -114,8 +114,9 @@ func (s *keyBackupStatements) getCollectionName() string {
|
|||
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
}
|
||||
|
||||
func (s *keyBackupStatements) getPartitionKey() string {
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
func (s *keyBackupStatements) getPartitionKey(userId string) string {
|
||||
uniqueId := userId
|
||||
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
|
||||
}
|
||||
|
||||
func getKeyBackup(s *keyBackupStatements, ctx context.Context, pk string, docId string) (*keyBackupCosmosData, error) {
|
||||
|
|
@ -176,7 +177,7 @@ func (s keyBackupStatements) countKeys(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.countKeysStmt, params, &rows)
|
||||
s.getPartitionKey(userID), s.countKeysStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return -1, err
|
||||
|
|
@ -214,7 +215,7 @@ func (s *keyBackupStatements) insertBackupKey(
|
|||
}
|
||||
|
||||
dbData := &keyBackupCosmosData{
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(userID), cosmosDocId),
|
||||
KeyBackup: data,
|
||||
}
|
||||
|
||||
|
|
@ -242,7 +243,7 @@ func (s *keyBackupStatements) updateBackupKey(
|
|||
docId := fmt.Sprintf("%s_%s_%s_%s", userID, key.RoomID, key.SessionID, version)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
res, err := getKeyBackup(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
res, err := getKeyBackup(s, ctx, s.getPartitionKey(userID), cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -278,7 +279,7 @@ func (s *keyBackupStatements) selectKeys(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectKeysStmt, params, &rows)
|
||||
s.getPartitionKey(userID), s.selectKeysStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -308,7 +309,7 @@ func (s *keyBackupStatements) selectKeysByRoomID(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectKeysByRoomIDStmt, params, &rows)
|
||||
s.getPartitionKey(userID), s.selectKeysByRoomIDStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -341,7 +342,7 @@ func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
|
|||
s.db.connection,
|
||||
s.db.cosmosConfig.DatabaseName,
|
||||
s.db.cosmosConfig.ContainerName,
|
||||
s.getPartitionKey(), s.selectKeysByRoomIDAndSessionIDStmt, params, &rows)
|
||||
s.getPartitionKey(userID), s.selectKeysByRoomIDAndSessionIDStmt, params, &rows)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -95,8 +95,9 @@ func (s *keyBackupVersionStatements) getCollectionName() string {
|
|||
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||
}
|
||||
|
||||
func (s *keyBackupVersionStatements) getPartitionKey() string {
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
func (s *keyBackupVersionStatements) getPartitionKey(userId string) string {
|
||||
uniqueId := userId
|
||||
return cosmosdbapi.GetPartitionKeyByUniqueId(s.db.cosmosConfig.TenantName, s.getCollectionName(), uniqueId)
|
||||
}
|
||||
|
||||
func getKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, pk string, docId string) (*keyBackupVersionCosmosData, error) {
|
||||
|
|
@ -168,7 +169,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup(
|
|||
}
|
||||
|
||||
dbData := &keyBackupVersionCosmosData{
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
|
||||
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(userID), cosmosDocId),
|
||||
KeyBackupVersion: data,
|
||||
}
|
||||
|
||||
|
|
@ -195,7 +196,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
|
|||
docId := fmt.Sprintf("%s_%d", userID, versionInt)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -225,7 +226,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag(
|
|||
docId := fmt.Sprintf("%s_%d", userID, versionInt)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -255,7 +256,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup(
|
|||
docId := fmt.Sprintf("%s_%d", userID, versionInt)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
|
@ -324,7 +325,7 @@ func (s *keyBackupVersionStatements) selectKeyBackup(
|
|||
docId := fmt.Sprintf("%s_%d", userID, versionInt)
|
||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
|
||||
|
||||
res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId)
|
||||
res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
|
|||
|
|
@ -169,11 +169,11 @@ func (s *profilesStatements) selectProfileByLocalpart(
|
|||
}
|
||||
|
||||
if len(rows) == 0 {
|
||||
return nil, errors.New(fmt.Sprintf("Localpart %s not found", len(rows)))
|
||||
return nil, errors.New(fmt.Sprintf("Localpart %d not found", len(rows)))
|
||||
}
|
||||
|
||||
if len(rows) != 1 {
|
||||
return nil, errors.New(fmt.Sprintf("Localpart %s has multiple entries", len(rows)))
|
||||
return nil, errors.New(fmt.Sprintf("Localpart %d has multiple entries", len(rows)))
|
||||
}
|
||||
|
||||
var result = mapFromProfile(rows[0].Profile)
|
||||
|
|
|
|||
|
|
@ -95,6 +95,7 @@ func (s *devicesStatements) getCollectionName() string {
|
|||
}
|
||||
|
||||
func (s *devicesStatements) getPartitionKey() string {
|
||||
//No easy PK, so just use the collection
|
||||
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -122,9 +122,6 @@ func (d *Database) CreateDevice(
|
|||
var err error
|
||||
dev, err = d.devices.insertDevice(ctx, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
return dev, err
|
||||
if returnErr == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
|
|
|
|||
Loading…
Reference in a new issue