- Remove PerformQueryAllPartitions as it does not support aggreates (#24)

- Update queries to all use PartitionKeys
- Remove the _sid from queries as the PK contains the Tenant
- Fix some bugs around empty values and ordering

Co-authored-by: alexf@example.com <alexf@example.com>
This commit is contained in:
alexfca 2021-10-08 11:17:22 +11:00 committed by GitHub
parent 49f8c7fe38
commit 60e11f88b8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 71 additions and 125 deletions

View file

@ -84,35 +84,6 @@ func PerformQuery(ctx context.Context,
return err return err
} }
func PerformQueryAllPartitions(ctx context.Context,
conn CosmosConnection,
databaseName string,
containerName string,
qryString string,
params map[string]interface{},
response interface{}) error {
err := validateQueryAllPartitions(qryString)
if err != nil {
return err
}
var optionsQry = getQueryAllPartitionsDocumentsOptions()
var query = getQuery(qryString, params)
_, err = GetClient(conn).QueryDocuments(
ctx,
databaseName,
containerName,
query,
&response,
optionsQry)
// When there are no Rows we seem to get the generic Bad Req JSON error
if err != nil {
// return nil, err
}
return nil
}
func GenerateDocument( func GenerateDocument(
collection string, collection string,
tenantName string, tenantName string,
@ -181,20 +152,3 @@ func validateQuery(qryString string) error {
} }
return nil return nil
} }
func validateQueryAllPartitions(qryString string) error {
err := validateQuery(qryString)
if err != nil {
return err
}
if strings.Contains(qryString, " top ") {
return errors.New("qryString contains [ top ] ")
}
if strings.Contains(qryString, " order by ") {
return errors.New("qryString contains [ order by ] ")
}
if !strings.Contains(qryString, " c._sid = ") {
return errors.New("qryString must contain [ c._sid = ] ")
}
return nil
}

View file

@ -26,14 +26,6 @@ func getQueryDocumentsOptions(pk string) cosmosapi.QueryDocumentsOptions {
} }
} }
func getQueryAllPartitionsDocumentsOptions() cosmosapi.QueryDocumentsOptions {
return cosmosapi.QueryDocumentsOptions{
IsQuery: true,
EnableCrossPartition: true,
ContentType: cosmosapi.QUERY_CONTENT_TYPE,
}
}
func GetGetDocumentOptions(pk string) cosmosapi.GetDocumentOptions { func GetGetDocumentOptions(pk string) cosmosapi.GetDocumentOptions {
return cosmosapi.GetDocumentOptions{ return cosmosapi.GetDocumentOptions{
PartitionKeyValue: pk, PartitionKeyValue: pk,

View file

@ -75,14 +75,14 @@ const selectBatchDeviceKeysSQL = "" +
// "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" // "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" + const selectMaxStreamForUserSQL = "" +
"select max(c.mx_keyserver_device_key.stream_id) as number from c where c._sid = @x1 and c._cn = @x2 " + "select max(c.mx_keyserver_device_key.stream_id) as number from c where c._cn = @x1 " +
"and c.mx_keyserver_device_key.user_id = @x3 " "and c.mx_keyserver_device_key.user_id = @x2 "
// "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" // "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
const countStreamIDsForUserSQL = "" + const countStreamIDsForUserSQL = "" +
"select count(c._ts) as number from c where c._sid = @x1 and c._cn = @x2 " + "select count(c._ts) as number from c where c._cn = @x1 " +
"and c.mx_keyserver_device_key.user_id = @x3 " + "and c.mx_keyserver_device_key.user_id = @x2 " +
"and ARRAY_CONTAINS(@x4, c.mx_keyserver_device_key.stream_id) " "and ARRAY_CONTAINS(@x3, c.mx_keyserver_device_key.stream_id) "
const selectAllDeviceKeysSQL = "" + const selectAllDeviceKeysSQL = "" +
"select * from c where c._cn = @x1 " + "select * from c where c._cn = @x1 " +
@ -356,9 +356,8 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
// "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" // "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": s.db.cosmosConfig.TenantName, "@x1": s.getCollectionName(),
"@x2": s.getCollectionName(), "@x2": userID,
"@x3": userID,
} }
// err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) // err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
@ -398,10 +397,9 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
} }
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": s.db.cosmosConfig.TenantName, "@x1": s.getCollectionName(),
"@x2": s.getCollectionName(), "@x2": userID,
"@x3": userID, "@x3": iStreamIDs,
"@x4": iStreamIDs,
} }
// query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1) // query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)

View file

@ -65,10 +65,10 @@ type keyChangeCosmosData struct {
// "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id" // "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id"
const selectKeyChangesSQL = "" + const selectKeyChangesSQL = "" +
"select c.mx_keyserver_key_change.user_id as user_id, max(c.mx_keyserver_key_change._offset) as max_offset " + "select c.mx_keyserver_key_change.user_id as user_id, max(c.mx_keyserver_key_change._offset) as max_offset " +
"from c where c._sid = @x1 and c._cn = @x2 " + "from c where c._cn = @x1 " +
"and c.mx_keyserver_key_change.partition = @x3 " + "and c.mx_keyserver_key_change.partition = @x2 " +
"and c.mx_keyserver_key_change._offset > @x4 " + "and c.mx_keyserver_key_change._offset > @x3 " +
"and c.mx_keyserver_key_change._offset < @x5 " + "and c.mx_keyserver_key_change._offset < @x4 " +
"group by c.mx_keyserver_key_change.user_id " "group by c.mx_keyserver_key_change.user_id "
type keyChangesStatements struct { type keyChangesStatements struct {
@ -161,18 +161,18 @@ func (s *keyChangesStatements) SelectKeyChanges(
// rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset) // rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": s.db.cosmosConfig.TenantName, "@x1": s.getCollectionName(),
"@x2": s.getCollectionName(), "@x2": partition,
"@x3": partition, "@x3": fromOffset,
"@x4": fromOffset, "@x4": toOffset,
"@x5": toOffset,
} }
var rows []keyChangeUserMaxCosmosData var rows []keyChangeUserMaxCosmosData
err = cosmosdbapi.PerformQueryAllPartitions(ctx, err = cosmosdbapi.PerformQuery(ctx,
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(),
s.selectKeyChangesStmt, params, &rows) s.selectKeyChangesStmt, params, &rows)
if err != nil { if err != nil {

View file

@ -55,9 +55,9 @@ type staleDeviceListCosmosData struct {
// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" // "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
const selectStaleDeviceListsWithDomainsSQL = "" + const selectStaleDeviceListsWithDomainsSQL = "" +
"select * from c where c._sid = @x1 and c._cn = @x2 " + "select * from c where c._cn = @x1 " +
"and c.mx_keyserver_stale_device_list.is_stale = @x3 " + "and c.mx_keyserver_stale_device_list.is_stale = @x2 " +
"and c.mx_keyserver_stale_device_list.domain = @x4 " "and c.mx_keyserver_stale_device_list.domain = @x3 "
// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" // "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
const selectStaleDeviceListsSQL = "" + const selectStaleDeviceListsSQL = "" +
@ -156,9 +156,8 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" // "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
// rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) // rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": s.db.cosmosConfig.TenantName, "@x1": s.getCollectionName(),
"@x2": s.getCollectionName(), "@x2": true,
"@x3": true,
} }
var rows []staleDeviceListCosmosData var rows []staleDeviceListCosmosData

View file

@ -146,9 +146,8 @@ const bulkSelectEventNIDSQL = "" +
// "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" // "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
const selectMaxEventDepthSQL = "" + const selectMaxEventDepthSQL = "" +
"select sub.maxinner != null ? sub.maxinner + 1 : 0 as maxdepth from " + "select MAX(c.mx_roomserver_event.depth) maxdepth from c where c._cn = @x1 " +
"(select MAX(c.mx_roomserver_event.depth) maxinner from c where c._sid = @x1 and c._cn = @x2 " + " and ARRAY_CONTAINS(@x2, c.mx_roomserver_event.event_nid)"
" and ARRAY_CONTAINS(@x3, c.mx_roomserver_event.event_nid)) sub"
// "SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid IN ($1)" // "SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid IN ($1)"
const selectRoomNIDsForEventNIDsSQL = "" + const selectRoomNIDsForEventNIDsSQL = "" +
@ -877,22 +876,29 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
// "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" // "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": s.db.cosmosConfig.TenantName, "@x1": s.getCollectionName(),
"@x2": s.getCollectionName(), "@x2": eventNIDs,
"@x3": eventNIDs,
} }
var rows []eventCosmosMaxDepth var rows []eventCosmosMaxDepth
err := cosmosdbapi.PerformQueryAllPartitions(ctx, err := cosmosdbapi.PerformQuery(ctx,
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(),
selectMaxEventDepthSQL, params, &rows) selectMaxEventDepthSQL, params, &rows)
if err != nil { if err != nil {
return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err) return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err)
} }
return rows[0].Max, nil if len(rows) == 0 {
return 0, nil
}
result := rows[0].Max
if result == 0 {
return 0, nil
}
return result + 1, nil
} }
func (s *eventStatements) SelectRoomNIDsForEventNIDs( func (s *eventStatements) SelectRoomNIDsForEventNIDs(

View file

@ -72,8 +72,7 @@ const selectAccountDataInRangeSQL = "" +
// "SELECT MAX(id) FROM syncapi_account_data_type" // "SELECT MAX(id) FROM syncapi_account_data_type"
const selectMaxAccountDataIDSQL = "" + const selectMaxAccountDataIDSQL = "" +
"select max(c.mx_syncapi_account_data_type.id) as number from c where c._cn = @x1 " + "select max(c.mx_syncapi_account_data_type.id) as number from c where c._cn = @x1 "
"and c._sid = @x2 "
type accountDataStatements struct { type accountDataStatements struct {
db *SyncServerDatasource db *SyncServerDatasource
@ -248,14 +247,14 @@ func (s *accountDataStatements) SelectMaxAccountDataID(
// err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID) // err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID)
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": s.getCollectionName(), "@x1": s.getCollectionName(),
"@x2": s.db.cosmosConfig.TenantName,
} }
var rows []AccountDataTypeNumberCosmosData var rows []AccountDataTypeNumberCosmosData
err = cosmosdbapi.PerformQueryAllPartitions(ctx, err = cosmosdbapi.PerformQuery(ctx,
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(),
s.selectMaxAccountDataIDStmt, params, &rows) s.selectMaxAccountDataIDStmt, params, &rows)
if err != cosmosdbutil.ErrNoRows && len(rows) == 1 { if err != cosmosdbutil.ErrNoRows && len(rows) == 1 {

View file

@ -69,9 +69,9 @@ func prepareWithFilters(
} }
switch order { switch order {
case FilterOrderAsc: case FilterOrderAsc:
sql += fmt.Sprintf("order by c.%s.event_id asc ", collectionName) sql += fmt.Sprintf("order by c.%s.id asc ", collectionName)
case FilterOrderDesc: case FilterOrderDesc:
sql += fmt.Sprintf("order by c.%s.event_id desc ", collectionName) sql += fmt.Sprintf("order by c.%s.id desc ", collectionName)
} }
// query += fmt.Sprintf(" LIMIT $%d", offset+1) // query += fmt.Sprintf(" LIMIT $%d", offset+1)
return return

View file

@ -83,8 +83,7 @@ const selectInviteEventsInRangeSQL = "" +
// "SELECT MAX(id) FROM syncapi_invite_events" // "SELECT MAX(id) FROM syncapi_invite_events"
const selectMaxInviteIDSQL = "" + const selectMaxInviteIDSQL = "" +
"select max(c.mx_syncapi_invite_event.id) from c where c._cn = @x1 " + "select max(c.mx_syncapi_invite_event.id) from c where c._cn = @x1 "
"and c._sid = @x2 "
type inviteEventsStatements struct { type inviteEventsStatements struct {
db *SyncServerDatasource db *SyncServerDatasource
@ -296,14 +295,14 @@ func (s *inviteEventsStatements) SelectMaxInviteID(
// err = stmt.QueryRowContext(ctx).Scan(&nullableID) // err = stmt.QueryRowContext(ctx).Scan(&nullableID)
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": s.getCollectionName(), "@x1": s.getCollectionName(),
"@x2": s.db.cosmosConfig.TenantName,
} }
var rows []inviteEventCosmosMaxNumber var rows []inviteEventCosmosMaxNumber
err = cosmosdbapi.PerformQueryAllPartitions(ctx, err = cosmosdbapi.PerformQuery(ctx,
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(),
s.selectMaxInviteIDStmt, params, &rows) s.selectMaxInviteIDStmt, params, &rows)
if len(rows) > 0 { if len(rows) > 0 {

View file

@ -115,8 +115,7 @@ const selectEarlyEventsSQL = "" +
// "SELECT MAX(id) FROM syncapi_output_room_events" // "SELECT MAX(id) FROM syncapi_output_room_events"
const selectMaxEventIDSQL = "" + const selectMaxEventIDSQL = "" +
"select max(c.mx_syncapi_output_room_event.id) as number from c where c._cn = @x1 " + "select max(c.mx_syncapi_output_room_event.id) as number from c where c._cn = @x1 "
"and c._sid = @x2 "
// "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2" // "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2"
const updateEventJSONSQL = "" + const updateEventJSONSQL = "" +
@ -336,29 +335,25 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
func (s *outputRoomEventsStatements) SelectMaxEventID( func (s *outputRoomEventsStatements) SelectMaxEventID(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (id int64, err error) { ) (id int64, err error) {
var nullableID sql.NullInt64
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": s.getCollectionName(), "@x1": s.getCollectionName(),
"@x2": s.db.cosmosConfig.TenantName,
} }
// stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt) // stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt)
var rows []outputRoomEventCosmosMaxNumber var rows []outputRoomEventCosmosMaxNumber
err = cosmosdbapi.PerformQueryAllPartitions(ctx, err = cosmosdbapi.PerformQuery(ctx,
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(),
s.selectMaxEventIDStmt, params, &rows) s.selectMaxEventIDStmt, params, &rows)
// err = stmt.QueryRowContext(ctx).Scan(&nullableID) // err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if rows != nil { if len(rows) > 0 {
nullableID.Int64 = rows[0].Max id = rows[0].Max
} }
if nullableID.Valid {
id = nullableID.Int64
}
return return
} }

View file

@ -382,10 +382,11 @@ func (s *peekStatements) SelectMaxPeekID(
"@x1": s.getCollectionName(), "@x1": s.getCollectionName(),
} }
var rows []peekCosmosMaxNumber var rows []peekCosmosMaxNumber
err = cosmosdbapi.PerformQueryAllPartitions(ctx, err = cosmosdbapi.PerformQuery(ctx,
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(),
s.selectMaxPeekIDStmt, params, &rows) s.selectMaxPeekIDStmt, params, &rows)
// err = stmt.QueryRowContext(ctx).Scan(&nullableID) // err = stmt.QueryRowContext(ctx).Scan(&nullableID)

View file

@ -76,8 +76,7 @@ const selectRoomReceipts = "" +
// "SELECT MAX(id) FROM syncapi_receipts" // "SELECT MAX(id) FROM syncapi_receipts"
const selectMaxReceiptIDSQL = "" + const selectMaxReceiptIDSQL = "" +
"select max(c.mx_syncapi_receipt.id) as number from c where c._cn = @x1 " + "select max(c.mx_syncapi_receipt.id) as number from c where c._cn = @x1 "
"and c._sid = @x2 "
type receiptStatements struct { type receiptStatements struct {
db *SyncServerDatasource db *SyncServerDatasource
@ -210,13 +209,13 @@ func (s *receiptStatements) SelectMaxReceiptID(
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": s.getCollectionName(), "@x1": s.getCollectionName(),
"@x2": s.db.cosmosConfig.TenantName,
} }
var rows []receiptCosmosMaxNumber var rows []receiptCosmosMaxNumber
err = cosmosdbapi.PerformQueryAllPartitions(ctx, err = cosmosdbapi.PerformQuery(ctx,
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(),
s.selectMaxReceiptID, params, &rows) s.selectMaxReceiptID, params, &rows)
// stmt := sqlutil.TxStmt(txn, s.selectMaxReceiptID) // stmt := sqlutil.TxStmt(txn, s.selectMaxReceiptID)

View file

@ -83,8 +83,7 @@ const deleteSendToDeviceMessagesSQL = "" +
// "SELECT MAX(id) FROM syncapi_send_to_device" // "SELECT MAX(id) FROM syncapi_send_to_device"
const selectMaxSendToDeviceIDSQL = "" + const selectMaxSendToDeviceIDSQL = "" +
"select max(c.mx_syncapi_send_to_device.id) as number from c where c._cn = @x1 " + "select max(c.mx_syncapi_send_to_device.id) as number from c where c._cn = @x1 "
"and c._sid = @x2 "
type sendToDeviceStatements struct { type sendToDeviceStatements struct {
db *SyncServerDatasource db *SyncServerDatasource
@ -274,13 +273,13 @@ func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID(
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": s.getCollectionName(), "@x1": s.getCollectionName(),
"@x2": s.db.cosmosConfig.TenantName,
} }
var rows []SendToDeviceCosmosMaxNumber var rows []SendToDeviceCosmosMaxNumber
err = cosmosdbapi.PerformQueryAllPartitions(ctx, err = cosmosdbapi.PerformQuery(ctx,
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(),
s.selectMaxSendToDeviceIDStmt, params, &rows) s.selectMaxSendToDeviceIDStmt, params, &rows)
// stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt) // stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt)

View file

@ -76,8 +76,8 @@ type keyBackupVersionCosmosNumber struct {
// "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" // "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1"
const selectLatestVersionSQL = "" + const selectLatestVersionSQL = "" +
"select max(c.mx_userapi_account_e2e_room_keys_versions.version) as number from c where c._sid = @x1 and c._cn = @x2 " + "select max(c.mx_userapi_account_e2e_room_keys_versions.version) as number from c where c._cn = @x1 " +
"and c.mx_userapi_account_e2e_room_keys_versions.user_id = @x3 " "and c.mx_userapi_account_e2e_room_keys_versions.user_id = @x2 "
type keyBackupVersionStatements struct { type keyBackupVersionStatements struct {
db *Database db *Database
@ -276,17 +276,17 @@ func (s *keyBackupVersionStatements) selectKeyBackup(
if version == "" { if version == "" {
// var v *int64 // allows nulls // var v *int64 // allows nulls
params := map[string]interface{}{ params := map[string]interface{}{
"@x1": s.db.cosmosConfig.TenantName, "@x1": s.getCollectionName(),
"@x2": s.getCollectionName(), "@x2": userID,
"@x3": userID,
} }
// err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) // err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
var rows []keyBackupVersionCosmosNumber var rows []keyBackupVersionCosmosNumber
err = cosmosdbapi.PerformQueryAllPartitions(ctx, err = cosmosdbapi.PerformQuery(ctx,
s.db.connection, s.db.connection,
s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName, s.db.cosmosConfig.ContainerName,
s.getPartitionKey(userID),
s.selectLatestVersionStmt, params, &rows) s.selectLatestVersionStmt, params, &rows)
if err != nil { if err != nil {
@ -303,6 +303,11 @@ func (s *keyBackupVersionStatements) selectKeyBackup(
return return
} }
versionInt = rows[0].Number versionInt = rows[0].Number
if versionInt == 0 {
err = cosmosdbutil.ErrNoRows
return
}
} else { } else {
if versionInt, err = strconv.ParseInt(version, 10, 64); err != nil { if versionInt, err = strconv.ParseInt(version, 10, 64); err != nil {
return return