diff --git a/appservice/storage/cosmosdb/appservice_events_table.go b/appservice/storage/cosmosdb/appservice_events_table.go index efc23db7c..f7cc63e79 100644 --- a/appservice/storage/cosmosdb/appservice_events_table.go +++ b/appservice/storage/cosmosdb/appservice_events_table.go @@ -45,20 +45,20 @@ import ( // CREATE INDEX IF NOT EXISTS appservice_events_as_id ON appservice_events(as_id); // ` -type EventCosmos struct { +type eventCosmos struct { ID int64 `json:"id"` AppServiceID string `json:"as_id"` HeaderedEventJSON []byte `json:"headered_event_json"` TXNID int64 `json:"txn_id"` } -type EventNumberCosmosData struct { +type eventNumberCosmosData struct { Number int `json:"number"` } -type EventCosmosData struct { +type eventCosmosData struct { cosmosdbapi.CosmosDocument - Event EventCosmos `json:"mx_appservice_event"` + Event eventCosmos `json:"mx_appservice_event"` } // "SELECT id, headered_event_json, txn_id " + @@ -119,50 +119,16 @@ func (s *eventsStatements) prepare(db *Database, writer sqlutil.Writer) (err err return } -func queryEvent(s *eventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []EventCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *eventsStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func queryEventEventNumber(s *eventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventNumberCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []EventNumberCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *eventsStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } -func getEvent(s *eventsStatements, ctx context.Context, pk string, docId string) (*EventCosmosData, error) { - response := EventCosmosData{} +func getEvent(s *eventsStatements, ctx context.Context, pk string, docId string) (*eventCosmosData, error) { + response := eventCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -178,7 +144,7 @@ func getEvent(s *eventsStatements, ctx context.Context, pk string, docId string) return &response, err } -func setEvent(s *eventsStatements, ctx context.Context, event EventCosmosData) (*EventCosmosData, error) { +func setEvent(s *eventsStatements, ctx context.Context, event eventCosmosData) (*eventCosmosData, error) { var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(event.Pk, event.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, @@ -190,7 +156,7 @@ func setEvent(s *eventsStatements, ctx context.Context, event EventCosmosData) ( return &event, ex } -func deleteEvent(s *eventsStatements, ctx context.Context, event EventCosmosData) error { +func deleteEvent(s *eventsStatements, ctx context.Context, event eventCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(event.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, @@ -223,13 +189,17 @@ func (s *eventsStatements) selectEventsByApplicationServiceID( // "SELECT id, headered_event_json, txn_id " + // "FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": applicationServiceID, } - eventRows, err := queryEvent(s, ctx, s.selectEventsByApplicationServiceIDStmt, params) + var rows []eventCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectEventsByApplicationServiceIDStmt, params, &rows) if err != nil { log.WithFields(log.Fields{ @@ -237,7 +207,7 @@ func (s *eventsStatements) selectEventsByApplicationServiceID( }).WithError(err).Fatalf("appservice unable to select new events to send") } - events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit) + events, maxID, txnID, eventsRemaining, err = retrieveEvents(rows, limit) if err != nil { return } @@ -252,7 +222,7 @@ func checkNamedErr(fn func() error, err *error) { } } -func retrieveEvents(eventRows []EventCosmosData, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) { +func retrieveEvents(eventRows []eventCosmosData, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) { // Get current time for use in calculating event age nowMilli := time.Now().UnixNano() / int64(time.Millisecond) @@ -318,18 +288,22 @@ func (s *eventsStatements) countEventsByApplicationServiceID( // "SELECT COUNT(id) FROM appservice_events WHERE as_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": appServiceID, } - response, err := queryEventEventNumber(s, ctx, s.countEventsByApplicationServiceIDStmt, params) + var rows []eventNumberCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectEventsByApplicationServiceIDStmt, params, &rows) if err != nil && err != sql.ErrNoRows { return 0, err } - count = response[0].Number + count = rows[0].Number return count, nil } @@ -350,12 +324,10 @@ func (s *eventsStatements) insertEvent( // "INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " + // "VALUES ($1, $2, $3)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) docId := fmt.Sprintf("%s", appServiceID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, err := getEvent(s, ctx, pk, cosmosDocId) + dbData, err := getEvent(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.Event.HeaderedEventJSON = eventJSON @@ -369,15 +341,15 @@ func (s *eventsStatements) insertEvent( // appServiceID, // eventJSON, // -1, // No transaction ID yet - data := EventCosmos{ + data := eventCosmos{ AppServiceID: appServiceID, HeaderedEventJSON: eventJSON, ID: idSeq, TXNID: -1, } - dbData = &EventCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &eventCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), Event: data, } @@ -401,19 +373,23 @@ func (s *eventsStatements) updateTxnIDForEvents( ) (err error) { // "UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": appserviceID, "@x3": maxID, } - response, err := queryEvent(s, ctx, s.updateTxnIDForEventsStmt, params) + var rows []eventCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.updateTxnIDForEventsStmt, params, &rows) if err != nil { return err } - for _, item := range response { + for _, item := range rows { item.Event.TXNID = int64(txnID) // _, err := s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID) _, err = setEvent(s, ctx, item) @@ -430,19 +406,24 @@ func (s *eventsStatements) deleteEventsBeforeAndIncludingID( ) (err error) { // "DELETE FROM appservice_events WHERE as_id = $1 AND id <= $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": appserviceID, "@x3": eventTableID, } - response, err := queryEvent(s, ctx, s.deleteEventsBeforeAndIncludingIDStmt, params) + var rows []eventCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deleteEventsBeforeAndIncludingIDStmt, params, &rows) + if err != nil { return err } - for _, item := range response { + for _, item := range rows { // _, err := s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID) err = deleteEvent(s, ctx, item) } diff --git a/federationsender/storage/cosmosdb/blacklist_table.go b/federationsender/storage/cosmosdb/blacklist_table.go index 531ffddee..64464e2d2 100644 --- a/federationsender/storage/cosmosdb/blacklist_table.go +++ b/federationsender/storage/cosmosdb/blacklist_table.go @@ -32,13 +32,13 @@ import ( // ); // ` -type BlacklistCosmos struct { +type blacklistCosmos struct { ServerName string `json:"server_name"` } -type BlacklistCosmosData struct { +type blacklistCosmosData struct { cosmosdbapi.CosmosDocument - Blacklist BlacklistCosmos `json:"mx_federationsender_blacklist"` + Blacklist blacklistCosmos `json:"mx_federationsender_blacklist"` } // const insertBlacklistSQL = "" + @@ -64,8 +64,16 @@ type blacklistStatements struct { tableName string } -func getBlacklist(s *blacklistStatements, ctx context.Context, pk string, docId string) (*BlacklistCosmosData, error) { - response := BlacklistCosmosData{} +func (s *blacklistStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *blacklistStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getBlacklist(s *blacklistStatements, ctx context.Context, pk string, docId string) (*blacklistCosmosData, error) { + response := blacklistCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -81,28 +89,7 @@ func getBlacklist(s *blacklistStatements, ctx context.Context, pk string, docId return &response, err } -func queryBlacklist(s *blacklistStatements, ctx context.Context, qry string, params map[string]interface{}) ([]BlacklistCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []BlacklistCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - -func deleteBlacklist(s *blacklistStatements, ctx context.Context, dbData BlacklistCosmosData) error { +func deleteBlacklist(s *blacklistStatements, ctx context.Context, dbData blacklistCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, @@ -137,22 +124,20 @@ func (s *blacklistStatements) InsertBlacklist( // stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // UNIQUE (server_name) docId := fmt.Sprintf("%s", serverName) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getBlacklist(s, ctx, pk, cosmosDocId) + dbData, _ := getBlacklist(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() } else { - data := BlacklistCosmos{ + data := blacklistCosmos{ ServerName: string(serverName), } - dbData = &BlacklistCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &blacklistCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), Blacklist: data, } } @@ -177,13 +162,11 @@ func (s *blacklistStatements) SelectBlacklist( // stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // UNIQUE (server_name) docId := fmt.Sprintf("%s", serverName) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) // res, err := stmt.QueryContext(ctx, serverName) - res, err := getBlacklist(s, ctx, pk, cosmosDocId) + res, err := getBlacklist(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return false, err } @@ -201,13 +184,11 @@ func (s *blacklistStatements) DeleteBlacklist( // "DELETE FROM federationsender_blacklist WHERE server_name = $1" // stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // UNIQUE (server_name) docId := fmt.Sprintf("%s", serverName) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) // _, err := stmt.ExecContext(ctx, serverName) - res, err := getBlacklist(s, ctx, pk, cosmosDocId) + res, err := getBlacklist(s, ctx, s.getPartitionKey(), cosmosDocId) if res != nil { _ = deleteBlacklist(s, ctx, *res) } @@ -220,13 +201,17 @@ func (s *blacklistStatements) DeleteAllBlacklist( // "DELETE FROM federationsender_blacklist" // stmt := sqlutil.TxStmt(txn, s.deleteAllBlacklistStmt) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), } // rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID) - rows, err := queryBlacklist(s, ctx, s.deleteAllBlacklistStmt, params) + var rows []blacklistCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deleteAllBlacklistStmt, params, &rows) if err != nil { return err diff --git a/federationsender/storage/cosmosdb/inbound_peeks_table.go b/federationsender/storage/cosmosdb/inbound_peeks_table.go index 2d139a46a..48ad7af23 100644 --- a/federationsender/storage/cosmosdb/inbound_peeks_table.go +++ b/federationsender/storage/cosmosdb/inbound_peeks_table.go @@ -38,7 +38,7 @@ import ( // ); // ` -type InboundPeekCosmos struct { +type inboundPeekCosmos struct { RoomID string `json:"room_id"` ServerName string `json:"server_name"` PeekID string `json:"peek_id"` @@ -47,9 +47,9 @@ type InboundPeekCosmos struct { RenewalInterval int64 `json:"renewal_interval"` } -type InboundPeekCosmosData struct { +type inboundPeekCosmosData struct { cosmosdbapi.CosmosDocument - InboundPeek InboundPeekCosmos `json:"mx_federationsender_inbound_peek"` + InboundPeek inboundPeekCosmos `json:"mx_federationsender_inbound_peek"` } // const insertInboundPeekSQL = "" + @@ -88,29 +88,16 @@ type inboundPeeksStatements struct { tableName string } -func queryInboundPeek(s *inboundPeeksStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InboundPeekCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []InboundPeekCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *inboundPeeksStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func getInboundPeek(s *inboundPeeksStatements, ctx context.Context, pk string, docId string) (*InboundPeekCosmosData, error) { - response := InboundPeekCosmosData{} +func (s *inboundPeeksStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getInboundPeek(s *inboundPeeksStatements, ctx context.Context, pk string, docId string) (*inboundPeekCosmosData, error) { + response := inboundPeekCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -126,7 +113,7 @@ func getInboundPeek(s *inboundPeeksStatements, ctx context.Context, pk string, d return &response, err } -func setInboundPeek(s *inboundPeeksStatements, ctx context.Context, inboundPeek InboundPeekCosmosData) (*InboundPeekCosmosData, error) { +func setInboundPeek(s *inboundPeeksStatements, ctx context.Context, inboundPeek inboundPeekCosmosData) (*inboundPeekCosmosData, error) { var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(inboundPeek.Pk, inboundPeek.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, @@ -138,7 +125,7 @@ func setInboundPeek(s *inboundPeeksStatements, ctx context.Context, inboundPeek return &inboundPeek, ex } -func deleteInboundPeek(s *inboundPeeksStatements, ctx context.Context, dbData InboundPeekCosmosData) error { +func deleteInboundPeek(s *inboundPeeksStatements, ctx context.Context, dbData inboundPeekCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, @@ -172,19 +159,17 @@ func (s *inboundPeeksStatements) InsertInboundPeek( nowMilli := time.Now().UnixNano() / int64(time.Millisecond) // stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // UNIQUE (room_id, server_name, peek_id) docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getInboundPeek(s, ctx, pk, cosmosDocId) + dbData, _ := getInboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.InboundPeek.RenewedTimestamp = nowMilli dbData.InboundPeek.RenewalInterval = renewalInterval } else { - data := InboundPeekCosmos{ + data := inboundPeekCosmos{ RoomID: roomID, ServerName: string(serverName), PeekID: peekID, @@ -193,8 +178,8 @@ func (s *inboundPeeksStatements) InsertInboundPeek( RenewalInterval: renewalInterval, } - dbData = &InboundPeekCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &inboundPeekCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), InboundPeek: data, } } @@ -218,14 +203,12 @@ func (s *inboundPeeksStatements) RenewInboundPeek( // "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" // _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // UNIQUE (room_id, server_name, peek_id) docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) // _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) - res, err := getInboundPeek(s, ctx, pk, cosmosDocId) + res, err := getInboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return @@ -248,14 +231,12 @@ func (s *inboundPeeksStatements) SelectInboundPeek( ) (*types.InboundPeek, error) { // "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // UNIQUE (room_id, server_name, peek_id) docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getPartitionKey(), docId) // row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID) - row, err := getInboundPeek(s, ctx, pk, cosmosDocId) + row, err := getInboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) if row == nil { return nil, nil @@ -278,14 +259,18 @@ func (s *inboundPeeksStatements) SelectInboundPeeks( ) (inboundPeeks []types.InboundPeek, err error) { // "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, } // rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID) - rows, err := queryInboundPeek(s, ctx, s.selectInboundPeeksStmt, params) + var rows []inboundPeekCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectInboundPeeksStmt, params, &rows) if err != nil { return @@ -310,15 +295,19 @@ func (s *inboundPeeksStatements) DeleteInboundPeek( ) (err error) { // "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, "@x3": serverName, } // _, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID) - rows, err := queryInboundPeek(s, ctx, s.deleteInboundPeekStmt, params) + var rows []inboundPeekCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deleteInboundPeekStmt, params, &rows) if err != nil { return @@ -339,14 +328,18 @@ func (s *inboundPeeksStatements) DeleteInboundPeeks( ) (err error) { // "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, } // _, err = sqlutil.TxStmt(txn, s.deleteInboundPeeksStmt).ExecContext(ctx, roomID) - rows, err := queryInboundPeek(s, ctx, s.deleteInboundPeekStmt, params) + var rows []inboundPeekCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deleteInboundPeekStmt, params, &rows) if err != nil { return diff --git a/federationsender/storage/cosmosdb/joined_hosts_table.go b/federationsender/storage/cosmosdb/joined_hosts_table.go index ff8bec7b7..772bf5814 100644 --- a/federationsender/storage/cosmosdb/joined_hosts_table.go +++ b/federationsender/storage/cosmosdb/joined_hosts_table.go @@ -45,15 +45,15 @@ import ( // ON federationsender_joined_hosts (room_id) // ` -type JoinedHostCosmos struct { +type joinedHostCosmos struct { RoomID string `json:"room_id"` EventID string `json:"event_id"` ServerName string `json:"server_name"` } -type JoinedHostCosmosData struct { +type joinedHostCosmosData struct { cosmosdbapi.CosmosDocument - JoinedHost JoinedHostCosmos `json:"mx_federationsender_joined_host"` + JoinedHost joinedHostCosmos `json:"mx_federationsender_joined_host"` } // const insertJoinedHostsSQL = "" + @@ -96,8 +96,16 @@ type joinedHostsStatements struct { tableName string } -func getJoinedHost(s *joinedHostsStatements, ctx context.Context, pk string, docId string) (*JoinedHostCosmosData, error) { - response := JoinedHostCosmosData{} +func (s *joinedHostsStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *joinedHostsStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getJoinedHost(s *joinedHostsStatements, ctx context.Context, pk string, docId string) (*joinedHostCosmosData, error) { + response := joinedHostCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -113,49 +121,7 @@ func getJoinedHost(s *joinedHostsStatements, ctx context.Context, pk string, doc return &response, err } -func queryJoinedHostDistinct(s *joinedHostsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]JoinedHostCosmos, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []JoinedHostCosmos - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - -func queryJoinedHost(s *joinedHostsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]JoinedHostCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []JoinedHostCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - -func deleteJoinedHost(s *joinedHostsStatements, ctx context.Context, dbData JoinedHostCosmosData) error { +func deleteJoinedHost(s *joinedHostsStatements, ctx context.Context, dbData joinedHostCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, @@ -194,23 +160,21 @@ func (s *joinedHostsStatements) InsertJoinedHosts( // stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx // ON federationsender_joined_hosts (event_id); docId := fmt.Sprintf("%s", eventID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getJoinedHost(s, ctx, pk, cosmosDocId) + dbData, _ := getJoinedHost(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData == nil { - data := JoinedHostCosmos{ + data := joinedHostCosmos{ EventID: eventID, RoomID: roomID, ServerName: string(serverName), } - dbData = &JoinedHostCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &joinedHostCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), JoinedHost: data, } // _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) @@ -231,14 +195,17 @@ func (s *joinedHostsStatements) DeleteJoinedHosts( for _, eventID := range eventIDs { // "DELETE FROM federationsender_joined_hosts WHERE event_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventID, } // stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) - - rows, err := queryJoinedHost(s, ctx, s.deleteJoinedHostsStmt, params) + var rows []joinedHostCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deleteJoinedHostsStmt, params, &rows) for _, item := range rows { if err = deleteJoinedHost(s, ctx, item); err != nil { @@ -254,14 +221,18 @@ func (s *joinedHostsStatements) DeleteJoinedHostsForRoom( ) error { // "DELETE FROM federationsender_joined_hosts WHERE room_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, } // stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt) - rows, err := queryJoinedHost(s, ctx, s.deleteJoinedHostsStmt, params) + var rows []joinedHostCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deleteJoinedHostsStmt, params, &rows) // _, err := stmt.ExecContext(ctx, roomID) for _, item := range rows { @@ -278,14 +249,18 @@ func (s *joinedHostsStatements) SelectJoinedHostsWithTx( // "SELECT event_id, server_name FROM federationsender_joined_hosts" + // " WHERE room_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, } // stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt) - rows, err := queryJoinedHost(s, ctx, s.deleteJoinedHostsStmt, params) + var rows []joinedHostCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectJoinedHostsStmt, params, &rows) if err != nil { return nil, err @@ -305,13 +280,18 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts( ) ([]gomatrixserverlib.ServerName, error) { // "SELECT DISTINCT server_name FROM federationsender_joined_hosts" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), } // rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx) - rows, err := queryJoinedHostDistinct(s, ctx, s.selectAllJoinedHostsStmt, params) + var rows []joinedHostCosmos + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectAllJoinedHostsStmt, params, &rows) + if err != nil { return nil, err } @@ -337,14 +317,19 @@ func (s *joinedHostsStatements) SelectJoinedHostsForRooms( // "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" // sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomIDs, } // rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...) - rows, err := queryJoinedHostDistinct(s, ctx, s.selectAllJoinedHostsStmt, params) + var rows []joinedHostCosmos + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectAllJoinedHostsStmt, params, &rows) + if err != nil { return nil, err } @@ -359,7 +344,7 @@ func (s *joinedHostsStatements) SelectJoinedHostsForRooms( return result, nil } -func rowsToJoinedHosts(rows *[]JoinedHostCosmosData) []types.JoinedHost { +func rowsToJoinedHosts(rows *[]joinedHostCosmosData) []types.JoinedHost { var result []types.JoinedHost if rows == nil { return result diff --git a/federationsender/storage/cosmosdb/outbound_peeks_table.go b/federationsender/storage/cosmosdb/outbound_peeks_table.go index ab2efd51d..828ca2222 100644 --- a/federationsender/storage/cosmosdb/outbound_peeks_table.go +++ b/federationsender/storage/cosmosdb/outbound_peeks_table.go @@ -38,7 +38,7 @@ import ( // ); // ` -type OutboundPeekCosmos struct { +type outboundPeekCosmos struct { RoomID string `json:"room_id"` ServerName string `json:"server_name"` PeekID string `json:"peek_id"` @@ -47,9 +47,9 @@ type OutboundPeekCosmos struct { RenewalInterval int64 `json:"renewal_interval"` } -type OutboundPeekCosmosData struct { +type outboundPeekCosmosData struct { cosmosdbapi.CosmosDocument - OutboundPeek OutboundPeekCosmos `json:"mx_federationsender_outbound_peek"` + OutboundPeek outboundPeekCosmos `json:"mx_federationsender_outbound_peek"` } // const insertOutboundPeekSQL = "" + @@ -85,29 +85,16 @@ type outboundPeeksStatements struct { tableName string } -func queryOutboundPeek(s *outboundPeeksStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OutboundPeekCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []OutboundPeekCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *outboundPeeksStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func getOutboundPeek(s *outboundPeeksStatements, ctx context.Context, pk string, docId string) (*OutboundPeekCosmosData, error) { - response := OutboundPeekCosmosData{} +func (s *outboundPeeksStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getOutboundPeek(s *outboundPeeksStatements, ctx context.Context, pk string, docId string) (*outboundPeekCosmosData, error) { + response := outboundPeekCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -123,7 +110,7 @@ func getOutboundPeek(s *outboundPeeksStatements, ctx context.Context, pk string, return &response, err } -func setOutboundPeek(s *outboundPeeksStatements, ctx context.Context, outboundPeek OutboundPeekCosmosData) (*OutboundPeekCosmosData, error) { +func setOutboundPeek(s *outboundPeeksStatements, ctx context.Context, outboundPeek outboundPeekCosmosData) (*outboundPeekCosmosData, error) { var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(outboundPeek.Pk, outboundPeek.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, @@ -135,7 +122,7 @@ func setOutboundPeek(s *outboundPeeksStatements, ctx context.Context, outboundPe return &outboundPeek, ex } -func deleteOutboundPeek(s *outboundPeeksStatements, ctx context.Context, dbData OutboundPeekCosmosData) error { +func deleteOutboundPeek(s *outboundPeeksStatements, ctx context.Context, dbData outboundPeekCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, @@ -168,20 +155,18 @@ func (s *outboundPeeksStatements) InsertOutboundPeek( // stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt) nowMilli := time.Now().UnixNano() / int64(time.Millisecond) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // UNIQUE (room_id, server_name, peek_id) docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getOutboundPeek(s, ctx, pk, cosmosDocId) + dbData, _ := getOutboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.OutboundPeek.RenewalInterval = renewalInterval dbData.OutboundPeek.RenewedTimestamp = nowMilli } else { - data := OutboundPeekCosmos{ + data := outboundPeekCosmos{ RoomID: roomID, ServerName: string(serverName), PeekID: peekID, @@ -190,8 +175,8 @@ func (s *outboundPeeksStatements) InsertOutboundPeek( RenewalInterval: renewalInterval, } - dbData = &OutboundPeekCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &outboundPeekCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), OutboundPeek: data, } @@ -215,14 +200,12 @@ func (s *outboundPeeksStatements) RenewOutboundPeek( // "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" nowMilli := time.Now().UnixNano() / int64(time.Millisecond) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // UNIQUE (room_id, server_name, peek_id) docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) // _, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) - res, err := getOutboundPeek(s, ctx, pk, cosmosDocId) + res, err := getOutboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return @@ -245,14 +228,12 @@ func (s *outboundPeeksStatements) SelectOutboundPeek( // "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // UNIQUE (room_id, server_name, peek_id) docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) // row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID) - row, err := getOutboundPeek(s, ctx, pk, cosmosDocId) + row, err := getOutboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return nil, err @@ -280,14 +261,19 @@ func (s *outboundPeeksStatements) SelectOutboundPeeks( if err != nil { return } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, } // rows, err := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryContext(ctx, roomID) - rows, err := queryOutboundPeek(s, ctx, s.selectOutboundPeeksStmt, params) + var rows []outboundPeekCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectOutboundPeeksStmt, params, &rows) if err != nil { return @@ -313,15 +299,19 @@ func (s *outboundPeeksStatements) DeleteOutboundPeek( // "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, "@x3": serverName, } // _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID) - rows, err := queryOutboundPeek(s, ctx, s.deleteOutboundPeekStmt, params) + var rows []outboundPeekCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deleteOutboundPeekStmt, params, &rows) if err != nil { return @@ -343,14 +333,18 @@ func (s *outboundPeeksStatements) DeleteOutboundPeeks( // "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, } // _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeeksStmt).ExecContext(ctx, roomID) - rows, err := queryOutboundPeek(s, ctx, s.deleteOutboundPeeksStmt, params) + var rows []outboundPeekCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deleteOutboundPeeksStmt, params, &rows) if err != nil { return diff --git a/federationsender/storage/cosmosdb/queue_edus_table.go b/federationsender/storage/cosmosdb/queue_edus_table.go index 790a6dae8..f08ee65bb 100644 --- a/federationsender/storage/cosmosdb/queue_edus_table.go +++ b/federationsender/storage/cosmosdb/queue_edus_table.go @@ -38,19 +38,19 @@ import ( // ON federationsender_queue_edus (json_nid, server_name); // ` -type QueueEDUCosmos struct { +type queueEDUCosmos struct { EDUType string `json:"edu_type"` ServerName string `json:"server_name"` JSONNID int64 `json:"json_nid"` } -type QueueEDUCosmosNumber struct { +type queueEDUCosmosNumber struct { Number int64 `json:"number"` } -type QueueEDUCosmosData struct { +type queueEDUCosmosData struct { cosmosdbapi.CosmosDocument - QueueEDU QueueEDUCosmos `json:"mx_federationsender_queue_edu"` + QueueEDU queueEDUCosmos `json:"mx_federationsender_queue_edu"` } // const insertQueueEDUSQL = "" + @@ -96,70 +96,15 @@ type queueEDUsStatements struct { tableName string } -func queryQueueEDUC(s *queueEDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueEDUCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []QueueEDUCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *queueEDUsStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func queryQueueEDUCDistinct(s *queueEDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueEDUCosmos, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []QueueEDUCosmos - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *queueEDUsStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } -func queryQueueEDUCNumber(s *queueEDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueEDUCosmosNumber, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []QueueEDUCosmosNumber - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - -func deleteQueueEDUC(s *queueEDUsStatements, ctx context.Context, dbData QueueEDUCosmosData) error { +func deleteQueueEDUC(s *queueEDUsStatements, ctx context.Context, dbData queueEDUCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, @@ -198,21 +143,19 @@ func (s *queueEDUsStatements) InsertQueueEDU( // stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx // ON federationsender_queue_edus (json_nid, server_name); docId := fmt.Sprintf("%d_%s", nid, eduType) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - data := QueueEDUCosmos{ + data := queueEDUCosmos{ EDUType: eduType, JSONNID: nid, ServerName: string(serverName), } - dbData := &QueueEDUCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData := &queueEDUCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), QueueEDU: data, } @@ -244,16 +187,20 @@ func (s *queueEDUsStatements) DeleteQueueEDUs( // deleteSQL := strings.Replace(deleteQueueEDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": serverName, "@x3": jsonNIDs, } // stmt := sqlutil.TxStmt(txn, deleteStmt) // _, err = stmt.ExecContext(ctx, params...) - rows, err := queryQueueEDUC(s, ctx, deleteQueueEDUsSQL, params) + var rows []queueEDUCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), deleteQueueEDUsSQL, params, &rows) if err != nil { return err @@ -280,15 +227,20 @@ func (s *queueEDUsStatements) SelectQueueEDUs( // " LIMIT $2" // stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": serverName, "@x3": limit, } // rows, err := stmt.QueryContext(ctx, serverName, limit) - rows, err := queryQueueEDUC(s, ctx, deleteQueueEDUsSQL, params) + var rows []queueEDUCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), deleteQueueEDUsSQL, params, &rows) + if err != nil { return nil, err } @@ -309,14 +261,19 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( // "SELECT COUNT(*) FROM federationsender_queue_edus" + // " WHERE json_nid = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": jsonNID, } // stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt) - rows, err := queryQueueEDUCNumber(s, ctx, s.selectQueueEDUReferenceJSONCountStmt, params) + var rows []queueEDUCosmosNumber + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectQueueEDUReferenceJSONCountStmt, params, &rows) + if len(rows) == 0 { return -1, nil } @@ -333,14 +290,19 @@ func (s *queueEDUsStatements) SelectQueueEDUCount( // "SELECT COUNT(*) FROM federationsender_queue_edus" + // " WHERE server_name = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": serverName, } // stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt) - rows, err := queryQueueEDUCNumber(s, ctx, s.selectQueueEDUCountStmt, params) + var rows []queueEDUCosmosNumber + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectQueueEDUCountStmt, params, &rows) + if len(rows) == 0 { // It's acceptable for there to be no rows referencing a given // JSON NID but it's not an error condition. Just return as if @@ -358,14 +320,19 @@ func (s *queueEDUsStatements) SelectQueueEDUServerNames( // "SELECT DISTINCT server_name FROM federationsender_queue_edus" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), } // stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt) // rows, err := stmt.QueryContext(ctx) - rows, err := queryQueueEDUCDistinct(s, ctx, s.selectQueueEDUServerNamesStmt, params) + var rows []queueEDUCosmos + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectQueueEDUServerNamesStmt, params, &rows) + if err != nil { return nil, err } diff --git a/federationsender/storage/cosmosdb/queue_json_table.go b/federationsender/storage/cosmosdb/queue_json_table.go index 0c464fae5..9dc5d343e 100644 --- a/federationsender/storage/cosmosdb/queue_json_table.go +++ b/federationsender/storage/cosmosdb/queue_json_table.go @@ -35,14 +35,14 @@ import ( // ); // ` -type QueueJSONCosmos struct { +type queueJSONCosmos struct { JSONNID int64 `json:"json_nid"` JSONBody []byte `json:"json_body"` } -type QueueJSONCosmosData struct { +type queueJSONCosmosData struct { cosmosdbapi.CosmosDocument - QueueJSON QueueJSONCosmos `json:"mx_federationsender_queue_json"` + QueueJSON queueJSONCosmos `json:"mx_federationsender_queue_json"` } // const insertJSONSQL = "" + @@ -68,28 +68,15 @@ type queueJSONStatements struct { tableName string } -func queryQueueJSON(s *queueJSONStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueJSONCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []QueueJSONCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *queueJSONStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func deleteQueueJSON(s *queueJSONStatements, ctx context.Context, dbData QueueJSONCosmosData) error { +func (s *queueJSONStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func deleteQueueJSON(s *queueJSONStatements, ctx context.Context, dbData queueJSONCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, @@ -122,22 +109,20 @@ func (s *queueJSONStatements) InsertQueueJSON( // json_nid INTEGER PRIMARY KEY AUTOINCREMENT, idSeq, err := GetNextQueueJSONNID(s, ctx) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // json_nid INTEGER PRIMARY KEY AUTOINCREMENT, docId := fmt.Sprintf("%d", idSeq) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) //Convert to byte jsonData := []byte(json) - data := QueueJSONCosmos{ + data := queueJSONCosmos{ JSONNID: idSeq, JSONBody: jsonData, } - dbData := &QueueJSONCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData := &queueJSONCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), QueueJSON: data, } @@ -165,16 +150,20 @@ func (s *queueJSONStatements) DeleteQueueJSON( // "DELETE FROM federationsender_queue_json WHERE json_nid IN ($1)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": nids, } // deleteSQL := strings.Replace(deleteJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) // deleteStmt, err := txn.Prepare(deleteSQL) // stmt := sqlutil.TxStmt(txn, deleteStmt) - rows, err := queryQueueJSON(s, ctx, deleteJSONSQL, params) + var rows []queueJSONCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), deleteJSONSQL, params, &rows) if err != nil { return err @@ -198,15 +187,19 @@ func (s *queueJSONStatements) SelectQueueJSON( // "SELECT json_nid, json_body FROM federationsender_queue_json" + // " WHERE json_nid IN ($1)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": jsonNIDs, } // selectSQL := strings.Replace(selectJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1) // selectStmt, err := txn.Prepare(selectSQL) - rows, err := queryQueueJSON(s, ctx, selectJSONSQL, params) + var rows []queueJSONCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), selectJSONSQL, params, &rows) if err != nil { return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err) diff --git a/federationsender/storage/cosmosdb/queue_pdus_table.go b/federationsender/storage/cosmosdb/queue_pdus_table.go index a4afa5dce..99e679c6e 100644 --- a/federationsender/storage/cosmosdb/queue_pdus_table.go +++ b/federationsender/storage/cosmosdb/queue_pdus_table.go @@ -39,19 +39,19 @@ import ( // ON federationsender_queue_pdus (json_nid, server_name); // ` -type QueuePDUCosmos struct { +type queuePDUCosmos struct { TransactionID string `json:"transaction_id"` ServerName string `json:"server_name"` JSONNID int64 `json:"json_nid"` } -type QueuePDUCosmosNumber struct { +type queuePDUCosmosNumber struct { Number int64 `json:"number"` } -type QueuePDUCosmosData struct { +type queuePDUCosmosData struct { cosmosdbapi.CosmosDocument - QueuePDU QueuePDUCosmos `json:"mx_federationsender_queue_pdu"` + QueuePDU queuePDUCosmos `json:"mx_federationsender_queue_pdu"` } // const insertQueuePDUSQL = "" + @@ -108,70 +108,15 @@ type queuePDUsStatements struct { tableName string } -func queryQueuePDU(s *queuePDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueuePDUCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []QueuePDUCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *queuePDUsStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func queryQueuePDUDistinct(s *queuePDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueuePDUCosmos, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []QueuePDUCosmos - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *queuePDUsStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } -func queryQueuePDUNumber(s *queuePDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueuePDUCosmosNumber, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []QueuePDUCosmosNumber - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - -func deleteQueuePDU(s *queuePDUsStatements, ctx context.Context, dbData QueuePDUCosmosData) error { +func deleteQueuePDU(s *queuePDUsStatements, ctx context.Context, dbData queuePDUCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, @@ -210,21 +155,19 @@ func (s *queuePDUsStatements) InsertQueuePDU( // "INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" + // " VALUES ($1, $2, $3)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx // ON federationsender_queue_pdus (json_nid, server_name); docId := fmt.Sprintf("%d_%s", nid, serverName) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - data := QueuePDUCosmos{ + data := queuePDUCosmos{ JSONNID: nid, ServerName: string(serverName), TransactionID: string(transactionID), } - dbData := &QueuePDUCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData := &queuePDUCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), QueuePDU: data, } @@ -255,16 +198,20 @@ func (s *queuePDUsStatements) DeleteQueuePDUs( // "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid IN ($2)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": serverName, "@x3": jsonNIDs, } // deleteSQL := strings.Replace(deleteQueuePDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) // deleteStmt, err := txn.Prepare(deleteSQL) - rows, err := queryQueuePDU(s, ctx, deleteQueuePDUsSQL, params) + var rows []queuePDUCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), deleteQueuePDUsSQL, params, &rows) if err != nil { return err @@ -290,14 +237,18 @@ func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID( // " ORDER BY transaction_id ASC" + // " LIMIT 1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": serverName, } // stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt) - rows, err := queryQueuePDU(s, ctx, s.selectQueueNextTransactionIDStmt, params) + var rows []queuePDUCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectQueueNextTransactionIDStmt, params, &rows) if err != nil { return "", err @@ -319,14 +270,18 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( // "SELECT COUNT(*) FROM federationsender_queue_pdus" + // " WHERE json_nid = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": jsonNID, } // stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt) - rows, err := queryQueuePDUNumber(s, ctx, s.selectQueueReferenceJSONCountStmt, params) + var rows []queuePDUCosmosNumber + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectQueueReferenceJSONCountStmt, params, &rows) if err != nil { return -1, err @@ -348,14 +303,18 @@ func (s *queuePDUsStatements) SelectQueuePDUCount( // "SELECT COUNT(*) FROM federationsender_queue_pdus" + // " WHERE server_name = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": serverName, } // stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt) - rows, err := queryQueuePDUNumber(s, ctx, s.selectQueuePDUsCountStmt, params) + var rows []queuePDUCosmosNumber + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectQueuePDUsCountStmt, params, &rows) if err != nil { return 0, err @@ -382,16 +341,20 @@ func (s *queuePDUsStatements) SelectQueuePDUs( // " WHERE server_name = $1" + // " LIMIT $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": serverName, "@x3": limit, } // stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt) // rows, err := stmt.QueryContext(ctx, serverName, limit) - rows, err := queryQueuePDU(s, ctx, s.selectQueuePDUsStmt, params) + var rows []queuePDUCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectQueuePDUsStmt, params, &rows) if err != nil { return nil, err @@ -412,14 +375,19 @@ func (s *queuePDUsStatements) SelectQueuePDUServerNames( // "SELECT DISTINCT server_name FROM federationsender_queue_pdus" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), } // stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt) // rows, err := stmt.QueryContext(ctx) - rows, err := queryQueuePDUDistinct(s, ctx, s.selectQueueServerNamesStmt, params) + var rows []queuePDUCosmos + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectQueueServerNamesStmt, params, &rows) + if err != nil { return nil, err } diff --git a/internal/cosmosdbapi/client.go b/internal/cosmosdbapi/client.go index 1dfddb94e..f7e74001c 100644 --- a/internal/cosmosdbapi/client.go +++ b/internal/cosmosdbapi/client.go @@ -47,6 +47,51 @@ func (doc *CosmosDocument) SetUpdateTime() { doc.Ut = now.Format(time.RFC3339) } +func PerformQuery(ctx context.Context, + conn CosmosConnection, + databaseName string, + containerName string, + partitonKey string, + qryString string, + params map[string]interface{}, + response interface{}) error { + optionsQry := GetQueryDocumentsOptions(partitonKey) + var query = GetQuery(qryString, params) + _, err := GetClient(conn).QueryDocuments( + ctx, + databaseName, + containerName, + query, + &response, + optionsQry) + return err +} + +func PerformQueryAllPartitions(ctx context.Context, + conn CosmosConnection, + databaseName string, + containerName string, + qryString string, + params map[string]interface{}, + response interface{}) error { + var optionsQry = GetQueryAllPartitionsDocumentsOptions() + var query = GetQuery(qryString, params) + _, err := GetClient(conn).QueryDocuments( + ctx, + databaseName, + containerName, + query, + &response, + optionsQry) + + // When there are no Rows we seem to get the generic Bad Req JSON error + if err != nil { + // return nil, err + } + + return nil +} + func GenerateDocument( collection string, tenantName string, diff --git a/internal/cosmosdbapi/document.go b/internal/cosmosdbapi/document.go index 64c68de01..879d90cfa 100644 --- a/internal/cosmosdbapi/document.go +++ b/internal/cosmosdbapi/document.go @@ -32,6 +32,10 @@ func GetDocumentId(tenantName string, collectionName string, id string) string { return fmt.Sprintf("%s,%s,%s", collectionName, tenantName, safeId) } -func GetPartitionKey(tenantName string, collectionName string) string { +func GetPartitionKeyByCollection(tenantName string, collectionName string) string { return fmt.Sprintf("%s,%s", collectionName, tenantName) } + +func GetPartitionKeyByUniqueId(tenantName string, collectionName string, uniqueId string) string { + return fmt.Sprintf("%s,%s,%s", collectionName, tenantName, uniqueId) +} diff --git a/internal/cosmosdbutil/partition_offset_table.go b/internal/cosmosdbutil/partition_offset_table.go index 39f296f75..c3c2fe991 100644 --- a/internal/cosmosdbutil/partition_offset_table.go +++ b/internal/cosmosdbutil/partition_offset_table.go @@ -46,15 +46,15 @@ import ( // ); // ` -type PartitionOffsetCosmos struct { +type partitionOffsetCosmos struct { Topic string `json:"topic"` Partition int32 `json:"partition"` PartitionOffset int64 `json:"partition_offset"` } -type PartitionOffsetCosmosData struct { +type partitionOffsetCosmosData struct { cosmosdbapi.CosmosDocument - PartitionOffset PartitionOffsetCosmos `json:"mx_partition_offset"` + PartitionOffset partitionOffsetCosmos `json:"mx_partition_offset"` } // "SELECT partition, partition_offset FROM ${prefix}_partition_offsets WHERE topic = $1" @@ -84,8 +84,18 @@ type PartitionOffsetStatements struct { tableName string } -func getPartitionOffset(s *PartitionOffsetStatements, ctx context.Context, pk string, docId string) (*PartitionOffsetCosmosData, error) { - response := PartitionOffsetCosmosData{} +func (s PartitionOffsetStatements) getCollectionName() string { + // Include the Prefix + tableName := fmt.Sprintf("%s_%s", s.prefix, s.tableName) + return cosmosdbapi.GetCollectionName(s.db.DatabaseName, tableName) +} + +func (s *PartitionOffsetStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.CosmosConfig.TenantName, s.getCollectionName()) +} + +func getPartitionOffset(s *PartitionOffsetStatements, ctx context.Context, pk string, docId string) (*partitionOffsetCosmosData, error) { + response := partitionOffsetCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.Connection, s.db.CosmosConfig, @@ -101,27 +111,6 @@ func getPartitionOffset(s *PartitionOffsetStatements, ctx context.Context, pk st return &response, err } -func queryPartitionOffset(s *PartitionOffsetStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PartitionOffsetCosmosData, error) { - var dbCollectionName = getCollectionName(*s) - var pk = cosmosdbapi.GetPartitionKey(s.db.CosmosConfig.ContainerName, dbCollectionName) - var response []PartitionOffsetCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.Connection).QueryDocuments( - ctx, - s.db.CosmosConfig.DatabaseName, - s.db.CosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - // Prepare converts the raw SQL statements into prepared statements. // Takes a prefix to prepend to the table name used to store the partition offsets. // This allows multiple components to share the same database schema. @@ -155,13 +144,18 @@ func (s *PartitionOffsetStatements) selectPartitionOffsets( // "SELECT partition, partition_offset FROM ${prefix}_partition_offsets WHERE topic = $1" - var dbCollectionName = getCollectionName(*s) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": topic, } - rows, err := queryPartitionOffset(s, ctx, s.selectPartitionOffsetsStmt, params) + var rows []partitionOffsetCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.Connection, + s.db.CosmosConfig.DatabaseName, + s.db.CosmosConfig.ContainerName, + s.getPartitionKey(), s.selectPartitionOffsetsStmt, params, &rows) + // rows, err := s.selectPartitionOffsetsStmt.QueryContext(ctx, topic) if err != nil { return nil, err @@ -197,25 +191,23 @@ func (s *PartitionOffsetStatements) upsertPartitionOffset( // stmt := TxStmt(txn, s.upsertPartitionOffsetStmt) - dbCollectionName := getCollectionName(*s) // UNIQUE (topic, partition) docId := fmt.Sprintf("%s_%d", topic, partition) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.CosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.CosmosConfig.ContainerName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.CosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getPartitionOffset(s, ctx, pk, cosmosDocId) + dbData, _ := getPartitionOffset(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.PartitionOffset.PartitionOffset = offset } else { - data := PartitionOffsetCosmos{ + data := partitionOffsetCosmos{ Partition: partition, PartitionOffset: offset, Topic: topic, } - dbData = &PartitionOffsetCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.CosmosConfig.TenantName, pk, cosmosDocId), + dbData = &partitionOffsetCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.CosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), PartitionOffset: data, } @@ -231,9 +223,3 @@ func (s *PartitionOffsetStatements) upsertPartitionOffset( &dbData) }) } - -func getCollectionName(s PartitionOffsetStatements) string { - // Include the Prefix - tableName := fmt.Sprintf("%s_%s", s.prefix, s.tableName) - return cosmosdbapi.GetCollectionName(s.db.DatabaseName, tableName) -} diff --git a/internal/naffka/naffkacosmosdb/naffka_topics_table.go b/internal/naffka/naffkacosmosdb/naffka_topics_table.go index 8bf4d18a7..67d912364 100644 --- a/internal/naffka/naffkacosmosdb/naffka_topics_table.go +++ b/internal/naffka/naffkacosmosdb/naffka_topics_table.go @@ -28,21 +28,21 @@ import ( // ); // ` -type TopicCosmos struct { +type topicCosmos struct { TopicName string `json:"topic_name"` TopicNID int64 `json:"topic_nid"` } -type TopicCosmosNumber struct { +type topicCosmosNumber struct { Number int64 `json:"number"` } -type TopicCosmosData struct { +type topicCosmosData struct { cosmosdbapi.CosmosDocument - Topic TopicCosmos `json:"mx_naffka_topic"` + Topic topicCosmos `json:"mx_naffka_topic"` } -type MessageCosmos struct { +type messageCosmos struct { TopicNID int64 `json:"topic_nid"` MessageOffset int64 `json:"message_offset"` MessageKey []byte `json:"message_key"` @@ -50,9 +50,9 @@ type MessageCosmos struct { MessageTimestampNS int64 `json:"message_timestamp_ns"` } -type MessageCosmosData struct { +type messageCosmosData struct { cosmosdbapi.CosmosDocument - Message MessageCosmos `json:"mx_naffka_message"` + Message messageCosmos `json:"mx_naffka_message"` } // const insertTopicSQL = "" + @@ -104,71 +104,24 @@ type topicsStatements struct { tableNameMessages string } -func queryTopic(s *topicsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]TopicCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.DB.databaseName, s.tableNameTopics) - var pk = cosmosdbapi.GetPartitionKey(s.DB.cosmosConfig.ContainerName, dbCollectionName) - var response []TopicCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.DB.connection).QueryDocuments( - ctx, - s.DB.cosmosConfig.DatabaseName, - s.DB.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *topicsStatements) getCollectionNameTopics() string { + return cosmosdbapi.GetCollectionName(s.DB.databaseName, s.tableNameTopics) } -func queryTopicNumber(s *topicsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]TopicCosmosNumber, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.DB.databaseName, s.tableNameTopics) - var pk = cosmosdbapi.GetPartitionKey(s.DB.cosmosConfig.ContainerName, dbCollectionName) - var response []TopicCosmosNumber - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.DB.connection).QueryDocuments( - ctx, - s.DB.cosmosConfig.DatabaseName, - s.DB.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *topicsStatements) getPartitionKeyTopics() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.DB.cosmosConfig.TenantName, s.getCollectionNameTopics()) } -func queryMessage(s *topicsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]MessageCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.DB.databaseName, s.tableNameMessages) - var pk = cosmosdbapi.GetPartitionKey(s.DB.cosmosConfig.ContainerName, dbCollectionName) - var response []MessageCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.DB.connection).QueryDocuments( - ctx, - s.DB.cosmosConfig.DatabaseName, - s.DB.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *topicsStatements) getCollectionNameMessages() string { + return cosmosdbapi.GetCollectionName(s.DB.databaseName, s.tableNameMessages) } -func getTopic(s *topicsStatements, ctx context.Context, pk string, docId string) (*TopicCosmosData, error) { - response := TopicCosmosData{} +func (s *topicsStatements) getPartitionKeyMessages() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.DB.cosmosConfig.TenantName, s.getCollectionNameMessages()) +} + +func getTopic(s *topicsStatements, ctx context.Context, pk string, docId string) (*topicCosmosData, error) { + response := topicCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.DB.connection, s.DB.cosmosConfig, @@ -212,24 +165,22 @@ func (t *topicsStatements) InsertTopic( // return errSeq // } - var dbCollectionName = cosmosdbapi.GetCollectionName(t.DB.databaseName, t.tableNameTopics) // topic_name TEXT UNIQUE, docId := fmt.Sprintf("%s", topicName) - cosmosDocId := cosmosdbapi.GetDocumentId(t.DB.cosmosConfig.ContainerName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(t.DB.cosmosConfig.ContainerName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(t.DB.cosmosConfig.ContainerName, t.getCollectionNameTopics(), docId) - dbData, _ := getTopic(t, ctx, pk, cosmosDocId) + dbData, _ := getTopic(t, ctx, t.getPartitionKeyTopics(), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.Topic.TopicName = topicName } else { - data := TopicCosmos{ + data := topicCosmos{ TopicNID: topicNID, TopicName: topicName, } - dbData = &TopicCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, t.DB.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &topicCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(t.getCollectionNameTopics(), t.DB.cosmosConfig.TenantName, t.getPartitionKeyTopics(), cosmosDocId), Topic: data, } } @@ -250,14 +201,18 @@ func (t *topicsStatements) SelectNextTopicNID( // "SELECT COUNT(topic_nid)+1 AS topic_nid FROM naffka_topics" - var dbCollectionName = cosmosdbapi.GetCollectionName(t.DB.databaseName, t.tableNameTopics) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": t.getCollectionNameTopics(), } // stmt := sqlutil.TxStmt(txn, t.selectNextTopicNIDStmt) // err = stmt.QueryRowContext(ctx).Scan(&topicNID) - rows, err := queryTopicNumber(t, ctx, t.selectNextTopicNIDStmt, params) + var rows []topicCosmosNumber + err = cosmosdbapi.PerformQuery(ctx, + t.DB.connection, + t.DB.cosmosConfig.DatabaseName, + t.DB.cosmosConfig.ContainerName, + t.getPartitionKeyTopics(), t.selectNextTopicNIDStmt, params, &rows) if err != nil { return 0, err @@ -279,14 +234,12 @@ func (t *topicsStatements) SelectTopic( // stmt := sqlutil.TxStmt(txn, t.selectTopicStmt) - var dbCollectionName = cosmosdbapi.GetCollectionName(t.DB.databaseName, t.tableNameTopics) // topic_name TEXT UNIQUE, docId := fmt.Sprintf("%s", topicName) - cosmosDocId := cosmosdbapi.GetDocumentId(t.DB.cosmosConfig.ContainerName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(t.DB.cosmosConfig.ContainerName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(t.DB.cosmosConfig.ContainerName, t.getCollectionNameTopics(), docId) // err = stmt.QueryRowContext(ctx, topicName).Scan(&topicNID) - res, err := getTopic(t, ctx, pk, cosmosDocId) + res, err := getTopic(t, ctx, t.getPartitionKeyTopics(), cosmosDocId) if err != nil { return 0, err @@ -304,14 +257,18 @@ func (t *topicsStatements) SelectTopics( // "SELECT topic_name, topic_nid FROM naffka_topics" - var dbCollectionName = cosmosdbapi.GetCollectionName(t.DB.databaseName, t.tableNameTopics) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": t.getCollectionNameTopics(), } // stmt := sqlutil.TxStmt(txn, t.selectTopicsStmt) // rows, err := stmt.QueryContext(ctx) - rows, err := queryTopic(t, ctx, t.selectTopicsStmt, params) + var rows []topicCosmosData + err := cosmosdbapi.PerformQuery(ctx, + t.DB.connection, + t.DB.cosmosConfig.DatabaseName, + t.DB.cosmosConfig.ContainerName, + t.getPartitionKeyTopics(), t.selectTopicsStmt, params, &rows) if err != nil { return nil, err @@ -340,13 +297,11 @@ func (t *topicsStatements) InsertTopics( // stmt := sqlutil.TxStmt(txn, t.insertTopicsStmt) - var dbCollectionName = cosmosdbapi.GetCollectionName(t.DB.databaseName, t.tableNameMessages) // UNIQUE (topic_nid, message_offset) docId := fmt.Sprintf("%d_%d", topicNID, messageOffset) - cosmosDocId := cosmosdbapi.GetDocumentId(t.DB.cosmosConfig.ContainerName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(t.DB.cosmosConfig.ContainerName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(t.DB.cosmosConfig.ContainerName, t.getCollectionNameMessages(), docId) - data := MessageCosmos{ + data := messageCosmos{ TopicNID: topicNID, MessageOffset: messageOffset, MessageKey: topicKey, @@ -354,8 +309,8 @@ func (t *topicsStatements) InsertTopics( MessageTimestampNS: messageTimestampNs, } - dbData := &MessageCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, t.DB.cosmosConfig.TenantName, pk, cosmosDocId), + dbData := &messageCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(t.getCollectionNameMessages(), t.DB.cosmosConfig.TenantName, t.getPartitionKeyMessages(), cosmosDocId), Message: data, } @@ -379,9 +334,8 @@ func (t *topicsStatements) SelectMessages( // " FROM naffka_messages WHERE topic_nid = $1 AND $2 <= message_offset AND message_offset < $3" + // " ORDER BY message_offset ASC" - var dbCollectionName = cosmosdbapi.GetCollectionName(t.DB.databaseName, t.tableNameMessages) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": t.getCollectionNameMessages(), "@x2": topicNID, "@x3": startOffset, "@x4": endOffset, @@ -389,7 +343,12 @@ func (t *topicsStatements) SelectMessages( // stmt := sqlutil.TxStmt(txn, t.selectMessagesStmt) // rows, err := stmt.QueryContext(ctx, topicNID, startOffset, endOffset) - rows, err := queryMessage(t, ctx, t.selectMessagesStmt, params) + var rows []messageCosmosData + err := cosmosdbapi.PerformQuery(ctx, + t.DB.connection, + t.DB.cosmosConfig.DatabaseName, + t.DB.cosmosConfig.ContainerName, + t.getPartitionKeyMessages(), t.selectMessagesStmt, params, &rows) if err != nil { return nil, err @@ -416,15 +375,19 @@ func (t *topicsStatements) SelectMaxOffset( // "SELECT message_offset FROM naffka_messages WHERE topic_nid = $1" + // " ORDER BY message_offset DESC LIMIT 1" - var dbCollectionName = cosmosdbapi.GetCollectionName(t.DB.databaseName, t.tableNameMessages) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": t.getCollectionNameMessages(), "@x2": topicNID, } // stmt := sqlutil.TxStmt(txn, t.selectMaxOffsetStmt) // err = stmt.QueryRowContext(ctx, topicNID).Scan(&offset) - rows, err := queryMessage(t, ctx, t.selectMaxOffsetStmt, params) + var rows []messageCosmosData + err = cosmosdbapi.PerformQuery(ctx, + t.DB.connection, + t.DB.cosmosConfig.DatabaseName, + t.DB.cosmosConfig.ContainerName, + t.getPartitionKeyMessages(), t.selectMaxOffsetStmt, params, &rows) if err != nil { return 0, err diff --git a/keyserver/storage/cosmosdb/cross_signing_keys_table.go b/keyserver/storage/cosmosdb/cross_signing_keys_table.go index 99ce6af05..735d062ff 100644 --- a/keyserver/storage/cosmosdb/cross_signing_keys_table.go +++ b/keyserver/storage/cosmosdb/cross_signing_keys_table.go @@ -34,15 +34,15 @@ import ( // ); // ` -type CrossSigningKeysCosmos struct { +type crossSigningKeysCosmos struct { UserID string `json:"user_id"` KeyType int64 `json:"key_type"` KeyData []byte `json:"key_data"` } -type CrossSigningKeysCosmosData struct { +type crossSigningKeysCosmosData struct { cosmosdbapi.CosmosDocument - CrossSigningKeys CrossSigningKeysCosmos `json:"mx_keyserver_cross_signing_keys"` + CrossSigningKeys crossSigningKeysCosmos `json:"mx_keyserver_cross_signing_keys"` } // "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + @@ -62,8 +62,16 @@ type crossSigningKeysStatements struct { tableName string } -func getCrossSigningKeys(s *crossSigningKeysStatements, ctx context.Context, pk string, docId string) (*CrossSigningKeysCosmosData, error) { - response := CrossSigningKeysCosmosData{} +func (s *crossSigningKeysStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *crossSigningKeysStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getCrossSigningKeys(s *crossSigningKeysStatements, ctx context.Context, pk string, docId string) (*crossSigningKeysCosmosData, error) { + response := crossSigningKeysCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -79,27 +87,6 @@ func getCrossSigningKeys(s *crossSigningKeysStatements, ctx context.Context, pk return &response, err } -func queryCrossSigningKeys(s *crossSigningKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]CrossSigningKeysCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []CrossSigningKeysCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - func NewSqliteCrossSigningKeysTable(db *Database) (tables.CrossSigningKeys, error) { s := &crossSigningKeysStatements{ db: db, @@ -115,12 +102,18 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( ) (r types.CrossSigningKeyMap, err error) { // "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + // " WHERE user_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": userID, } - rows, err := queryCrossSigningKeys(s, ctx, s.selectCrossSigningKeysForUserStmt, params) + + var rows []crossSigningKeysCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectCrossSigningKeysForUserStmt, params, &rows) + // rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID) if err != nil { return nil, err @@ -154,25 +147,23 @@ func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( if !ok { return fmt.Errorf("unknown key purpose %q", keyType) } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) // PRIMARY KEY (user_id, key_type) docId := fmt.Sprintf("%s_%s", userID, keyType) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getCrossSigningKeys(s, ctx, pk, cosmosDocId) + dbData, _ := getCrossSigningKeys(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.CrossSigningKeys.KeyData = keyData } else { - data := CrossSigningKeysCosmos{ + data := crossSigningKeysCosmos{ UserID: userID, KeyType: int64(keyTypeInt), KeyData: keyData, } - dbData = &CrossSigningKeysCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &crossSigningKeysCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CrossSigningKeys: data, } } diff --git a/keyserver/storage/cosmosdb/cross_signing_sigs_table.go b/keyserver/storage/cosmosdb/cross_signing_sigs_table.go index b6c018185..5e1cc0e3e 100644 --- a/keyserver/storage/cosmosdb/cross_signing_sigs_table.go +++ b/keyserver/storage/cosmosdb/cross_signing_sigs_table.go @@ -36,7 +36,7 @@ import ( // ); // ` -type CrossSigningSigsCosmos struct { +type crossSigningSigsCosmos struct { OriginUserId string `json:"origin_user_id"` OriginKeyId string `json:"origin_key_id"` TargetUserId string `json:"target_user_id"` @@ -44,9 +44,9 @@ type CrossSigningSigsCosmos struct { Signature []byte `json:"signature"` } -type CrossSigningSigsCosmosData struct { +type crossSigningSigsCosmosData struct { cosmosdbapi.CosmosDocument - CrossSigningSigs CrossSigningSigsCosmos `json:"mx_keyserver_cross_signing_sigs"` + CrossSigningSigs crossSigningSigsCosmos `json:"mx_keyserver_cross_signing_sigs"` } // "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + @@ -74,8 +74,16 @@ type crossSigningSigsStatements struct { tableName string } -func getCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, pk string, docId string) (*CrossSigningSigsCosmosData, error) { - response := CrossSigningSigsCosmosData{} +func (s *crossSigningSigsStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *crossSigningSigsStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, pk string, docId string) (*crossSigningSigsCosmosData, error) { + response := crossSigningSigsCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -91,28 +99,7 @@ func getCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, pk return &response, err } -func queryCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]CrossSigningSigsCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []CrossSigningSigsCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - -func deleteCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, dbData CrossSigningSigsCosmosData) error { +func deleteCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, dbData crossSigningSigsCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, @@ -147,13 +134,19 @@ func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( ) (r types.CrossSigningSigMap, err error) { // "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + // " WHERE target_user_id = $1 AND target_key_id = $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": targetUserID, "@x3": targetKeyID, } - rows, err := queryCrossSigningSigs(s, ctx, s.selectCrossSigningSigsForTargetStmt, params) + + var rows []crossSigningSigsCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectCrossSigningSigsForTargetStmt, params, &rows) + // rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID) if err != nil { return nil, err @@ -187,19 +180,18 @@ func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( ) error { // "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + // " VALUES($1, $2, $3, $4, $5)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + // PRIMARY KEY (origin_user_id, target_user_id, target_key_id) docId := fmt.Sprintf("%s_%s_%s", originUserID, targetUserID, targetKeyID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getCrossSigningSigs(s, ctx, pk, cosmosDocId) + dbData, _ := getCrossSigningSigs(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.CrossSigningSigs.OriginKeyId = string(originKeyID) dbData.CrossSigningSigs.Signature = signature } else { - data := CrossSigningSigsCosmos{ + data := crossSigningSigsCosmos{ TargetUserId: targetUserID, TargetKeyId: string(targetKeyID), OriginUserId: originUserID, @@ -207,8 +199,8 @@ func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( Signature: signature, } - dbData = &CrossSigningSigsCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &crossSigningSigsCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CrossSigningSigs: data, } } @@ -228,13 +220,18 @@ func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget( targetUserID string, targetKeyID gomatrixserverlib.KeyID, ) error { // "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": targetUserID, "@x3": targetKeyID, } - rows, err := queryCrossSigningSigs(s, ctx, s.selectCrossSigningSigsForTargetStmt, params) + var rows []crossSigningSigsCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectCrossSigningSigsForTargetStmt, params, &rows) + // if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil { // return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err) // } diff --git a/keyserver/storage/cosmosdb/device_keys_table.go b/keyserver/storage/cosmosdb/device_keys_table.go index c67a231e0..cf17c1c6e 100644 --- a/keyserver/storage/cosmosdb/device_keys_table.go +++ b/keyserver/storage/cosmosdb/device_keys_table.go @@ -39,7 +39,7 @@ import ( // ); // ` -type DeviceKeyCosmos struct { +type deviceKeyCosmos struct { UserID string `json:"user_id"` DeviceID string `json:"device_id"` // Use the CosmosDB.Timestamp for this one @@ -49,13 +49,13 @@ type DeviceKeyCosmos struct { DisplayName string `json:"display_name"` } -type DeviceKeyCosmosNumber struct { +type deviceKeyCosmosNumber struct { Number int64 `json:"number"` } -type DeviceKeyCosmosData struct { +type deviceKeyCosmosData struct { cosmosdbapi.CosmosDocument - DeviceKey DeviceKeyCosmos `json:"mx_keyserver_device_key"` + DeviceKey deviceKeyCosmos `json:"mx_keyserver_device_key"` } // const upsertDeviceKeysSQL = "" + @@ -97,54 +97,8 @@ const deleteDeviceKeysSQL = "" + // const deleteAllDeviceKeysSQL = "" + // "DELETE FROM keyserver_device_keys WHERE user_id=$1" -func queryDeviceKey(s *deviceKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceKeyCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []DeviceKeyCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - -func queryDeviceKeyNumber(s *deviceKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceKeyCosmosNumber, error) { - var response []DeviceKeyCosmosNumber - - var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() - var query = cosmosdbapi.GetQuery(qry, params) - var _, _ = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - //WHen there is no data these GroupBy queries return errors - // if err != nil { - // return nil, err - // } - - if len(response) == 0 { - return nil, cosmosdbutil.ErrNoRows - } - - return response, nil -} - -func getDeviceKey(s *deviceKeysStatements, ctx context.Context, pk string, docId string) (*DeviceKeyCosmosData, error) { - response := DeviceKeyCosmosData{} +func getDeviceKey(s *deviceKeysStatements, ctx context.Context, pk string, docId string) (*deviceKeyCosmosData, error) { + response := deviceKeyCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -160,7 +114,7 @@ func getDeviceKey(s *deviceKeysStatements, ctx context.Context, pk string, docId return &response, err } -func insertDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData DeviceKeyCosmosData) error { +func insertDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData deviceKeyCosmosData) error { // "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + // " VALUES ($1, $2, $3, $4, $5, $6)" + // " ON CONFLICT (user_id, device_id)" + @@ -189,8 +143,8 @@ func insertDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData De return nil } -func mapFromDeviceKeyMessage(key api.DeviceMessage) DeviceKeyCosmos { - return DeviceKeyCosmos{ +func mapFromDeviceKeyMessage(key api.DeviceMessage) deviceKeyCosmos { + return deviceKeyCosmos{ DeviceID: key.DeviceID, DisplayName: key.DisplayName, KeyJSON: key.KeyJSON, @@ -210,6 +164,14 @@ type deviceKeysStatements struct { tableName string } +func (s *deviceKeysStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *deviceKeysStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + func NewCosmosDBDeviceKeysTable(db *Database) (tables.DeviceKeys, error) { s := &deviceKeysStatements{ db: db, @@ -221,7 +183,7 @@ func NewCosmosDBDeviceKeysTable(db *Database) (tables.DeviceKeys, error) { return s, nil } -func deleteDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData DeviceKeyCosmosData) error { +func deleteDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData deviceKeyCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, @@ -239,19 +201,24 @@ func deleteDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData De func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { // "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" // _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": userID, "@x3": deviceID, } - response, err := queryDeviceKey(s, ctx, selectAllDeviceKeysSQL, params) + + var rows []deviceKeyCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), selectAllDeviceKeysSQL, params, &rows) if err != nil { return err } - for _, item := range response { + for _, item := range rows { errItem := deleteDeviceKeyCore(s, ctx, item) if errItem != nil { return errItem @@ -265,18 +232,23 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql // "DELETE FROM keyserver_device_keys WHERE user_id=$1" // _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": userID, } - response, err := queryDeviceKey(s, ctx, selectAllDeviceKeysSQL, params) + + var rows []deviceKeyCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), selectAllDeviceKeysSQL, params, &rows) if err != nil { return err } - for _, item := range response { + for _, item := range rows { errItem := deleteDeviceKeyCore(s, ctx, item) if errItem != nil { return errItem @@ -293,12 +265,18 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID for _, d := range deviceIDs { deviceIDMap[d] = true } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": userID, } - response, err := queryDeviceKey(s, ctx, s.selectBatchDeviceKeysStmt, params) + + var rows []deviceKeyCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectBatchDeviceKeysStmt, params, &rows) + // rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) if err != nil { return nil, err @@ -306,7 +284,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID // defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") var result []api.DeviceMessage - for _, item := range response { + for _, item := range rows { dk := api.DeviceMessage{ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{}, @@ -344,13 +322,12 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] // "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" // err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // UNIQUE (user_id, device_id) docId := fmt.Sprintf("%s_%s", key.UserID, key.DeviceID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - response, err := getDeviceKey(s, ctx, pk, cosmosDocId) + response, err := getDeviceKey(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil && err != cosmosdbutil.ErrNoRows { return err @@ -377,15 +354,19 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn // "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ "@x1": s.db.cosmosConfig.TenantName, - "@x2": dbCollectionName, + "@x2": s.getCollectionName(), "@x3": userID, } // err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) - response, err := queryDeviceKeyNumber(s, ctx, selectMaxStreamForUserSQL, params) + var rows []deviceKeyCosmosNumber + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), selectMaxStreamForUserSQL, params, &rows) if err != nil { if err == cosmosdbutil.ErrNoRows { @@ -395,8 +376,8 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn } } - if len(response) > 0 { - nullStream.Int32 = int32(response[0].Number) + if len(rows) > 0 { + nullStream.Int32 = int32(rows[0].Number) } if nullStream.Valid { @@ -415,10 +396,9 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID iStreamIDs[i+1] = streamIDs[i] } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ "@x1": s.db.cosmosConfig.TenantName, - "@x2": dbCollectionName, + "@x2": s.getCollectionName(), "@x3": userID, "@x4": iStreamIDs, } @@ -428,7 +408,12 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID // var count sql.NullInt32 // err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count) - response, err := queryDeviceKeyNumber(s, ctx, countStreamIDsForUserSQL, params) + var rows []deviceKeyCosmosNumber + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), countStreamIDsForUserSQL, params, &rows) if err != nil { return 0, err @@ -436,8 +421,8 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID // if count.Valid { // return int(count.Int32), nil // } - if response[0].Number >= 0 { - return int(response[0].Number), nil + if rows[0].Number >= 0 { + return int(rows[0].Number), nil } return 0, nil } @@ -448,16 +433,14 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx // " VALUES ($1, $2, $3, $4, $5, $6)" + // " ON CONFLICT (user_id, device_id)" + // " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) for _, key := range keys { // UNIQUE (user_id, device_id) docId := fmt.Sprintf("%s_%s", key.UserID, key.DeviceID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData := &DeviceKeyCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData := &deviceKeyCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), DeviceKey: mapFromDeviceKeyMessage(key), } diff --git a/keyserver/storage/cosmosdb/key_changes_table.go b/keyserver/storage/cosmosdb/key_changes_table.go index 8fdcc157c..72e81174b 100644 --- a/keyserver/storage/cosmosdb/key_changes_table.go +++ b/keyserver/storage/cosmosdb/key_changes_table.go @@ -35,20 +35,20 @@ import ( // ); // ` -type KeyChangeCosmos struct { +type keyChangeCosmos struct { Partition int32 `json:"partition"` Offset int64 `json:"_offset"` //offset is reserved UserID string `json:"user_id"` } -type KeyChangeUserMaxCosmosData struct { +type keyChangeUserMaxCosmosData struct { UserID string `json:"user_id"` MaxOffset int64 `json:"max_offset"` } -type KeyChangeCosmosData struct { +type keyChangeCosmosData struct { cosmosdbapi.CosmosDocument - KeyChange KeyChangeCosmos `json:"mx_keyserver_key_change"` + KeyChange keyChangeCosmos `json:"mx_keyserver_key_change"` } // Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped. @@ -78,8 +78,16 @@ type keyChangesStatements struct { tableName string } -func getKeyChangeUser(s *keyChangesStatements, ctx context.Context, pk string, docId string) (*KeyChangeCosmosData, error) { - response := KeyChangeCosmosData{} +func (s *keyChangesStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *keyChangesStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getKeyChangeUser(s *keyChangesStatements, ctx context.Context, pk string, docId string) (*keyChangeCosmosData, error) { + response := keyChangeCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -95,27 +103,6 @@ func getKeyChangeUser(s *keyChangesStatements, ctx context.Context, pk string, d return &response, err } -func queryKeyChangeUserMax(s *keyChangesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyChangeUserMaxCosmosData, error) { - var response []KeyChangeUserMaxCosmosData - - var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() - var query = cosmosdbapi.GetQuery(qry, params) - var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - // When there are no Rows we seem to get the generic Bad Req JSON error - if err != nil { - // return nil, err - } - - return response, nil -} - func NewCosmosDBKeyChangesTable(db *Database) (tables.KeyChanges, error) { s := &keyChangesStatements{ db: db, @@ -132,25 +119,23 @@ func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition in // " ON CONFLICT (partition, offset)" + // " DO UPDATE SET user_id = $3" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) // UNIQUE (partition, offset) docId := fmt.Sprintf("%d_%d", partition, offset) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getKeyChangeUser(s, ctx, pk, cosmosDocId) + dbData, _ := getKeyChangeUser(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.KeyChange.UserID = userID } else { - data := KeyChangeCosmos{ + data := keyChangeCosmos{ Offset: offset, Partition: partition, UserID: userID, } - dbData = &KeyChangeCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &keyChangeCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), KeyChange: data, } } @@ -175,22 +160,26 @@ func (s *keyChangesStatements) SelectKeyChanges( // "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id" // rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ "@x1": s.db.cosmosConfig.TenantName, - "@x2": dbCollectionName, + "@x2": s.getCollectionName(), "@x3": partition, "@x4": fromOffset, "@x5": toOffset, } - response, err := queryKeyChangeUserMax(s, ctx, s.selectKeyChangesStmt, params) + var rows []keyChangeUserMaxCosmosData + err = cosmosdbapi.PerformQueryAllPartitions(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.selectKeyChangesStmt, params, &rows) if err != nil { return nil, 0, err } - for _, item := range response { + for _, item := range rows { var userID string var offset int64 userID = item.UserID diff --git a/keyserver/storage/cosmosdb/one_time_keys_table.go b/keyserver/storage/cosmosdb/one_time_keys_table.go index e85ce6f00..0f2a52e5f 100644 --- a/keyserver/storage/cosmosdb/one_time_keys_table.go +++ b/keyserver/storage/cosmosdb/one_time_keys_table.go @@ -42,7 +42,7 @@ import ( // ); // ` -type OneTimeKeyCosmos struct { +type oneTimeKeyCosmos struct { UserID string `json:"user_id"` DeviceID string `json:"device_id"` KeyID string `json:"key_id"` @@ -52,14 +52,14 @@ type OneTimeKeyCosmos struct { KeyJSON []byte `json:"key_json"` } -type OneTimeKeyAlgoNumberCosmosData struct { +type oneTimeKeyAlgoNumberCosmosData struct { Algorithm string `json:"algorithm"` Number int `json:"number"` } -type OneTimeKeyCosmosData struct { +type oneTimeKeyCosmosData struct { cosmosdbapi.CosmosDocument - OneTimeKey OneTimeKeyCosmos `json:"mx_keyserver_one_time_key"` + OneTimeKey oneTimeKeyCosmos `json:"mx_keyserver_one_time_key"` } // const upsertKeysSQL = "" + @@ -102,8 +102,16 @@ type oneTimeKeysStatements struct { tableName string } -func getOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, pk string, docId string) (*OneTimeKeyCosmosData, error) { - response := OneTimeKeyCosmosData{} +func (s *oneTimeKeysStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *oneTimeKeysStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, pk string, docId string) (*oneTimeKeyCosmosData, error) { + response := oneTimeKeyCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -119,53 +127,7 @@ func getOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, pk string, doc return &response, err } -func queryOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyCosmosData, error) { - var response []OneTimeKeyCosmosData - - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - - return response, nil -} - -func queryOneTimeKeyAlgoCount(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyAlgoNumberCosmosData, error) { - var response []OneTimeKeyAlgoNumberCosmosData - - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - // var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() - var query = cosmosdbapi.GetQuery(qry, params) - var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - // When there are no Rows we seem to get the generic Bad Req JSON error - if err != nil { - // return nil, err - } - - return response, nil -} - -func insertOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData OneTimeKeyCosmosData) error { +func insertOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData oneTimeKeyCosmosData) error { // "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + // " VALUES ($1, $2, $3, $4, $5, $6)" + // " ON CONFLICT (user_id, device_id, key_id, algorithm)" + @@ -191,7 +153,7 @@ func insertOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData return nil } -func deleteOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData OneTimeKeyCosmosData) error { +func deleteOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData oneTimeKeyCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, @@ -221,14 +183,19 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d // "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": userID, "@x3": deviceID, } - response, err := queryOneTimeKey(s, ctx, s.selectKeyByAlgorithmStmt, params) + var rows []oneTimeKeyCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectKeyByAlgorithmStmt, params, &rows) + if err != nil { return nil, err } @@ -239,7 +206,7 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d } result := make(map[string]json.RawMessage) - for _, item := range response { + for _, item := range rows { var keyID string var algorithm string keyID = item.OneTimeKey.KeyID @@ -260,21 +227,25 @@ func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, de KeyCount: make(map[string]int), } // rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": counts.UserID, "@x3": counts.DeviceID, } // "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" - response, err := queryOneTimeKeyAlgoCount(s, ctx, s.selectKeysCountStmt, params) + var rows []oneTimeKeyAlgoNumberCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectKeysCountStmt, params, &rows) if err != nil { return nil, err } - for _, item := range response { + for _, item := range rows { var algorithm string var count int algorithm = item.Algorithm @@ -293,9 +264,6 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys( KeyCount: make(map[string]int), } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - for keyIDWithAlgo, keyJSON := range keys.KeyJSON { // "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + @@ -307,9 +275,9 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys( // UNIQUE (user_id, device_id, key_id, algorithm) docId := fmt.Sprintf("%s_%s_%s_%s", keys.UserID, keys.DeviceID, keyID, algo) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - data := OneTimeKeyCosmos{ + data := oneTimeKeyCosmos{ Algorithm: algo, DeviceID: keys.DeviceID, KeyID: keyID, @@ -317,8 +285,8 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys( UserID: keys.UserID, } - dbData := &OneTimeKeyCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData := &oneTimeKeyCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), OneTimeKey: data, } @@ -330,19 +298,24 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys( } // rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": keys.UserID, "@x3": keys.DeviceID, } // "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" - response, err := queryOneTimeKeyAlgoCount(s, ctx, s.selectKeysCountStmt, params) + var rows []oneTimeKeyAlgoNumberCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectKeysCountStmt, params, &rows) if err != nil { return nil, err } - for _, item := range response { + for _, item := range rows { var algorithm string var count int algorithm = item.Algorithm @@ -361,24 +334,29 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( // "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": userID, "@x3": deviceID, "@x4": algorithm, } - response, err := queryOneTimeKey(s, ctx, s.selectKeyByAlgorithmStmt, params) + var rows []oneTimeKeyCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectKeyByAlgorithmStmt, params, &rows) + if err != nil { if err == cosmosdbutil.ErrNoRows { return nil, nil } return nil, err } - keyID = response[0].OneTimeKey.KeyID - keyJSONBytes := response[0].OneTimeKey.KeyJSON - err = deleteOneTimeKeyCore(s, ctx, response[0]) + keyID = rows[0].OneTimeKey.KeyID + keyJSONBytes := rows[0].OneTimeKey.KeyJSON + err = deleteOneTimeKeyCore(s, ctx, rows[0]) if err != nil { return nil, err } diff --git a/keyserver/storage/cosmosdb/stale_device_lists.go b/keyserver/storage/cosmosdb/stale_device_lists.go index 6d5770ab9..60cfa46c3 100644 --- a/keyserver/storage/cosmosdb/stale_device_lists.go +++ b/keyserver/storage/cosmosdb/stale_device_lists.go @@ -35,16 +35,16 @@ import ( // CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); // ` -type StaleDeviceListCosmos struct { +type staleDeviceListCosmos struct { UserID string `json:"user_id"` Domain string `json:"domain"` IsStale bool `json:"is_stale"` AddedSecs int64 `json:"ts_added_secs"` } -type StaleDeviceListCosmosData struct { +type staleDeviceListCosmosData struct { cosmosdbapi.CosmosDocument - StaleDeviceList StaleDeviceListCosmos `json:"mx_keyserver_stale_device_list"` + StaleDeviceList staleDeviceListCosmos `json:"mx_keyserver_stale_device_list"` } // const upsertStaleDeviceListSQL = "" + @@ -72,8 +72,16 @@ type staleDeviceListsStatements struct { tableName string } -func getStaleDeviceList(s *staleDeviceListsStatements, ctx context.Context, pk string, docId string) (*StaleDeviceListCosmosData, error) { - response := StaleDeviceListCosmosData{} +func (s *staleDeviceListsStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *staleDeviceListsStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getStaleDeviceList(s *staleDeviceListsStatements, ctx context.Context, pk string, docId string) (*staleDeviceListCosmosData, error) { + response := staleDeviceListCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -89,26 +97,6 @@ func getStaleDeviceList(s *staleDeviceListsStatements, ctx context.Context, pk s return &response, err } -func queryStaleDeviceList(s *staleDeviceListsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]StaleDeviceListCosmosData, error) { - var response []StaleDeviceListCosmosData - - var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() - var query = cosmosdbapi.GetQuery(qry, params) - var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - - return response, nil -} - func NewCosmosDBStaleDeviceListsTable(db *Database) (tables.StaleDeviceLists, error) { s := &staleDeviceListsStatements{ db: db, @@ -131,26 +119,24 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, return err } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) // user_id TEXT PRIMARY KEY NOT NULL, docId := userID - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getStaleDeviceList(s, ctx, pk, cosmosDocId) + dbData, _ := getStaleDeviceList(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.StaleDeviceList.IsStale = isStale dbData.StaleDeviceList.AddedSecs = time.Now().Unix() } else { - data := StaleDeviceListCosmos{ + data := staleDeviceListCosmos{ Domain: string(domain), IsStale: isStale, UserID: userID, } - dbData = &StaleDeviceListCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &staleDeviceListCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), StaleDeviceList: data, } } @@ -165,17 +151,22 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { // we only query for 1 domain or all domains so optimise for those use cases - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) if len(domains) == 0 { // "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" // rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) params := map[string]interface{}{ "@x1": s.db.cosmosConfig.TenantName, - "@x2": dbCollectionName, + "@x2": s.getCollectionName(), "@x3": true, } - rows, err := queryStaleDeviceList(s, ctx, s.selectStaleDeviceListsWithDomainsStmt, params) + + var rows []staleDeviceListCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectStaleDeviceListsStmt, params, &rows) if err != nil { return nil, err @@ -188,12 +179,17 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte // "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" // rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": true, "@x3": string(domain), } - rows, err := queryStaleDeviceList(s, ctx, s.selectStaleDeviceListsWithDomainsStmt, params) + var rows []staleDeviceListCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectStaleDeviceListsWithDomainsStmt, params, &rows) if err != nil { return nil, err @@ -207,7 +203,7 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte return result, nil } -func rowsToUserIDs(ctx context.Context, rows []StaleDeviceListCosmosData) (result []string, err error) { +func rowsToUserIDs(ctx context.Context, rows []staleDeviceListCosmosData) (result []string, err error) { for _, item := range rows { var userID string userID = item.StaleDeviceList.UserID diff --git a/mediaapi/storage/cosmosdb/media_repository_table.go b/mediaapi/storage/cosmosdb/media_repository_table.go index e8939dcc2..ff0d010bf 100644 --- a/mediaapi/storage/cosmosdb/media_repository_table.go +++ b/mediaapi/storage/cosmosdb/media_repository_table.go @@ -54,7 +54,7 @@ import ( // CREATE UNIQUE INDEX IF NOT EXISTS mediaapi_media_repository_index ON mediaapi_media_repository (media_id, media_origin); // ` -type MediaRepositoryCosmos struct { +type mediaRepositoryCosmos struct { MediaID string `json:"media_id"` MediaOrigin string `json:"media_origin"` ContentType string `json:"content_type"` @@ -65,9 +65,9 @@ type MediaRepositoryCosmos struct { UserID string `json:"user_id"` } -type MediaRepositoryCosmosData struct { +type mediaRepositoryCosmosData struct { cosmosdbapi.CosmosDocument - MediaRepository MediaRepositoryCosmos `json:"mx_mediaapi_media_repository"` + MediaRepository mediaRepositoryCosmos `json:"mx_mediaapi_media_repository"` } // const insertMediaSQL = ` @@ -94,29 +94,16 @@ type mediaStatements struct { tableName string } -func queryMediaRepository(s *mediaStatements, ctx context.Context, qry string, params map[string]interface{}) ([]MediaRepositoryCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []MediaRepositoryCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *mediaStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func getMediaRepository(s *mediaStatements, ctx context.Context, pk string, docId string) (*MediaRepositoryCosmosData, error) { - response := MediaRepositoryCosmosData{} +func (s *mediaStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getMediaRepository(s *mediaStatements, ctx context.Context, pk string, docId string) (*mediaRepositoryCosmosData, error) { + response := mediaRepositoryCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -150,13 +137,11 @@ func (s *mediaStatements) insertMedia( // INSERT INTO mediaapi_media_repository (media_id, media_origin, content_type, file_size_bytes, creation_ts, upload_name, base64hash, user_id) // VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // CREATE UNIQUE INDEX IF NOT EXISTS mediaapi_media_repository_index ON mediaapi_media_repository (media_id, media_origin); docId := fmt.Sprintf("%s_%s", mediaMetadata.MediaID, mediaMetadata.Origin) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - data := MediaRepositoryCosmos{ + data := mediaRepositoryCosmos{ MediaID: string(mediaMetadata.MediaID), MediaOrigin: string(mediaMetadata.Origin), ContentType: string(mediaMetadata.ContentType), @@ -167,8 +152,8 @@ func (s *mediaStatements) insertMedia( UserID: string(mediaMetadata.UserID), } - dbData := &MediaRepositoryCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData := &mediaRepositoryCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), MediaRepository: data, } @@ -207,15 +192,13 @@ func (s *mediaStatements) selectMedia( // SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user_id FROM mediaapi_media_repository WHERE media_id = $1 AND media_origin = $2 - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // CREATE UNIQUE INDEX IF NOT EXISTS mediaapi_media_repository_index ON mediaapi_media_repository (media_id, media_origin); docId := fmt.Sprintf("%s_%s", mediaMetadata.MediaID, mediaMetadata.Origin) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) // err := s.selectMediaStmt.QueryRowContext( // ctx, mediaMetadata.MediaID, mediaMetadata.Origin, - row, err := getMediaRepository(s, ctx, pk, cosmosDocId) + row, err := getMediaRepository(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return nil, err @@ -245,9 +228,8 @@ func (s *mediaStatements) selectMediaByHash( // SELECT content_type, file_size_bytes, creation_ts, upload_name, media_id, user_id FROM mediaapi_media_repository WHERE base64hash = $1 AND media_origin = $2 - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": mediaHash, "@x3": mediaOrigin, } @@ -255,7 +237,12 @@ func (s *mediaStatements) selectMediaByHash( // err := s.selectMediaStmt.QueryRowContext( // ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin, // ).Scan( - rows, err := queryMediaRepository(s, ctx, s.selectMediaByHashStmt, params) + var rows []mediaRepositoryCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectMediaByHashStmt, params, &rows) if err != nil { return nil, err diff --git a/mediaapi/storage/cosmosdb/thumbnail_table.go b/mediaapi/storage/cosmosdb/thumbnail_table.go index a62a9b34b..058d1be70 100644 --- a/mediaapi/storage/cosmosdb/thumbnail_table.go +++ b/mediaapi/storage/cosmosdb/thumbnail_table.go @@ -43,7 +43,7 @@ import ( // CREATE UNIQUE INDEX IF NOT EXISTS mediaapi_thumbnail_index ON mediaapi_thumbnail (media_id, media_origin, width, height, resize_method); // ` -type ThumbnailCosmos struct { +type thumbnailCosmos struct { MediaID string `json:"media_id"` MediaOrigin string `json:"media_origin"` ContentType string `json:"content_type"` @@ -54,9 +54,9 @@ type ThumbnailCosmos struct { ResizeMethod string `json:"resize_method"` } -type ThumbnailCosmosData struct { +type thumbnailCosmosData struct { cosmosdbapi.CosmosDocument - Thumbnail ThumbnailCosmos `json:"mx_mediaapi_thumbnail"` + Thumbnail thumbnailCosmos `json:"mx_mediaapi_thumbnail"` } // const insertThumbnailSQL = ` @@ -85,29 +85,16 @@ type thumbnailStatements struct { tableName string } -func queryThumbnail(s *thumbnailStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ThumbnailCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []ThumbnailCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *thumbnailStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func getThumbnail(s *thumbnailStatements, ctx context.Context, pk string, docId string) (*ThumbnailCosmosData, error) { - response := ThumbnailCosmosData{} +func (s *thumbnailStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getThumbnail(s *thumbnailStatements, ctx context.Context, pk string, docId string) (*thumbnailCosmosData, error) { + response := thumbnailCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -142,7 +129,6 @@ func (s *thumbnailStatements) insertThumbnail( // return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { // stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // CREATE UNIQUE INDEX IF NOT EXISTS mediaapi_thumbnail_index ON mediaapi_thumbnail (media_id, media_origin, width, height, resize_method); docId := fmt.Sprintf("%s_%s_%d_%d_%s", thumbnailMetadata.MediaMetadata.MediaID, @@ -151,8 +137,7 @@ func (s *thumbnailStatements) insertThumbnail( thumbnailMetadata.ThumbnailSize.Height, thumbnailMetadata.ThumbnailSize.ResizeMethod, ) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) // _, err := stmt.ExecContext( // ctx, @@ -166,7 +151,7 @@ func (s *thumbnailStatements) insertThumbnail( // thumbnailMetadata.ThumbnailSize.ResizeMethod, // ) - data := ThumbnailCosmos{ + data := thumbnailCosmos{ MediaID: string(thumbnailMetadata.MediaMetadata.MediaID), MediaOrigin: string(thumbnailMetadata.MediaMetadata.Origin), ContentType: string(thumbnailMetadata.MediaMetadata.ContentType), @@ -177,8 +162,8 @@ func (s *thumbnailStatements) insertThumbnail( ResizeMethod: string(thumbnailMetadata.ThumbnailSize.ResizeMethod), } - dbData := &ThumbnailCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData := &thumbnailCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), Thumbnail: data, } @@ -213,7 +198,6 @@ func (s *thumbnailStatements) selectThumbnail( // SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 AND width = $3 AND height = $4 AND resize_method = $5 - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // CREATE UNIQUE INDEX IF NOT EXISTS mediaapi_thumbnail_index ON mediaapi_thumbnail (media_id, media_origin, width, height, resize_method); docId := fmt.Sprintf("%s_%s_%d_%d_%s", mediaID, @@ -222,11 +206,10 @@ func (s *thumbnailStatements) selectThumbnail( height, resizeMethod, ) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) // row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID) - row, err := getThumbnail(s, ctx, pk, cosmosDocId) + row, err := getThumbnail(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return nil, err @@ -253,9 +236,8 @@ func (s *thumbnailStatements) selectThumbnails( // SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": mediaID, "@x3": mediaOrigin, } @@ -263,7 +245,12 @@ func (s *thumbnailStatements) selectThumbnails( // rows, err := s.selectThumbnailsStmt.QueryContext( // ctx, mediaID, mediaOrigin, // ) - rows, err := queryThumbnail(s, ctx, s.selectThumbnailsStmt, params) + var rows []thumbnailCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectThumbnailsStmt, params, &rows) if err != nil { return nil, err diff --git a/roomserver/storage/cosmosdb/event_json_table.go b/roomserver/storage/cosmosdb/event_json_table.go index 517dbde6c..48ff4fed7 100644 --- a/roomserver/storage/cosmosdb/event_json_table.go +++ b/roomserver/storage/cosmosdb/event_json_table.go @@ -33,14 +33,14 @@ import ( // ); // ` -type EventJSONCosmos struct { +type eventJSONCosmos struct { EventNID int64 `json:"event_nid"` EventJSON []byte `json:"event_json"` } -type EventJSONCosmosData struct { +type eventJSONCosmosData struct { cosmosdbapi.CosmosDocument - EventJSON EventJSONCosmos `json:"mx_roomserver_event_json"` + EventJSON eventJSONCosmos `json:"mx_roomserver_event_json"` } // const insertEventJSONSQL = ` @@ -65,8 +65,16 @@ type eventJSONStatements struct { tableName string } -func getEventJSON(s *eventJSONStatements, ctx context.Context, pk string, docId string) (*EventJSONCosmosData, error) { - response := EventJSONCosmosData{} +func (s *eventJSONStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *eventJSONStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getEventJSON(s *eventJSONStatements, ctx context.Context, pk string, docId string) (*eventJSONCosmosData, error) { + response := eventJSONCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -82,27 +90,6 @@ func getEventJSON(s *eventJSONStatements, ctx context.Context, pk string, docId return &response, err } -func queryEventJSON(s *eventJSONStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventJSONCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []EventJSONCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - func NewCosmosDBEventJSONTable(db *Database) (tables.EventJSON, error) { s := &eventJSONStatements{ db: db, @@ -126,24 +113,21 @@ func (s *eventJSONStatements) InsertEventJSON( // _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) // INSERT OR REPLACE INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - docId := fmt.Sprintf("%d", eventNID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getEventJSON(s, ctx, pk, cosmosDocId) + dbData, _ := getEventJSON(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.EventJSON.EventJSON = eventJSON } else { - data := EventJSONCosmos{ + data := eventJSONCosmos{ EventNID: int64(eventNID), EventJSON: eventJSON, } - dbData = &EventJSONCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &eventJSONCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), EventJSON: data, } } @@ -165,13 +149,17 @@ func (s *eventJSONStatements) BulkSelectEventJSON( // WHERE event_nid IN ($1) // ORDER BY event_nid ASC - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventNIDs, } - response, err := queryEventJSON(s, ctx, s.bulkSelectEventJSONStmt, params) + var rows []eventJSONCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.bulkSelectEventJSONStmt, params, &rows) if err != nil { return nil, err @@ -183,7 +171,7 @@ func (s *eventJSONStatements) BulkSelectEventJSON( // We might get fewer results than NIDs so we adjust the length of the slice before returning it. results := make([]tables.EventJSONPair, len(eventNIDs)) i := 0 - for _, item := range response { + for _, item := range rows { result := &results[i] result.EventNID = types.EventNID(item.EventJSON.EventNID) result.EventJSON = item.EventJSON.EventJSON diff --git a/roomserver/storage/cosmosdb/event_state_keys_table.go b/roomserver/storage/cosmosdb/event_state_keys_table.go index bdb363e6a..cd63cb276 100644 --- a/roomserver/storage/cosmosdb/event_state_keys_table.go +++ b/roomserver/storage/cosmosdb/event_state_keys_table.go @@ -37,14 +37,14 @@ import ( // ON CONFLICT DO NOTHING; // ` -type EventStateKeysCosmos struct { +type eventStateKeysCosmos struct { EventStateKeyNID int64 `json:"event_state_key_nid"` EventStateKey string `json:"event_state_key"` } -type EventStateKeysCosmosData struct { +type eventStateKeysCosmosData struct { cosmosdbapi.CosmosDocument - EventStateKeys EventStateKeysCosmos `json:"mx_roomserver_event_state_keys"` + EventStateKeys eventStateKeysCosmos `json:"mx_roomserver_event_state_keys"` } // Same as insertEventTypeNIDSQL @@ -84,29 +84,16 @@ type eventStateKeyStatements struct { tableName string } -func queryEventStateKeys(s *eventStateKeyStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventStateKeysCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []EventStateKeysCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *eventStateKeyStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func getEventStateKeys(s *eventStateKeyStatements, ctx context.Context, pk string, docId string) (*EventStateKeysCosmosData, error) { - response := EventStateKeysCosmosData{} +func (s *eventStateKeyStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getEventStateKeys(s *eventStateKeyStatements, ctx context.Context, pk string, docId string) (*eventStateKeysCosmosData, error) { + response := eventStateKeysCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -144,27 +131,25 @@ func ensureEventStateKeys(s *eventStateKeyStatements, ctx context.Context) { // VALUES (1, '') // ON CONFLICT DO NOTHING; - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // event_state_key TEXT NOT NULL UNIQUE docId := "" - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - data := EventStateKeysCosmos{ + data := eventStateKeysCosmos{ EventStateKey: "", EventStateKeyNID: 1, } // event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT, - dbData := EventStateKeysCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData := eventStateKeysCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), EventStateKeys: data, } insertEventStateKeyCore(s, ctx, dbData) } -func insertEventStateKeyCore(s *eventStateKeyStatements, ctx context.Context, dbData EventStateKeysCosmosData) error { +func insertEventStateKeyCore(s *eventStateKeyStatements, ctx context.Context, dbData eventStateKeysCosmosData) error { err := cosmosdbapi.UpsertDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, @@ -189,15 +174,13 @@ func (s *eventStateKeyStatements) InsertEventStateKeyNID( return 0, cosmosdbutil.ErrNoRows } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // event_state_key TEXT NOT NULL UNIQUE docId := eventStateKey - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - existing, _ := getEventStateKeys(s, ctx, pk, cosmosDocId) + existing, _ := getEventStateKeys(s, ctx, s.getPartitionKey(), cosmosDocId) - var dbData EventStateKeysCosmosData + var dbData eventStateKeysCosmosData if existing == nil { //Not exists, we need to create a new one with a SEQ eventStateKeyNIDSeq, seqErr := GetNextEventStateKeyNID(s, ctx) @@ -205,14 +188,14 @@ func (s *eventStateKeyStatements) InsertEventStateKeyNID( return -1, seqErr } - data := EventStateKeysCosmos{ + data := eventStateKeysCosmos{ EventStateKey: eventStateKey, EventStateKeyNID: eventStateKeyNIDSeq, } // event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT, - dbData = EventStateKeysCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = eventStateKeysCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), EventStateKeys: data, } } else { @@ -232,23 +215,27 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID( // SELECT event_state_key_nid FROM roomserver_event_state_keys // WHERE event_state_key = $1 - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventStateKey, } - response, err := queryEventStateKeys(s, ctx, s.selectEventStateKeyNIDStmt, params) + var rows []eventStateKeysCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectEventStateKeyNIDStmt, params, &rows) if err != nil { return 0, err } //See storage.assignStateKeyNID() - if len(response) == 0 { + if len(rows) == 0 { return 0, cosmosdbutil.ErrNoRows } - return types.EventStateKeyNID(response[0].EventStateKeys.EventStateKeyNID), err + return types.EventStateKeyNID(rows[0].EventStateKeys.EventStateKeyNID), err } func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( @@ -262,20 +249,24 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( // SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys // WHERE event_state_key IN ($1) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventStateKeys, } - response, err := queryEventStateKeys(s, ctx, s.bulkSelectEventStateKeyNIDStmt, params) + var rows []eventStateKeysCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.bulkSelectEventStateKeyNIDStmt, params, &rows) if err != nil { return nil, err } result := make(map[string]types.EventStateKeyNID, len(eventStateKeys)) - for _, item := range response { + for _, item := range rows { result[item.EventStateKeys.EventStateKey] = types.EventStateKeyNID(item.EventStateKeys.EventStateKeyNID) } return result, nil @@ -288,19 +279,23 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKey( // SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys // WHERE event_state_key_nid IN ($1) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventStateKeyNIDs, } - response, err := queryEventStateKeys(s, ctx, s.bulkSelectEventStateKeyStmt, params) + var rows []eventStateKeysCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.bulkSelectEventStateKeyStmt, params, &rows) if err != nil { return nil, err } result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs)) - for _, item := range response { + for _, item := range rows { result[types.EventStateKeyNID(item.EventStateKeys.EventStateKeyNID)] = item.EventStateKeys.EventStateKey } return result, nil diff --git a/roomserver/storage/cosmosdb/event_types_table.go b/roomserver/storage/cosmosdb/event_types_table.go index fffe170fa..fc6641539 100644 --- a/roomserver/storage/cosmosdb/event_types_table.go +++ b/roomserver/storage/cosmosdb/event_types_table.go @@ -42,16 +42,16 @@ import ( // (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING; // ` -type EventTypeCosmosData struct { - cosmosdbapi.CosmosDocument - EventType EventTypeCosmos `json:"mx_roomserver_event_type"` -} - -type EventTypeCosmos struct { +type eventTypeCosmos struct { EventTypeNID int64 `json:"event_type_nid"` EventType string `json:"event_type"` } +type eventTypeCosmosData struct { + cosmosdbapi.CosmosDocument + EventType eventTypeCosmos `json:"mx_roomserver_event_type"` +} + // Assign a new numeric event type ID. // The usual case is that the event type is not in the database. // In that case the ID will be assigned using the next value from the sequence. @@ -96,6 +96,14 @@ type eventTypeStatements struct { tableName string } +func (s *eventTypeStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *eventTypeStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + func NewCosmosDBEventTypesTable(db *Database) (tables.EventTypes, error) { s := &eventTypeStatements{ db: db, @@ -112,27 +120,6 @@ func NewCosmosDBEventTypesTable(db *Database) (tables.EventTypes, error) { return s, nil } -func queryEventTypes(s *eventTypeStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventTypeCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []EventTypeCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - func (s *eventTypeStatements) InsertEventTypeNID( ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { @@ -142,7 +129,7 @@ func (s *eventTypeStatements) InsertEventTypeNID( return -1, seqErr } - data := EventTypeCosmos{ + data := eventTypeCosmos{ EventType: eventType, EventTypeNID: eventTypeNIDSeq, } @@ -156,17 +143,15 @@ func (s *eventTypeStatements) InsertEventTypeNID( return types.EventTypeNID(dbData.EventTypeNID), err } -func insertEventTypeCore(s *eventTypeStatements, ctx context.Context, eventType EventTypeCosmos) (*EventTypeCosmos, error) { +func insertEventTypeCore(s *eventTypeStatements, ctx context.Context, eventType eventTypeCosmos) (*eventTypeCosmos, error) { // INSERT INTO roomserver_event_types (event_type) VALUES ($1) // ON CONFLICT DO NOTHING; - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) //Unique on eventType - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, eventType.EventType) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), eventType.EventType) - var dbData = EventTypeCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + var dbData = eventTypeCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), EventType: eventType, } @@ -200,53 +185,53 @@ func ensureEventTypes(s *eventTypeStatements, ctx context.Context) error { // (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING; // (1, 'm.room.create'), - _, err := insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 1, EventType: "m.room.create"}) + _, err := insertEventTypeCore(s, context.Background(), eventTypeCosmos{EventTypeNID: 1, EventType: "m.room.create"}) if err != nil { return err } // (2, 'm.room.power_levels'), - _, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 2, EventType: "m.room.power_levels"}) + _, err = insertEventTypeCore(s, context.Background(), eventTypeCosmos{EventTypeNID: 2, EventType: "m.room.power_levels"}) if err != nil { return err } // (3, 'm.room.join_rules'), - _, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 3, EventType: "m.room.join_rules"}) + _, err = insertEventTypeCore(s, context.Background(), eventTypeCosmos{EventTypeNID: 3, EventType: "m.room.join_rules"}) if err != nil { return err } // (4, 'm.room.third_party_invite'), - _, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 4, EventType: "m.room.third_party_invite"}) + _, err = insertEventTypeCore(s, context.Background(), eventTypeCosmos{EventTypeNID: 4, EventType: "m.room.third_party_invite"}) if err != nil { return err } // (5, 'm.room.member'), - _, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 5, EventType: "m.room.member"}) + _, err = insertEventTypeCore(s, context.Background(), eventTypeCosmos{EventTypeNID: 5, EventType: "m.room.member"}) if err != nil { return err } // (6, 'm.room.redaction'), - _, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 6, EventType: "m.room.redaction"}) + _, err = insertEventTypeCore(s, context.Background(), eventTypeCosmos{EventTypeNID: 6, EventType: "m.room.redaction"}) if err != nil { return err } // (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING; - _, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 7, EventType: "m.room.history_visibility"}) + _, err = insertEventTypeCore(s, context.Background(), eventTypeCosmos{EventTypeNID: 7, EventType: "m.room.history_visibility"}) if err != nil { return err } return nil } -func selectEventTypeCore(s *eventTypeStatements, ctx context.Context, eventType string) (*EventTypeCosmos, error) { - var response EventTypeCosmosData - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, eventType) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) +func selectEventTypeCore(s *eventTypeStatements, ctx context.Context, eventType string) (*eventTypeCosmos, error) { + var response eventTypeCosmosData + + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), eventType) + err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, ctx, - pk, + s.getPartitionKey(), cosmosDocId, &response) @@ -281,20 +266,24 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID( // SELECT event_type, event_type_nid FROM roomserver_event_types // WHERE event_type IN ($1) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventTypes, } - response, err := queryEventTypes(s, ctx, s.bulkSelectEventTypeNIDStmt, params) + var rows []eventTypeCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.bulkSelectEventTypeNIDStmt, params, &rows) if err != nil { return nil, err } result := make(map[string]types.EventTypeNID, len(eventTypes)) - for _, item := range response { + for _, item := range rows { var eventType string var eventTypeNID int64 eventType = item.EventType.EventType diff --git a/roomserver/storage/cosmosdb/events_table.go b/roomserver/storage/cosmosdb/events_table.go index b8d5a6049..4c199b720 100644 --- a/roomserver/storage/cosmosdb/events_table.go +++ b/roomserver/storage/cosmosdb/events_table.go @@ -46,7 +46,7 @@ import ( // ); // ` -type EventCosmos struct { +type eventCosmos struct { EventNID int64 `json:"event_nid"` RoomNID int64 `json:"room_nid"` EventTypeNID int64 `json:"event_type_nid"` @@ -60,13 +60,13 @@ type EventCosmos struct { IsRejected bool `json:"is_rejected"` } -type EventCosmosMaxDepth struct { +type eventCosmosMaxDepth struct { Max int64 `json:"maxdepth"` } -type EventCosmosData struct { +type eventCosmosData struct { cosmosdbapi.CosmosDocument - Event EventCosmos `json:"mx_roomserver_event"` + Event eventCosmos `json:"mx_roomserver_event"` } // const insertEventSQL = ` @@ -173,6 +173,14 @@ type eventStatements struct { tableName string } +func (s *eventStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *eventStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + func NewCosmosDBEventsTable(db *Database) (tables.Events, error) { s := &eventStatements{ db: db, @@ -207,29 +215,8 @@ func mapFromEventNIDArray(eventNIDs []types.EventNID) []int64 { return result } -func queryEvent(s *eventStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []EventCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - -func getEvent(s *eventStatements, ctx context.Context, pk string, docId string) (*EventCosmosData, error) { - response := EventCosmosData{} +func getEvent(s *eventStatements, ctx context.Context, pk string, docId string) (*eventCosmosData, error) { + response := eventCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -245,7 +232,7 @@ func getEvent(s *eventStatements, ctx context.Context, pk string, docId string) return &response, err } -func setEvent(s *eventStatements, ctx context.Context, event EventCosmosData) (*EventCosmosData, error) { +func setEvent(s *eventStatements, ctx context.Context, event eventCosmosData) (*eventCosmosData, error) { var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(event.Pk, event.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, @@ -288,7 +275,7 @@ func isReferenceSha256Same( } func isEventSame( - event EventCosmos, + event eventCosmos, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, @@ -343,12 +330,10 @@ func (s *eventStatements) InsertEvent( // event_nid INTEGER PRIMARY KEY AUTOINCREMENT, // event_id TEXT NOT NULL UNIQUE, - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) docId := eventID - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, errGet := getEvent(s, ctx, pk, cosmosDocId) + dbData, errGet := getEvent(s, ctx, s.getPartitionKey(), cosmosDocId) // ON CONFLICT DO NOTHING; // event_nid INTEGER PRIMARY KEY AUTOINCREMENT, @@ -358,7 +343,7 @@ func (s *eventStatements) InsertEvent( if seqErr != nil { return 0, 0, seqErr } - data := EventCosmos{ + data := eventCosmos{ AuthEventNIDs: mapFromEventNIDArray(authEventNIDs), Depth: depth, EventId: eventID, @@ -370,8 +355,8 @@ func (s *eventStatements) InsertEvent( RoomNID: int64(roomNID), } - dbData = &EventCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &eventCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), Event: data, } } else { @@ -424,11 +409,9 @@ func (s *eventStatements) SelectEvent( ) (types.EventNID, types.StateSnapshotNID, error) { // "SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) docId := eventID - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - var response, err = getEvent(s, ctx, pk, cosmosDocId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + var response, err = getEvent(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return 0, 0, err } @@ -451,13 +434,17 @@ func (s *eventStatements) BulkSelectStateEventByID( // " WHERE event_id IN ($1)" + // " ORDER BY event_type_nid, event_state_key_nid ASC" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventIDs, } - response, err := queryEvent(s, ctx, s.bulkSelectStateEventByIDStmt, params) + var rows []eventCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.bulkSelectStateEventByIDStmt, params, &rows) if err != nil { return nil, err @@ -467,9 +454,9 @@ func (s *eventStatements) BulkSelectStateEventByID( // because of the unique constraint on event IDs. // So we can allocate an array of the correct size now. // We might get fewer results than IDs so we adjust the length of the slice before returning it. - results := make([]types.StateEntry, len(response)) + results := make([]types.StateEntry, len(rows)) i := 0 - for _, item := range response { + for _, item := range rows { result := &results[i] result.EventTypeNID = types.EventTypeNID(item.Event.EventTypeNID) result.EventStateKeyNID = types.EventStateKeyNID(item.Event.EventStateKeyNID) @@ -502,9 +489,8 @@ func (s *eventStatements) BulkSelectStateEventByNID( sort.Sort(tuples) eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() // params := make([]interface{}, 0, len(eventNIDs)+len(eventTypeNIDArray)+len(eventStateKeyNIDArray)) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventNIDs, } // selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1) @@ -536,7 +522,13 @@ func (s *eventStatements) BulkSelectStateEventByNID( // return nil, fmt.Errorf("s.db.Prepare: %w", err) // } // rows, err := selectStmt.QueryContext(ctx, params...) - rows, err := queryEvent(s, ctx, selectOrig, params) + + var rows []eventCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), selectOrig, params, &rows) if err != nil { return nil, fmt.Errorf("selectStmt.QueryContext: %w", err) @@ -578,13 +570,17 @@ func (s *eventStatements) BulkSelectStateAtEventByID( // "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + // " WHERE event_id IN ($1)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventIDs, } - response, err := queryEvent(s, ctx, s.bulkSelectStateAtEventByIDStmt, params) + var rows []eventCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.bulkSelectStateAtEventByIDStmt, params, &rows) if err != nil { return nil, err @@ -592,7 +588,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID( results := make([]types.StateAtEvent, len(eventIDs)) i := 0 - for _, item := range response { + for _, item := range rows { result := &results[i] result.EventTypeNID = types.EventTypeNID(item.Event.EventTypeNID) result.EventStateKeyNID = types.EventStateKeyNID(item.Event.EventStateKeyNID) @@ -620,19 +616,23 @@ func (s *eventStatements) UpdateEventState( // "UPDATE roomserver_events SET state_snapshot_nid = $1 WHERE event_nid = $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventNID, } - response, err := queryEvent(s, ctx, s.updateEventStateStmt, params) + var rows []eventCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.updateEventStateStmt, params, &rows) if err != nil { return err } - item := response[0] + item := rows[0] item.Event.StateSnapshotNID = int64(stateNID) var _, exReplace = setEvent(s, ctx, item) @@ -648,19 +648,23 @@ func (s *eventStatements) SelectEventSentToOutput( // "SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventNID, } - response, err := queryEvent(s, ctx, s.selectEventSentToOutputStmt, params) + var rows []eventCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectEventSentToOutputStmt, params, &rows) if err != nil { return false, err } - item := response[0] + item := rows[0] sentToOutput = item.Event.SentToOutput return } @@ -669,19 +673,23 @@ func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql. // "UPDATE roomserver_events SET sent_to_output = TRUE WHERE event_nid = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventNID, } - response, err := queryEvent(s, ctx, s.updateEventSentToOutputStmt, params) + var rows []eventCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.updateEventSentToOutputStmt, params, &rows) if err != nil { return err } - item := response[0] + item := rows[0] item.Event.SentToOutput = true var _, exReplace = setEvent(s, ctx, item) @@ -697,19 +705,23 @@ func (s *eventStatements) SelectEventID( // "SELECT event_id FROM roomserver_events WHERE event_nid = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventNID, } - response, err := queryEvent(s, ctx, s.selectEventIDStmt, params) + var rows []eventCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectEventIDStmt, params, &rows) if err != nil { return "", err } - item := response[0] + item := rows[0] eventNID = types.EventNID(item.Event.EventNID) return } @@ -724,21 +736,25 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( // "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" + // " FROM roomserver_events WHERE event_nid IN ($1)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventNIDs, } - response, err := queryEvent(s, ctx, s.bulkSelectStateAtEventAndReferenceStmt, params) + var rows []eventCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.bulkSelectStateAtEventAndReferenceStmt, params, &rows) if err != nil { return nil, err } - results := make([]types.StateAtEventAndReference, len(response)) + results := make([]types.StateAtEventAndReference, len(rows)) i := 0 - for _, item := range response { + for _, item := range rows { result := &results[i] result.EventTypeNID = types.EventTypeNID(item.Event.EventTypeNID) result.EventStateKeyNID = types.EventStateKeyNID(item.Event.EventStateKeyNID) @@ -762,13 +778,17 @@ func (s *eventStatements) BulkSelectEventReference( } // "SELECT event_id, reference_sha256 FROM roomserver_events WHERE event_nid IN ($1)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventNIDs, } - response, err := queryEvent(s, ctx, s.bulkSelectEventReferenceStmt, params) + var rows []eventCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.bulkSelectEventReferenceStmt, params, &rows) if err != nil { return nil, err @@ -776,7 +796,7 @@ func (s *eventStatements) BulkSelectEventReference( results := make([]gomatrixserverlib.EventReference, len(eventNIDs)) i := 0 - for _, item := range response { + for _, item := range rows { result := &results[i] result.EventID = item.Event.EventId result.EventSHA256 = item.Event.ReferenceSha256 @@ -797,20 +817,24 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ // "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventNIDs, } - response, err := queryEvent(s, ctx, s.bulkSelectEventIDStmt, params) + var rows []eventCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.bulkSelectEventIDStmt, params, &rows) if err != nil { return nil, err } i := 0 - for _, item := range response { + for _, item := range rows { eventNID := item.Event.EventNID eventID := item.Event.EventId results[types.EventNID(eventNID)] = eventID @@ -830,20 +854,24 @@ func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []str } // "SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventIDs, } - response, err := queryEvent(s, ctx, s.bulkSelectEventNIDStmt, params) + var rows []eventCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.bulkSelectEventNIDStmt, params, &rows) if err != nil { return nil, err } results := make(map[string]types.EventNID, len(eventIDs)) - for _, item := range response { + for _, item := range rows { eventID := item.Event.EventId eventNID := item.Event.EventNID results[eventID] = types.EventNID(eventNID) @@ -857,28 +885,23 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, } // "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var response []EventCosmosMaxDepth params := map[string]interface{}{ "@x1": s.db.cosmosConfig.TenantName, - "@x2": dbCollectionName, + "@x2": s.getCollectionName(), "@x3": eventNIDs, } - var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() - var query = cosmosdbapi.GetQuery(selectMaxEventDepthSQL, params) - var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, + var rows []eventCosmosMaxDepth + err := cosmosdbapi.PerformQueryAllPartitions(ctx, + s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) + selectMaxEventDepthSQL, params, &rows) if err != nil { return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err) } - return response[0].Max, nil + return rows[0].Max, nil } func (s *eventStatements) SelectRoomNIDsForEventNIDs( @@ -890,20 +913,24 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs( // "SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid IN ($1)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventNIDs, } - response, err := queryEvent(s, ctx, selectRoomNIDsForEventNIDsSQL, params) + var rows []eventCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), selectRoomNIDsForEventNIDsSQL, params, &rows) if err != nil { return nil, err } result := make(map[types.EventNID]types.RoomNID) - for _, item := range response { + for _, item := range rows { roomNID := types.RoomNID(item.Event.RoomNID) eventNID := types.EventNID(item.Event.EventNID) result[eventNID] = roomNID diff --git a/roomserver/storage/cosmosdb/invite_table.go b/roomserver/storage/cosmosdb/invite_table.go index d173768c6..68f5b6c60 100644 --- a/roomserver/storage/cosmosdb/invite_table.go +++ b/roomserver/storage/cosmosdb/invite_table.go @@ -40,7 +40,7 @@ import ( // WHERE NOT retired; // ` -type InviteCosmos struct { +type inviteCosmos struct { InviteEventID string `json:"invite_event_id"` RoomNID int64 `json:"room_nid"` TargetNID int64 `json:"target_nid"` @@ -49,9 +49,9 @@ type InviteCosmos struct { InviteEventJSON []byte `json:"invite_event_json"` } -type InviteCosmosData struct { +type inviteCosmosData struct { cosmosdbapi.CosmosDocument - Invite InviteCosmos `json:"mx_roomserver_invite"` + Invite inviteCosmos `json:"mx_roomserver_invite"` } // const insertInviteEventSQL = "" + @@ -93,29 +93,16 @@ type inviteStatements struct { tableName string } -func queryInvite(s *inviteStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InviteCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []InviteCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *inviteStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func getInvite(s *inviteStatements, ctx context.Context, pk string, docId string) (*InviteCosmosData, error) { - response := InviteCosmosData{} +func (s *inviteStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getInvite(s *inviteStatements, ctx context.Context, pk string, docId string) (*inviteCosmosData, error) { + response := inviteCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -131,7 +118,7 @@ func getInvite(s *inviteStatements, ctx context.Context, pk string, docId string return &response, err } -func setInvite(s *inviteStatements, ctx context.Context, invite InviteCosmosData) (*InviteCosmosData, error) { +func setInvite(s *inviteStatements, ctx context.Context, invite inviteCosmosData) (*inviteCosmosData, error) { var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(invite.Pk, invite.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, @@ -168,9 +155,11 @@ func (s *inviteStatements) InsertInviteEvent( // " sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" + // " ON CONFLICT DO NOTHING" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // invite_event_id TEXT PRIMARY KEY, + docId := inviteEventID + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - data := InviteCosmos{ + data := inviteCosmos{ InviteEventID: inviteEventID, InviteEventJSON: inviteEventJSON, Retired: false, @@ -179,13 +168,8 @@ func (s *inviteStatements) InsertInviteEvent( TargetNID: int64(targetUserNID), } - // invite_event_id TEXT PRIMARY KEY, - docId := inviteEventID - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - - var dbData = InviteCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + var dbData = inviteCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), Invite: data, } @@ -216,20 +200,24 @@ func (s *inviteStatements) UpdateInviteRetired( // " AND NOT retired" // gather all the event IDs we will retire - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": targetUserNID, "@x3": roomNID, } - response, err := queryInvite(s, ctx, s.selectInvitesAboutToRetireStmt, params) + var rows []inviteCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectInvitesAboutToRetireStmt, params, &rows) if err != nil { return } - for _, item := range response { + for _, item := range rows { eventIDs = append(eventIDs, item.Invite.InviteEventID) // UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired @@ -249,14 +237,18 @@ func (s *inviteStatements) SelectInviteActiveForUserInRoom( // SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomNID, "@x3": targetUserNID, } - response, err := queryInvite(s, ctx, s.selectInviteActiveForUserInRoomStmt, params) + var rows []inviteCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectInviteActiveForUserInRoomStmt, params, &rows) if err != nil { return nil, nil, err @@ -264,7 +256,7 @@ func (s *inviteStatements) SelectInviteActiveForUserInRoom( var result []types.EventStateKeyNID var eventIDs []string - for _, item := range response { + for _, item := range rows { var eventID = item.Invite.InviteEventID var senderUserNID = item.Invite.SenderNID result = append(result, types.EventStateKeyNID(senderUserNID)) diff --git a/roomserver/storage/cosmosdb/membership_table.go b/roomserver/storage/cosmosdb/membership_table.go index e5637a930..452c1fbcc 100644 --- a/roomserver/storage/cosmosdb/membership_table.go +++ b/roomserver/storage/cosmosdb/membership_table.go @@ -41,7 +41,7 @@ import ( // ); // ` -type MembershipCosmos struct { +type membershipCosmos struct { RoomNID int64 `json:"room_nid"` TargetNID int64 `json:"target_nid"` SenderNID int64 `json:"sender_nid"` @@ -51,9 +51,9 @@ type MembershipCosmos struct { Forgotten bool `json:"forgotten"` } -type MembershipCosmosData struct { +type membershipCosmosData struct { cosmosdbapi.CosmosDocument - Membership MembershipCosmos `json:"mx_roomserver_membership"` + Membership membershipCosmos `json:"mx_roomserver_membership"` } type MembershipJoinedCountCosmosData struct { @@ -199,29 +199,24 @@ type membershipStatements struct { tableName string } -func queryMembership(s *membershipStatements, ctx context.Context, qry string, params map[string]interface{}) ([]MembershipCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []MembershipCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *membershipStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func getMembership(s *membershipStatements, ctx context.Context, pk string, docId string) (*MembershipCosmosData, error) { - response := MembershipCosmosData{} +func (s *membershipStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func (s *membershipStatements) getCollectionEventStateKeys() string { + return "roomserver_event_state_keys" +} + +func (s *membershipStatements) getPartitionKeyEventStateKeys() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionEventStateKeys()) +} + +func getMembership(s *membershipStatements, ctx context.Context, pk string, docId string) (*membershipCosmosData, error) { + response := membershipCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -237,7 +232,7 @@ func getMembership(s *membershipStatements, ctx context.Context, pk string, docI return &response, err } -func setMembership(s *membershipStatements, ctx context.Context, membership MembershipCosmosData) (*MembershipCosmosData, error) { +func setMembership(s *membershipStatements, ctx context.Context, membership membershipCosmosData) (*membershipCosmosData, error) { var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(membership.Pk, membership.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, @@ -284,15 +279,12 @@ func (s *membershipStatements) InsertMembership( // " VALUES ($1, $2, $3)" + // " ON CONFLICT DO NOTHING" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - // UNIQUE (room_nid, target_nid) docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) // " ON CONFLICT DO NOTHING" - exists, _ := getMembership(s, ctx, pk, cosmosDocId) + exists, _ := getMembership(s, ctx, s.getPartitionKey(), cosmosDocId) if exists != nil { exists.Membership.RoomNID = int64(roomNID) exists.Membership.TargetNID = int64(targetUserNID) @@ -302,7 +294,7 @@ func (s *membershipStatements) InsertMembership( return errSet } - data := MembershipCosmos{ + data := membershipCosmos{ EventNID: 0, Forgotten: false, MembershipNID: 1, @@ -312,8 +304,8 @@ func (s *membershipStatements) InsertMembership( TargetNID: int64(targetUserNID), } - var dbData = MembershipCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + var dbData = membershipCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), Membership: data, } @@ -336,11 +328,10 @@ func (s *membershipStatements) SelectMembershipForUpdate( // "SELECT membership_nid FROM roomserver_membership" + // " WHERE room_nid = $1 AND target_nid = $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - response, err := getMembership(s, ctx, pk, cosmosDocId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + + response, err := getMembership(s, ctx, s.getPartitionKey(), cosmosDocId) if response != nil { membership = tables.MembershipState(response.Membership.MembershipNID) } @@ -355,11 +346,9 @@ func (s *membershipStatements) SelectMembershipFromRoomAndTarget( // "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" + // " WHERE room_nid = $1 AND target_nid = $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - response, err := getMembership(s, ctx, pk, cosmosDocId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + response, err := getMembership(s, ctx, s.getPartitionKey(), cosmosDocId) if response != nil { eventNID = types.EventNID(response.Membership.EventNID) forgotten = response.Membership.Forgotten @@ -374,9 +363,8 @@ func (s *membershipStatements) SelectMembershipsFromRoom( ) (eventNIDs []types.EventNID, err error) { var selectStmt string - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomNID, } if localOnly { @@ -390,12 +378,19 @@ func (s *membershipStatements) SelectMembershipsFromRoom( // " WHERE room_nid = $1 and forgotten = false" selectStmt = s.selectMembershipsFromRoomStmt } - response, err := queryMembership(s, ctx, selectStmt, params) + + var rows []membershipCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), selectStmt, params, &rows) + if err != nil { return nil, err } - for _, item := range response { + for _, item := range rows { eventNIDs = append(eventNIDs, types.EventNID(item.Membership.EventNID)) } return @@ -407,9 +402,8 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( ) (eventNIDs []types.EventNID, err error) { var stmt string - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomNID, "@x3": membership, } @@ -423,12 +417,18 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( // " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false" stmt = s.selectMembershipsFromRoomAndMembershipStmt } - response, err := queryMembership(s, ctx, stmt, params) + var rows []membershipCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), stmt, params, &rows) + if err != nil { return nil, err } - for _, item := range response { + for _, item := range rows { eventNIDs = append(eventNIDs, types.EventNID(item.Membership.EventNID)) } return @@ -443,11 +443,9 @@ func (s *membershipStatements) UpdateMembership( // "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" + // " WHERE room_nid = $5 AND target_nid = $6" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - dbData, err := getMembership(s, ctx, pk, cosmosDocId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + dbData, err := getMembership(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return err @@ -468,18 +466,23 @@ func (s *membershipStatements) SelectRoomsWithMembership( // "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": membershipState, "@x3": userID, } - response, err := queryMembership(s, ctx, s.selectRoomsWithMembershipStmt, params) + var rows []membershipCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectRoomsWithMembershipStmt, params, &rows) + if err != nil { return nil, err } var roomNIDs []types.RoomNID - for _, item := range response { + for _, item := range rows { roomNIDs = append(roomNIDs, types.RoomNID(item.Membership.RoomNID)) } return roomNIDs, nil @@ -495,30 +498,24 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, // " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + // " GROUP BY target_nid" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomNIDs, } - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []MembershipJoinedCountCosmosData - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(selectJoinedUsersSetForRoomsSQL, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, + var rows []MembershipJoinedCountCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) + s.getPartitionKey(), selectJoinedUsersSetForRoomsSQL, params, &rows) if err != nil { return nil, err } result := make(map[types.EventStateKeyNID]int) - for _, item := range response { + for _, item := range rows { userID := types.EventStateKeyNID(item.TargetNID) count := item.RoomCount result[userID] = count @@ -532,20 +529,25 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room var nid types.RoomNID // err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) // - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": tables.MembershipStateJoin, "@x3": roomNID, } - response, err := queryMembership(s, ctx, s.selectLocalServerInRoomStmt, params) - if len(response) == 0 { + var rows []membershipCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectLocalServerInRoomStmt, params, &rows) + + if len(rows) == 0 { if err == cosmosdbutil.ErrNoRows { return false, nil } return false, err } - nid = types.RoomNID(response[0].Membership.RoomNID) + nid = types.RoomNID(rows[0].Membership.RoomNID) found := nid > 0 return found, nil @@ -564,27 +566,20 @@ func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID t "select * from c where c._cn = @x1 " + "and (endswith(c.mx_roomserver_event_state_keys.event_state_key, \":\") or c.mx_roomserver_event_state_keys.event_state_key = @x2) " - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionEventStateKeys(), "@x2": serverName, } - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var eventStateKeys []EventStateKeysCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(selectEventStateKeyNIDSQL, params) // - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, + var rowsEventsStateKeys []eventStateKeysCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - query, - &eventStateKeys, - optionsQry) + s.getPartitionKeyEventStateKeys(), selectEventStateKeyNIDSQL, params, &rowsEventsStateKeys) eventStateKeyNids := []int64{} - for _, item := range eventStateKeys { + for _, item := range rowsEventsStateKeys { eventStateKeyNids = append(eventStateKeyNids, item.EventStateKeys.EventStateKeyNID) } @@ -595,19 +590,25 @@ func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID t // err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) params = map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": tables.MembershipStateJoin, "@x3": roomNID, "@x4": eventStateKeyNids, } - response, err := queryMembership(s, ctx, s.selectServerInRoomStmt, params) - if len(response) == 0 { + var rows []membershipCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectServerInRoomStmt, params, &rows) + + if len(rows) == 0 { if err == cosmosdbutil.ErrNoRows { return false, nil } return false, err } - nid = types.RoomNID(response[0].Membership.RoomNID) + nid = types.RoomNID(rows[0].Membership.RoomNID) return roomNID == nid, nil } @@ -616,33 +617,26 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type // " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + // ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": userID, "@x3": searchString, "@x4": limit, } - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var responseDistinctRoom []MembershipCosmos - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(selectKnownUsersSQLDistinctRoom, params) // - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, + var rowsDistinctRoom []membershipCosmos + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - query, - &responseDistinctRoom, - optionsQry) + s.getPartitionKey(), selectKnownUsersSQLDistinctRoom, params, &rowsDistinctRoom) if err != nil { return nil, err } rooms := []int64{} - for _, item := range responseDistinctRoom { + for _, item := range rowsDistinctRoom { rooms = append(rooms, item.RoomNID) } @@ -651,47 +645,40 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type // " WHERE room_nid IN (" + params = map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": rooms, } - var responseRooms []MembershipCosmos - query = cosmosdbapi.GetQuery(selectKnownUsersSQLRooms, params) - _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, + var rows []membershipCosmos + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - query, - &responseRooms, - optionsQry) + s.getPartitionKey(), selectKnownUsersSQLRooms, params, &rows) if err != nil { return nil, err } targetNIDs := []int64{} - for _, item := range responseRooms { + for _, item := range rows { targetNIDs = append(targetNIDs, item.TargetNID) } // HACK: Joined table - var dbCollectionNameEventStateKeys = cosmosdbapi.GetCollectionName(s.db.databaseName, "event_state_keys") params = map[string]interface{}{ - "@x1": dbCollectionNameEventStateKeys, + "@x1": s.getCollectionEventStateKeys(), "@x2": targetNIDs, } bulkSelectEventStateKeyStmt := "select * from c where c._cn = @x1 and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_state_keys.event_state_key_nid)" - var responseEventStateKeys []EventStateKeysCosmos - query = cosmosdbapi.GetQuery(bulkSelectEventStateKeyStmt, params) - _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, + var rowsEventStateKeys []eventStateKeysCosmos + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - query, - &responseEventStateKeys, - optionsQry) + s.getPartitionKeyEventStateKeys(), bulkSelectEventStateKeyStmt, params, &rowsEventStateKeys) if err != nil { return nil, err @@ -700,7 +687,7 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type // SELECT DISTINCT event_state_key result := []string{} - for _, item := range responseEventStateKeys { + for _, item := range rowsEventStateKeys { userID := item.EventStateKey result = append(result, userID) } @@ -716,11 +703,9 @@ func (s *membershipStatements) UpdateForgetMembership( // "UPDATE roomserver_membership SET forgotten = $1" + // " WHERE room_nid = $2 AND target_nid = $3" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - dbData, err := getMembership(s, ctx, pk, cosmosDocId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + dbData, err := getMembership(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return err diff --git a/roomserver/storage/cosmosdb/previous_events_table.go b/roomserver/storage/cosmosdb/previous_events_table.go index 4fb9422a9..448ade8a1 100644 --- a/roomserver/storage/cosmosdb/previous_events_table.go +++ b/roomserver/storage/cosmosdb/previous_events_table.go @@ -43,15 +43,15 @@ import ( // ); // ` -type PreviousEventCosmos struct { +type previousEventCosmos struct { PreviousEventID string `json:"previous_event_id"` PreviousReferenceSha256 []byte `json:"previous_reference_sha256"` EventNIDs string `json:"event_nids"` } -type PreviousEventCosmosData struct { +type previousEventCosmosData struct { cosmosdbapi.CosmosDocument - PreviousEvent PreviousEventCosmos `json:"mx_roomserver_previous_event"` + PreviousEvent previousEventCosmos `json:"mx_roomserver_previous_event"` } // Insert an entry into the previous_events table. @@ -85,8 +85,16 @@ type previousEventStatements struct { tableName string } -func getPreviousEvent(s *previousEventStatements, ctx context.Context, pk string, docId string) (*PreviousEventCosmosData, error) { - response := PreviousEventCosmosData{} +func (s *previousEventStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *previousEventStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getPreviousEvent(s *previousEventStatements, ctx context.Context, pk string, docId string) (*previousEventCosmosData, error) { + response := previousEventCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -125,18 +133,15 @@ func (s *previousEventStatements) InsertPreviousEvent( ) error { eventNIDAsString := fmt.Sprintf("%d", eventNID) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - // UNIQUE (previous_event_id, previous_reference_sha256) // TODO: Check value // docId := fmt.Sprintf("%s_%s", previousEventID, previousEventReferenceSHA256) docId := previousEventID - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) // SELECT 1 FROM roomserver_previous_events // WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 - existing, err := getPreviousEvent(s, ctx, pk, cosmosDocId) + existing, err := getPreviousEvent(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { if err != cosmosdbutil.ErrNoRows { @@ -144,17 +149,17 @@ func (s *previousEventStatements) InsertPreviousEvent( } } - var dbData PreviousEventCosmosData + var dbData previousEventCosmosData // Doesnt exist, create a new one if existing == nil { - data := PreviousEventCosmos{ + data := previousEventCosmos{ EventNIDs: "", PreviousEventID: previousEventID, PreviousReferenceSha256: previousEventReferenceSHA256, } - dbData = PreviousEventCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = previousEventCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), PreviousEvent: data, } } else { @@ -192,18 +197,16 @@ func (s *previousEventStatements) InsertPreviousEvent( func (s *previousEventStatements) SelectPreviousEventExists( ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, ) error { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // UNIQUE (previous_event_id, previous_reference_sha256) // TODO: Check value // docId := fmt.Sprintf("%s_%s", previousEventID, previousEventReferenceSHA256) docId := eventID - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, string(docId)) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), string(docId)) // SELECT 1 FROM roomserver_previous_events // WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 - dbData, err := getPreviousEvent(s, ctx, pk, cosmosDocId) + dbData, err := getPreviousEvent(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return err } diff --git a/roomserver/storage/cosmosdb/published_table.go b/roomserver/storage/cosmosdb/published_table.go index 7a84ee696..5fb7c5e8a 100644 --- a/roomserver/storage/cosmosdb/published_table.go +++ b/roomserver/storage/cosmosdb/published_table.go @@ -34,14 +34,14 @@ import ( // ); // ` -type PublishCosmos struct { +type publishCosmos struct { RoomID string `json:"room_id"` Published bool `json:"published"` } -type PublishCosmosData struct { +type publishCosmosData struct { cosmosdbapi.CosmosDocument - Publish PublishCosmos `json:"mx_roomserver_publish"` + Publish publishCosmos `json:"mx_roomserver_publish"` } // const upsertPublishedSQL = "" + @@ -64,29 +64,16 @@ type publishedStatements struct { tableName string } -func queryPublish(s *publishedStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PublishCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []PublishCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *publishedStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func getPublish(s *publishedStatements, ctx context.Context, pk string, docId string) (*PublishCosmosData, error) { - response := PublishCosmosData{} +func (s *publishedStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getPublish(s *publishedStatements, ctx context.Context, pk string, docId string) (*publishCosmosData, error) { + response := publishCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -125,24 +112,21 @@ func (s *publishedStatements) UpsertRoomPublished( // "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // room_id TEXT NOT NULL PRIMARY KEY, docId := roomID - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - - dbData, _ := getPublish(s, ctx, pk, cosmosDocId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + dbData, _ := getPublish(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.Publish.Published = published } else { - data := PublishCosmos{ + data := publishCosmos{ RoomID: roomID, Published: false, } - dbData = &PublishCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &publishCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), Publish: data, } } @@ -161,13 +145,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID( ) (published bool, err error) { // "SELECT published FROM roomserver_published WHERE room_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // room_id TEXT NOT NULL PRIMARY KEY, docId := roomID - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - - response, err := getPublish(s, ctx, pk, cosmosDocId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + response, err := getPublish(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return false, err } @@ -179,19 +160,24 @@ func (s *publishedStatements) SelectAllPublishedRooms( ) ([]string, error) { // "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": published, } - response, err := queryPublish(s, ctx, s.selectAllPublishedStmt, params) + var rows []publishCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectAllPublishedStmt, params, &rows) + if err != nil { return nil, err } var roomIDs []string - for _, item := range response { + for _, item := range rows { roomIDs = append(roomIDs, item.Publish.RoomID) } return roomIDs, nil diff --git a/roomserver/storage/cosmosdb/redactions_table.go b/roomserver/storage/cosmosdb/redactions_table.go index 5994129d5..c3dcad635 100644 --- a/roomserver/storage/cosmosdb/redactions_table.go +++ b/roomserver/storage/cosmosdb/redactions_table.go @@ -37,15 +37,15 @@ import ( // ); // ` -type RedactionCosmos struct { +type redactionCosmos struct { RedactionEventID string `json:"redaction_event_id"` RedactsEventID string `json:"redacts_event_id"` Validated bool `json:"validated"` } -type RedactionCosmosData struct { +type redactionCosmosData struct { cosmosdbapi.CosmosDocument - Redaction RedactionCosmos `json:"mx_roomserver_redaction"` + Redaction redactionCosmos `json:"mx_roomserver_redaction"` } // const insertRedactionSQL = "" + @@ -74,29 +74,16 @@ type redactionStatements struct { tableName string } -func queryRedaction(s *redactionStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RedactionCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []RedactionCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *redactionStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func getRedaction(s *redactionStatements, ctx context.Context, pk string, docId string) (*RedactionCosmosData, error) { - response := RedactionCosmosData{} +func (s *redactionStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getRedaction(s *redactionStatements, ctx context.Context, pk string, docId string) (*redactionCosmosData, error) { + response := redactionCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -112,7 +99,7 @@ func getRedaction(s *redactionStatements, ctx context.Context, pk string, docId return &response, err } -func setRedaction(s *redactionStatements, ctx context.Context, redaction RedactionCosmosData) (*RedactionCosmosData, error) { +func setRedaction(s *redactionStatements, ctx context.Context, redaction redactionCosmosData) (*redactionCosmosData, error) { var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(redaction.Pk, redaction.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, @@ -146,20 +133,18 @@ func (s *redactionStatements) InsertRedaction( // "INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" + // " VALUES ($1, $2, $3)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // redaction_event_id TEXT PRIMARY KEY, docId := info.RedactionEventID - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - data := RedactionCosmos{ + data := redactionCosmos{ RedactionEventID: info.RedactionEventID, RedactsEventID: info.RedactsEventID, Validated: info.Validated, } - var dbData = RedactionCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + var dbData = redactionCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), Redaction: data, } @@ -188,13 +173,11 @@ func (s *redactionStatements) SelectRedactionInfoByRedactionEventID( // "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" + // " WHERE redaction_event_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // redaction_event_id TEXT PRIMARY KEY, docId := redactionEventID - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - response, err := getRedaction(s, ctx, pk, cosmosDocId) + response, err := getRedaction(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { info = nil err = err @@ -221,27 +204,31 @@ func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted( // "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" + // " WHERE redacts_event_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventID, } - response, err := queryRedaction(s, ctx, s.selectRedactionInfoByEventBeingRedactedStmt, params) + var rows []redactionCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectRedactionInfoByEventBeingRedactedStmt, params, &rows) if err != nil { return nil, err } - if len(response) == 0 { + if len(rows) == 0 { info = nil err = nil return } // TODO: Check this is ok to return the 1st one *info = tables.RedactionInfo{ - RedactionEventID: response[0].Redaction.RedactionEventID, - RedactsEventID: response[0].Redaction.RedactsEventID, - Validated: response[0].Redaction.Validated, + RedactionEventID: rows[0].Redaction.RedactionEventID, + RedactsEventID: rows[0].Redaction.RedactsEventID, + Validated: rows[0].Redaction.Validated, } return } @@ -252,13 +239,11 @@ func (s *redactionStatements) MarkRedactionValidated( // " UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // redaction_event_id TEXT PRIMARY KEY, docId := redactionEventID - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - response, err := getRedaction(s, ctx, pk, cosmosDocId) + response, err := getRedaction(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return err } diff --git a/roomserver/storage/cosmosdb/room_aliases_table.go b/roomserver/storage/cosmosdb/room_aliases_table.go index 815154f43..005ec737b 100644 --- a/roomserver/storage/cosmosdb/room_aliases_table.go +++ b/roomserver/storage/cosmosdb/room_aliases_table.go @@ -33,15 +33,15 @@ import ( // CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id); // ` -type RoomAliasCosmos struct { +type roomAliasCosmos struct { Alias string `json:"alias"` RoomID string `json:"room_id"` CreatorID string `json:"creator_id"` } -type RoomAliasCosmosData struct { +type roomAliasCosmosData struct { cosmosdbapi.CosmosDocument - RoomAlias RoomAliasCosmos `json:"mx_roomserver_room_alias"` + RoomAlias roomAliasCosmos `json:"mx_roomserver_room_alias"` } // const insertRoomAliasSQL = ` @@ -75,29 +75,16 @@ type roomAliasesStatements struct { tableName string } -func queryRoomAlias(s *roomAliasesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RoomAliasCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []RoomAliasCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *roomAliasesStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func getRoomAlias(s *roomAliasesStatements, ctx context.Context, pk string, docId string) (*RoomAliasCosmosData, error) { - response := RoomAliasCosmosData{} +func (s *roomAliasesStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getRoomAlias(s *roomAliasesStatements, ctx context.Context, pk string, docId string) (*roomAliasCosmosData, error) { + response := roomAliasCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -137,21 +124,19 @@ func (s *roomAliasesStatements) InsertRoomAlias( ) error { // INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3) - data := RoomAliasCosmos{ + + // alias TEXT NOT NULL PRIMARY KEY, + docId := alias + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + + data := roomAliasCosmos{ Alias: alias, CreatorID: creatorUserID, RoomID: roomID, } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - - // alias TEXT NOT NULL PRIMARY KEY, - docId := alias - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - - var dbData = RoomAliasCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + var dbData = roomAliasCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), RoomAlias: data, } @@ -172,13 +157,10 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias( // SELECT room_id FROM roomserver_room_aliases WHERE alias = $1 - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - // alias TEXT NOT NULL PRIMARY KEY, docId := alias - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - response, err := getRoomAlias(s, ctx, pk, cosmosDocId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + response, err := getRoomAlias(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return "", err @@ -198,19 +180,23 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID( // SELECT alias FROM roomserver_room_aliases WHERE room_id = $1 - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, } - response, err := queryRoomAlias(s, ctx, s.selectAliasesFromRoomIDStmt, params) + var rows []roomAliasCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectAliasesFromRoomIDStmt, params, &rows) if err != nil { return nil, err } - for _, item := range response { + for _, item := range rows { aliases = append(aliases, item.RoomAlias.Alias) } @@ -223,13 +209,10 @@ func (s *roomAliasesStatements) SelectCreatorIDFromAlias( // SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1 - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - // alias TEXT NOT NULL PRIMARY KEY, docId := alias - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - response, err := getRoomAlias(s, ctx, pk, cosmosDocId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + response, err := getRoomAlias(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return "", err @@ -248,11 +231,9 @@ func (s *roomAliasesStatements) DeleteRoomAlias( // DELETE FROM roomserver_room_aliases WHERE alias = $1 - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) docId := alias - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var options = cosmosdbapi.GetDeleteDocumentOptions(pk) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + var options = cosmosdbapi.GetDeleteDocumentOptions(s.getPartitionKey()) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, s.db.cosmosConfig.DatabaseName, diff --git a/roomserver/storage/cosmosdb/rooms_table.go b/roomserver/storage/cosmosdb/rooms_table.go index 9cf749b8c..d342535f8 100644 --- a/roomserver/storage/cosmosdb/rooms_table.go +++ b/roomserver/storage/cosmosdb/rooms_table.go @@ -40,12 +40,7 @@ import ( // ); // ` -type RoomCosmosData struct { - cosmosdbapi.CosmosDocument - Room RoomCosmos `json:"mx_roomserver_room"` -} - -type RoomCosmos struct { +type roomCosmos struct { RoomNID int64 `json:"room_nid"` RoomID string `json:"room_id"` LatestEventNIDs []int64 `json:"latest_event_nids"` @@ -54,6 +49,11 @@ type RoomCosmos struct { RoomVersion string `json:"room_version"` } +type roomCosmosData struct { + cosmosdbapi.CosmosDocument + Room roomCosmos `json:"mx_roomserver_room"` +} + // Same as insertEventTypeNIDSQL // const insertRoomNIDSQL = ` // INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2) @@ -113,6 +113,14 @@ type roomStatements struct { tableName string } +func (s *roomStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *roomStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + func NewCosmosDBRoomsTable(db *Database) (tables.Rooms, error) { s := &roomStatements{ db: db, @@ -139,29 +147,8 @@ func mapToRoomEventNIDArray(eventNIDs []int64) []types.EventNID { return result } -func queryRoom(s *roomStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RoomCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []RoomCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - -func getRoom(s *roomStatements, ctx context.Context, pk string, docId string) (*RoomCosmosData, error) { - response := RoomCosmosData{} +func getRoom(s *roomStatements, ctx context.Context, pk string, docId string) (*roomCosmosData, error) { + response := roomCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -177,7 +164,7 @@ func getRoom(s *roomStatements, ctx context.Context, pk string, docId string) (* return &response, err } -func setRoom(s *roomStatements, ctx context.Context, room RoomCosmosData) (*RoomCosmosData, error) { +func setRoom(s *roomStatements, ctx context.Context, room roomCosmosData) (*roomCosmosData, error) { var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(room.Pk, room.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, @@ -193,19 +180,23 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { // "SELECT room_id FROM roomserver_rooms" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), } - response, err := queryRoom(s, ctx, s.selectRoomIDsStmt, params) + var rows []roomCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectRoomIDsStmt, params, &rows) if err != nil { return nil, err } var roomIDs []string - for _, item := range response { + for _, item := range rows { roomIDs = append(roomIDs, item.Room.RoomID) } return roomIDs, nil @@ -216,12 +207,10 @@ func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*ty // "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // room_id TEXT NOT NULL UNIQUE, docId := roomID - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - room, err := getRoom(s, ctx, pk, cosmosDocId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + room, err := getRoom(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { if err == cosmosdbutil.ErrNoRows { @@ -242,16 +231,13 @@ func (s *roomStatements) InsertRoomNID( roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (roomNID types.RoomNID, err error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - // INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2) // ON CONFLICT DO NOTHING; // room_id TEXT NOT NULL UNIQUE, docId := roomID - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, errGet := getRoom(s, ctx, pk, cosmosDocId) + dbData, errGet := getRoom(s, ctx, s.getPartitionKey(), cosmosDocId) if errGet == cosmosdbutil.ErrNoRows { // room_nid INTEGER PRIMARY KEY AUTOINCREMENT, @@ -260,14 +246,14 @@ func (s *roomStatements) InsertRoomNID( return 0, seqErr } - data := RoomCosmos{ + data := roomCosmos{ RoomNID: int64(roomNIDSeq), RoomID: roomID, RoomVersion: string(roomVersion), } - dbData = &RoomCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &roomCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), Room: data, } } else { @@ -298,12 +284,10 @@ func (s *roomStatements) SelectRoomNID( // "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // room_id TEXT NOT NULL UNIQUE, docId := roomID - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - room, err := getRoom(s, ctx, pk, cosmosDocId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + room, err := getRoom(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return 0, err @@ -321,25 +305,29 @@ func (s *roomStatements) SelectLatestEventNIDs( // "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomNID, } - response, err := queryRoom(s, ctx, s.selectLatestEventNIDsStmt, params) + var rows []roomCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectLatestEventNIDsStmt, params, &rows) if err != nil { return nil, 0, err } // TODO: Check the error handling - if len(response) == 0 { + if len(rows) == 0 { return nil, 0, cosmosdbutil.ErrNoRows } //Assume 1 per RoomNID - room := response[0] + room := rows[0] return mapToRoomEventNIDArray(room.Room.LatestEventNIDs), types.StateSnapshotNID(room.Room.StateSnapshotNID), nil } @@ -349,25 +337,29 @@ func (s *roomStatements) SelectLatestEventsNIDsForUpdate( // "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomNID, } - response, err := queryRoom(s, ctx, s.selectLatestEventNIDsForUpdateStmt, params) + var rows []roomCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectLatestEventNIDsForUpdateStmt, params, &rows) if err != nil { return nil, 0, 0, err } // TODO: Check the error handling - if len(response) == 0 { + if len(rows) == 0 { return nil, 0, 0, cosmosdbutil.ErrNoRows } //Assume 1 per RoomNID - room := response[0] + room := rows[0] return mapToRoomEventNIDArray(room.Room.LatestEventNIDs), types.EventNID(room.Room.LastEventSentNID), types.StateSnapshotNID(room.Room.StateSnapshotNID), nil } @@ -382,25 +374,29 @@ func (s *roomStatements) UpdateLatestEventNIDs( // "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomNID, } - response, err := queryRoom(s, ctx, s.selectLatestEventNIDsForUpdateStmt, params) + var rows []roomCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectLatestEventNIDsForUpdateStmt, params, &rows) if err != nil { return err } // TODO: Check the error handling - if len(response) == 0 { + if len(rows) == 0 { return cosmosdbutil.ErrNoRows } //Assume 1 per RoomNID - room := response[0] + room := rows[0] room.Room.LatestEventNIDs = mapFromEventNIDArray(eventNIDs) room.Room.LastEventSentNID = int64(lastEventSentNID) @@ -419,20 +415,24 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs( // "SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomNIDs, } - response, err := queryRoom(s, ctx, selectRoomVersionsForRoomNIDsSQL, params) + var rows []roomCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectRoomVersionForRoomNIDStmt, params, &rows) if err != nil { return nil, err } result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion) - for _, item := range response { + for _, item := range rows { result[types.RoomNID(item.Room.RoomNID)] = gomatrixserverlib.RoomVersion(item.Room.RoomVersion) } return result, nil @@ -445,20 +445,24 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types // "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomNIDs, } - response, err := queryRoom(s, ctx, bulkSelectRoomIDsSQL, params) + var rows []roomCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), bulkSelectRoomIDsSQL, params, &rows) if err != nil { return nil, err } var roomIDs []string - for _, item := range response { + for _, item := range rows { roomIDs = append(roomIDs, item.Room.RoomID) } return roomIDs, nil @@ -471,20 +475,24 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []strin // "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomIDs, } - response, err := queryRoom(s, ctx, bulkSelectRoomNIDsSQL, params) + var rows []roomCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), bulkSelectRoomNIDsSQL, params, &rows) if err != nil { return nil, err } var roomNIDs []types.RoomNID - for _, item := range response { + for _, item := range rows { roomNIDs = append(roomNIDs, types.RoomNID(item.Room.RoomNID)) } return roomNIDs, nil diff --git a/roomserver/storage/cosmosdb/state_block_table.go b/roomserver/storage/cosmosdb/state_block_table.go index ac7b7f590..2b25d0c73 100644 --- a/roomserver/storage/cosmosdb/state_block_table.go +++ b/roomserver/storage/cosmosdb/state_block_table.go @@ -42,19 +42,19 @@ import ( // ); // ` -type StateBlockCosmos struct { +type stateBlockCosmos struct { StateBlockNID int64 `json:"state_block_nid"` StateBlockHash []byte `json:"state_block_hash"` EventNIDs []int64 `json:"event_nids"` } -type StateBlockCosmosMaxNID struct { +type stateBlockCosmosMaxNID struct { Max int64 `json:"maxstateblocknid"` } -type StateBlockCosmosData struct { +type stateBlockCosmosData struct { cosmosdbapi.CosmosDocument - StateBlock StateBlockCosmos `json:"mx_roomserver_state_block"` + StateBlock stateBlockCosmos `json:"mx_roomserver_state_block"` } // Insert a new state block. If we conflict on the hash column then @@ -82,29 +82,16 @@ type stateBlockStatements struct { tableName string } -func queryStateBlock(s *stateBlockStatements, ctx context.Context, qry string, params map[string]interface{}) ([]StateBlockCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []StateBlockCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *stateBlockStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func getStateBlock(s *stateBlockStatements, ctx context.Context, pk string, docId string) (*StateBlockCosmosData, error) { - response := StateBlockCosmosData{} +func (s *stateBlockStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getStateBlock(s *stateBlockStatements, ctx context.Context, pk string, docId string) (*stateBlockCosmosData, error) { + response := stateBlockCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -120,7 +107,7 @@ func getStateBlock(s *stateBlockStatements, ctx context.Context, pk string, docI return &response, err } -func setStateBlock(s *stateBlockStatements, ctx context.Context, item StateBlockCosmosData) (*StateBlockCosmosData, error) { +func setStateBlock(s *stateBlockStatements, ctx context.Context, item stateBlockCosmosData) (*stateBlockCosmosData, error) { var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(item.Pk, item.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, @@ -168,16 +155,12 @@ func (s *stateBlockStatements) BulkInsertStateData( // ctx, nids.Hash(), js, // ).Scan(&id) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - // state_block_hash BLOB UNIQUE, docId := hex.EncodeToString(nids.Hash()) - - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) //See if it exists - existing, err := getStateBlock(s, ctx, pk, cosmosDocId) + existing, err := getStateBlock(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { if err != cosmosdbutil.ErrNoRows { return 0, err @@ -199,14 +182,14 @@ func (s *stateBlockStatements) BulkInsertStateData( seq, err := GetNextStateBlockNID(s, ctx) id = types.StateBlockNID(seq) - data := StateBlockCosmos{ + data := stateBlockCosmos{ StateBlockNID: seq, StateBlockHash: nids.Hash(), EventNIDs: ids, } - var dbData = StateBlockCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + var dbData = stateBlockCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), StateBlock: data, } @@ -235,14 +218,18 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries( // if err != nil { // return nil, err // } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": stateBlockNIDs, } // rows, err := selectStmt.QueryContext(ctx, intfs...) - rows, err := queryStateBlock(s, ctx, s.bulkSelectStateBlockEntriesStmt, params) + var rows []stateBlockCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.bulkSelectStateBlockEntriesStmt, params, &rows) if err != nil { return nil, err diff --git a/roomserver/storage/cosmosdb/state_snapshot_table.go b/roomserver/storage/cosmosdb/state_snapshot_table.go index fb5801423..4c6a81d03 100644 --- a/roomserver/storage/cosmosdb/state_snapshot_table.go +++ b/roomserver/storage/cosmosdb/state_snapshot_table.go @@ -47,16 +47,16 @@ import ( // state_block_nids TEXT NOT NULL DEFAULT '[]' // ); -type StateSnapshotCosmos struct { +type stateSnapshotCosmos struct { StateSnapshotNID int64 `json:"state_snapshot_nid"` StateSnapshotHash []byte `json:"state_snapshot_hash"` RoomNID int64 `json:"room_nid"` StateBlockNIDs []int64 `json:"state_block_nids"` } -type StateSnapshotCosmosData struct { +type stateSnapshotCosmosData struct { cosmosdbapi.CosmosDocument - StateSnapshot StateSnapshotCosmos `json:"mx_roomserver_state_snapshot"` + StateSnapshot stateSnapshotCosmos `json:"mx_roomserver_state_snapshot"` } // const insertStateSQL = ` @@ -82,6 +82,14 @@ type stateSnapshotStatements struct { tableName string } +func (s *stateSnapshotStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *stateSnapshotStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + func mapFromStateBlockNIDArray(stateBlockNIDs []types.StateBlockNID) []int64 { result := []int64{} for i := 0; i < len(stateBlockNIDs); i++ { @@ -133,22 +141,19 @@ func (s *stateSnapshotStatements) InsertState( // return // } - data := StateSnapshotCosmos{ + // state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, + docId := fmt.Sprintf("%d", stateSnapshotNIDSeq) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + + data := stateSnapshotCosmos{ RoomNID: int64(roomNID), StateSnapshotHash: stateBlockNIDs.Hash(), StateBlockNIDs: mapFromStateBlockNIDArray(stateBlockNIDs), StateSnapshotNID: int64(stateSnapshotNIDSeq), } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - - // state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, - docId := fmt.Sprintf("%d", stateSnapshotNIDSeq) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - - var dbData = StateSnapshotCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + var dbData = stateSnapshotCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), StateSnapshot: data, } @@ -174,30 +179,25 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( // "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + // " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []StateSnapshotCosmosData params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": stateNIDs, } - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(s.bulkSelectStateBlockNIDsStmt, params) - var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, + var rows []stateSnapshotCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) + s.getPartitionKey(), s.bulkSelectStateBlockNIDsStmt, params, &rows) + if err != nil { return nil, err } results := make([]types.StateBlockNIDList, len(stateNIDs)) i := 0 - for _, item := range response { + for _, item := range rows { result := &results[i] result.StateSnapshotNID = types.StateSnapshotNID(item.StateSnapshot.StateSnapshotNID) result.StateBlockNIDs = mapToStateBlockNIDArray(item.StateSnapshot.StateBlockNIDs) diff --git a/roomserver/storage/cosmosdb/transactions_table.go b/roomserver/storage/cosmosdb/transactions_table.go index 8e67b3679..016254512 100644 --- a/roomserver/storage/cosmosdb/transactions_table.go +++ b/roomserver/storage/cosmosdb/transactions_table.go @@ -35,16 +35,16 @@ import ( // ); // ` -type TransactionCosmos struct { +type transactionCosmos struct { TransactionID string `json:"transaction_id"` SessionID int64 `json:"session_id"` UserID string `json:"user_id"` EventID string `json:"event_id"` } -type TransactionCosmosData struct { +type transactionCosmosData struct { cosmosdbapi.CosmosDocument - Transaction TransactionCosmos `json:"mx_roomserver_transaction"` + Transaction transactionCosmos `json:"mx_roomserver_transaction"` } // const insertTransactionSQL = ` @@ -64,8 +64,16 @@ type transactionStatements struct { tableName string } -func getTransaction(s *transactionStatements, ctx context.Context, pk string, docId string) (*TransactionCosmosData, error) { - response := TransactionCosmosData{} +func (s *transactionStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *transactionStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getTransaction(s *transactionStatements, ctx context.Context, pk string, docId string) (*transactionCosmosData, error) { + response := transactionCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -103,22 +111,20 @@ func (s *transactionStatements) InsertTransaction( // INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id) // VALUES ($1, $2, $3, $4) - data := TransactionCosmos{ + + // PRIMARY KEY (transaction_id, session_id, user_id) + docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + + data := transactionCosmos{ EventID: eventID, SessionID: sessionID, TransactionID: transactionID, UserID: userID, } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - - // PRIMARY KEY (transaction_id, session_id, user_id) - docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - - var dbData = TransactionCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + var dbData = transactionCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), Transaction: data, } @@ -143,13 +149,11 @@ func (s *transactionStatements) SelectTransactionEventID( // SELECT event_id FROM roomserver_transactions // WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3 - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // PRIMARY KEY (transaction_id, session_id, user_id) docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - response, err := getTransaction(s, ctx, pk, cosmosDocId) + response, err := getTransaction(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return "", err diff --git a/signingkeyserver/storage/cosmosdb/server_key_table.go b/signingkeyserver/storage/cosmosdb/server_key_table.go index 22892ed1a..6a10a8c3b 100644 --- a/signingkeyserver/storage/cosmosdb/server_key_table.go +++ b/signingkeyserver/storage/cosmosdb/server_key_table.go @@ -51,7 +51,7 @@ import ( // CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id); // ` -type ServerKeyCosmos struct { +type serverKeyCosmos struct { ServerName string `json:"server_name"` ServerKeyID string `json:"server_key_id"` ServerNameAndKeyID string `json:"server_name_and_key_id"` @@ -60,9 +60,9 @@ type ServerKeyCosmos struct { ServerKey string `json:"server_key"` } -type ServerKeyCosmosData struct { +type serverKeyCosmosData struct { cosmosdbapi.CosmosDocument - ServerKey ServerKeyCosmos `json:"mx_keydb_server_key"` + ServerKey serverKeyCosmos `json:"mx_keydb_server_key"` } // "SELECT server_name, server_key_id, valid_until_ts, expired_ts, " + @@ -87,8 +87,16 @@ type serverKeyStatements struct { tableName string } -func getServerKey(s *serverKeyStatements, ctx context.Context, pk string, docId string) (*ServerKeyCosmosData, error) { - response := ServerKeyCosmosData{} +func (s *serverKeyStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *serverKeyStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getServerKey(s *serverKeyStatements, ctx context.Context, pk string, docId string) (*serverKeyCosmosData, error) { + response := serverKeyCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -104,27 +112,6 @@ func getServerKey(s *serverKeyStatements, ctx context.Context, pk string, docId return &response, err } -func queryServerKey(s *serverKeyStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ServerKeyCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []ServerKeyCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - func (s *serverKeyStatements) prepare(db *Database, writer sqlutil.Writer) (err error) { s.db = db s.writer = writer @@ -146,9 +133,8 @@ func (s *serverKeyStatements) bulkSelectServerKeys( // iKeyIDs[i] = v // } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": nameAndKeyIDs, } @@ -159,8 +145,12 @@ func (s *serverKeyStatements) bulkSelectServerKeys( // err := sqlutil.RunLimitedVariablesQuery( // ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables, // func(rows *sql.Rows) error { - - rows, err := queryServerKey(s, ctx, bulkSelectServerKeysSQL, params) + var rows []serverKeyCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), bulkSelectServerKeysSQL, params, &rows) if err != nil { return nil, err @@ -213,20 +203,18 @@ func (s *serverKeyStatements) upsertServerKeys( // stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // UNIQUE (server_name, server_key_id) docId := fmt.Sprintf("%s_%s", string(request.ServerName), string(request.KeyID)) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getServerKey(s, ctx, pk, cosmosDocId) + dbData, _ := getServerKey(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.ServerKey.ValidUntilTimestamp = int64(key.ValidUntilTS) dbData.ServerKey.ExpiredTimestamp = int64(key.ExpiredTS) dbData.ServerKey.ServerKey = key.Key.Encode() } else { - data := ServerKeyCosmos{ + data := serverKeyCosmos{ ServerName: string(request.ServerName), ServerKeyID: string(request.KeyID), ServerNameAndKeyID: nameAndKeyID(request), @@ -235,8 +223,8 @@ func (s *serverKeyStatements) upsertServerKeys( ServerKey: key.Key.Encode(), } - dbData = &ServerKeyCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &serverKeyCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), ServerKey: data, } } diff --git a/syncapi/storage/cosmosdb/account_data_table.go b/syncapi/storage/cosmosdb/account_data_table.go index a401741c6..670b4f309 100644 --- a/syncapi/storage/cosmosdb/account_data_table.go +++ b/syncapi/storage/cosmosdb/account_data_table.go @@ -39,7 +39,7 @@ import ( // ); // ` -type AccountDataTypeCosmos struct { +type accountDataTypeCosmos struct { ID int64 `json:"id"` UserID string `json:"user_id"` RoomID string `json:"room_id"` @@ -50,9 +50,9 @@ type AccountDataTypeNumberCosmosData struct { Number int64 `json:"number"` } -type AccountDataTypeCosmosData struct { +type accountDataTypeCosmosData struct { cosmosdbapi.CosmosDocument - AccountDataType AccountDataTypeCosmos `json:"mx_syncapi_account_data_type"` + AccountDataType accountDataTypeCosmos `json:"mx_syncapi_account_data_type"` } // const insertAccountDataSQL = "" + @@ -83,8 +83,16 @@ type accountDataStatements struct { tableName string } -func getAccountDataType(s *accountDataStatements, ctx context.Context, pk string, docId string) (*AccountDataTypeCosmosData, error) { - response := AccountDataTypeCosmosData{} +func (s *accountDataStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *accountDataStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getAccountDataType(s *accountDataStatements, ctx context.Context, pk string, docId string) (*accountDataTypeCosmosData, error) { + response := accountDataTypeCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -100,48 +108,6 @@ func getAccountDataType(s *accountDataStatements, ctx context.Context, pk string return &response, err } -func queryAccountDataType(s *accountDataStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountDataTypeCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []AccountDataTypeCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - -func queryAccountDataTypeNumber(s *accountDataStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountDataTypeNumberCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []AccountDataTypeNumberCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, cosmosdbutil.ErrNoRows - } - return response, nil -} - func NewCosmosDBAccountDataTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.AccountData, error) { s := &accountDataStatements{ db: db, @@ -168,26 +134,24 @@ func (s *accountDataStatements) InsertAccountData( return } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // UNIQUE (user_id, room_id, type) docId := fmt.Sprintf("%s_%s_%s", userID, roomID, dataType) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getAccountDataType(s, ctx, pk, cosmosDocId) + dbData, _ := getAccountDataType(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() dbData.AccountDataType.ID = int64(pos) } else { - data := AccountDataTypeCosmos{ + data := accountDataTypeCosmos{ ID: int64(pos), UserID: userID, RoomID: roomID, DataType: dataType, } - dbData = &AccountDataTypeCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &accountDataTypeCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), AccountDataType: data, } } @@ -216,15 +180,18 @@ func (s *accountDataStatements) SelectAccountDataInRange( // " ORDER BY id ASC" // rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High()) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": userID, "@x3": r.Low(), "@x4": r.High(), } - - rows, err := queryAccountDataType(s, ctx, s.selectAccountDataInRangeStmt, params) + var rows []accountDataTypeCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectAccountDataInRangeStmt, params, &rows) if err != nil { return @@ -276,12 +243,16 @@ func (s *accountDataStatements) SelectMaxAccountDataID( var nullableID sql.NullInt64 // err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), } - rows, err := queryAccountDataTypeNumber(s, ctx, s.selectMaxAccountDataIDStmt, params) + var rows []AccountDataTypeNumberCosmosData + err = cosmosdbapi.PerformQueryAllPartitions(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.selectMaxAccountDataIDStmt, params, &rows) if err != cosmosdbutil.ErrNoRows && len(rows) == 1 { nullableID.Int64 = rows[0].Number diff --git a/syncapi/storage/cosmosdb/backwards_extremities_table.go b/syncapi/storage/cosmosdb/backwards_extremities_table.go index 944a043d7..113e7e573 100644 --- a/syncapi/storage/cosmosdb/backwards_extremities_table.go +++ b/syncapi/storage/cosmosdb/backwards_extremities_table.go @@ -37,15 +37,15 @@ import ( // ); // ` -type BackwardExtremityCosmos struct { +type backwardExtremityCosmos struct { RoomID string `json:"room_id"` EventID string `json:"event_id"` PrevEventID string `json:"prev_event_id"` } -type BackwardExtremityCosmosData struct { +type backwardExtremityCosmosData struct { cosmosdbapi.CosmosDocument - BackwardExtremity BackwardExtremityCosmos `json:"mx_syncapi_backward_extremity"` + BackwardExtremity backwardExtremityCosmos `json:"mx_syncapi_backward_extremity"` } // const insertBackwardExtremitySQL = "" + @@ -78,8 +78,16 @@ type backwardExtremitiesStatements struct { tableName string } -func getBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, pk string, docId string) (*BackwardExtremityCosmosData, error) { - response := BackwardExtremityCosmosData{} +func (s *backwardExtremitiesStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *backwardExtremitiesStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, pk string, docId string) (*backwardExtremityCosmosData, error) { + response := backwardExtremityCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -95,28 +103,7 @@ func getBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, return &response, err } -func queryBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]BackwardExtremityCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []BackwardExtremityCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - -func deleteBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, dbData BackwardExtremityCosmosData) error { +func deleteBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, dbData backwardExtremityCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, @@ -152,24 +139,22 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( // _, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // PRIMARY KEY(room_id, event_id, prev_event_id) docId := fmt.Sprintf("%s_%s_%s", roomID, eventID, prevEventID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getBackwardExtremity(s, ctx, pk, cosmosDocId) + dbData, _ := getBackwardExtremity(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { dbData.SetUpdateTime() } else { - data := BackwardExtremityCosmos{ + data := backwardExtremityCosmos{ EventID: eventID, PrevEventID: prevEventID, RoomID: roomID, } - dbData = &BackwardExtremityCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &backwardExtremityCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), BackwardExtremity: data, } } @@ -191,13 +176,16 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( // "SELECT event_id, prev_event_id FROM syncapi_backward_extremities WHERE room_id = $1" // rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, } - - rows, err := queryBackwardExtremity(s, ctx, s.selectBackwardExtremitiesForRoomStmt, params) + var rows []backwardExtremityCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectBackwardExtremitiesForRoomStmt, params, &rows) if err != nil { return @@ -223,14 +211,18 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( // _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, "@x3": knownEventID, } + var rows []backwardExtremityCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deleteBackwardExtremityStmt, params, &rows) - rows, err := queryBackwardExtremity(s, ctx, s.deleteBackwardExtremityStmt, params) if err != nil { return } @@ -249,13 +241,18 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom( // _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, } - rows, err := queryBackwardExtremity(s, ctx, s.deleteBackwardExtremitiesForRoomStmt, params) + var rows []backwardExtremityCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deleteBackwardExtremitiesForRoomStmt, params, &rows) + if err != nil { return } diff --git a/syncapi/storage/cosmosdb/current_room_state_table.go b/syncapi/storage/cosmosdb/current_room_state_table.go index eea182cac..db14c8edb 100644 --- a/syncapi/storage/cosmosdb/current_room_state_table.go +++ b/syncapi/storage/cosmosdb/current_room_state_table.go @@ -52,7 +52,7 @@ import ( // CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id); // ` -type CurrentRoomStateCosmos struct { +type currentRoomStateCosmos struct { RoomID string `json:"room_id"` EventID string `json:"event_id"` Type string `json:"type"` @@ -64,9 +64,9 @@ type CurrentRoomStateCosmos struct { AddedAt int64 `json:"added_at"` } -type CurrentRoomStateCosmosData struct { +type currentRoomStateCosmosData struct { cosmosdbapi.CosmosDocument - CurrentRoomState CurrentRoomStateCosmos `json:"mx_syncapi_current_room_state"` + CurrentRoomState currentRoomStateCosmos `json:"mx_syncapi_current_room_state"` } // const upsertRoomStateSQL = "" + @@ -132,50 +132,16 @@ type currentRoomStateStatements struct { jsonPropertyName string } -func queryCurrentRoomState(s *currentRoomStateStatements, ctx context.Context, qry string, params map[string]interface{}) ([]CurrentRoomStateCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []CurrentRoomStateCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *currentRoomStateStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func queryCurrentRoomStateDistinct(s *currentRoomStateStatements, ctx context.Context, qry string, params map[string]interface{}) ([]CurrentRoomStateCosmos, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []CurrentRoomStateCosmos - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *currentRoomStateStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } -func getEvent(s *currentRoomStateStatements, ctx context.Context, pk string, docId string) (*CurrentRoomStateCosmosData, error) { - response := CurrentRoomStateCosmosData{} +func getEvent(s *currentRoomStateStatements, ctx context.Context, pk string, docId string) (*currentRoomStateCosmosData, error) { + response := currentRoomStateCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -191,7 +157,7 @@ func getEvent(s *currentRoomStateStatements, ctx context.Context, pk string, doc return &response, err } -func deleteCurrentRoomState(s *currentRoomStateStatements, ctx context.Context, dbData CurrentRoomStateCosmosData) error { +func deleteCurrentRoomState(s *currentRoomStateStatements, ctx context.Context, dbData currentRoomStateCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, @@ -228,12 +194,15 @@ func (s *currentRoomStateStatements) SelectJoinedUsers( // "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" // rows, err := s.selectJoinedUsersStmt.QueryContext(ctx) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), } - - rows, err := queryCurrentRoomState(s, ctx, s.selectJoinedUsersStmt, params) + var rows []currentRoomStateCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectJoinedUsersStmt, params, &rows) if err != nil { return nil, err @@ -264,14 +233,18 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( // stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) // rows, err := stmt.QueryContext(ctx, userID, membership) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": userID, "@x3": membership, } - rows, err := queryCurrentRoomStateDistinct(s, ctx, s.selectRoomIDsWithMembershipStmt, params) + var rows []currentRoomStateCosmos + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectRoomIDsWithMembershipStmt, params, &rows) if err != nil { return nil, err @@ -296,9 +269,8 @@ func (s *currentRoomStateStatements) SelectCurrentState( // "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" // // WHEN, ORDER BY and LIMIT will be added by prepareWithFilter - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, "@x3": stateFilter.Limit, } @@ -309,7 +281,12 @@ func (s *currentRoomStateStatements) SelectCurrentState( stateFilter.Types, stateFilter.NotTypes, excludeEventIDs, stateFilter.Limit, FilterOrderNone, ) - rows, err := queryCurrentRoomState(s, ctx, stmt, params) + var rows []currentRoomStateCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), stmt, params, &rows) if err != nil { return nil, err @@ -325,13 +302,16 @@ func (s *currentRoomStateStatements) DeleteRoomStateByEventID( // "DELETE FROM syncapi_current_room_state WHERE event_id = $1" // stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventID, } - - rows, err := queryCurrentRoomState(s, ctx, s.deleteRoomStateByEventIDStmt, params) + var rows []currentRoomStateCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deleteRoomStateByEventIDStmt, params, &rows) for _, item := range rows { err = deleteCurrentRoomState(s, ctx, item) @@ -348,13 +328,16 @@ func (s *currentRoomStateStatements) DeleteRoomStateForRoom( // "DELETE FROM syncapi_current_room_state WHERE event_id = $1" // stmt := sqlutil.TxStmt(txn, s.DeleteRoomStateForRoomStmt) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, } - - rows, err := queryCurrentRoomState(s, ctx, s.DeleteRoomStateForRoomStmt, params) + var rows []currentRoomStateCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.DeleteRoomStateForRoomStmt, params, &rows) for _, item := range rows { err = deleteCurrentRoomState(s, ctx, item) @@ -407,18 +390,16 @@ func (s *currentRoomStateStatements) UpsertRoomState( // addedAt, // ) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // " ON CONFLICT (room_id, type, state_key)" + docId := fmt.Sprintf("%s_%s_%s", event.RoomID(), event.Type(), *event.StateKey()) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) membershipData := "" if membership != nil { membershipData = *membership } - dbData, _ := getEvent(s, ctx, pk, cosmosDocId) + dbData, _ := getEvent(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { // " DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9" dbData.SetUpdateTime() @@ -429,7 +410,7 @@ func (s *currentRoomStateStatements) UpsertRoomState( dbData.CurrentRoomState.Membership = membershipData dbData.CurrentRoomState.AddedAt = int64(addedAt) } else { - data := CurrentRoomStateCosmos{ + data := currentRoomStateCosmos{ RoomID: event.RoomID(), EventID: event.EventID(), Type: event.Type(), @@ -441,8 +422,8 @@ func (s *currentRoomStateStatements) UpsertRoomState( AddedAt: int64(addedAt), } - dbData = &CurrentRoomStateCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = ¤tRoomStateCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), CurrentRoomState: data, } } @@ -480,13 +461,17 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs( // query := strings.Replace(selectEventsWithEventIDsSQL, "@x2", sql.QueryVariadic(n), 1) // rows, err := txn.QueryContext(ctx, query, iEventIDs[start:start+n]...) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventIDs, } - rows, err := queryCurrentRoomState(s, ctx, s.DeleteRoomStateForRoomStmt, params) + var rows []currentRoomStateCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.DeleteRoomStateForRoomStmt, params, &rows) if err != nil { return nil, err @@ -502,7 +487,7 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs( } // Copied from output_room_events_table -func rowsToStreamEventsFromCurrentRoomState(rows *[]CurrentRoomStateCosmosData) ([]types.StreamEvent, error) { +func rowsToStreamEventsFromCurrentRoomState(rows *[]currentRoomStateCosmosData) ([]types.StreamEvent, error) { var result []types.StreamEvent for _, item := range *rows { var ( @@ -546,7 +531,7 @@ func rowsToStreamEventsFromCurrentRoomState(rows *[]CurrentRoomStateCosmosData) return result, nil } -func rowsToEvents(rows *[]CurrentRoomStateCosmosData) ([]*gomatrixserverlib.HeaderedEvent, error) { +func rowsToEvents(rows *[]currentRoomStateCosmosData) ([]*gomatrixserverlib.HeaderedEvent, error) { result := []*gomatrixserverlib.HeaderedEvent{} for _, item := range *rows { var eventID string @@ -570,12 +555,10 @@ func (s *currentRoomStateStatements) SelectStateEvent( // stmt := s.selectStateEventStmt var res []byte - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) // " ON CONFLICT (room_id, type, state_key)" + docId := fmt.Sprintf("%s_%s_%s", roomID, evType, stateKey) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - var response, err = getEvent(s, ctx, pk, cosmosDocId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + var response, err = getEvent(s, ctx, s.getPartitionKey(), cosmosDocId) // err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res) if err == cosmosdbutil.ErrNoRows { diff --git a/syncapi/storage/cosmosdb/filter_table.go b/syncapi/storage/cosmosdb/filter_table.go index a9fec490f..c6bb154dc 100644 --- a/syncapi/storage/cosmosdb/filter_table.go +++ b/syncapi/storage/cosmosdb/filter_table.go @@ -41,15 +41,15 @@ import ( // CREATE INDEX IF NOT EXISTS syncapi_filter_localpart ON syncapi_filter(localpart); // ` -type FilterCosmos struct { +type filterCosmos struct { ID int64 `json:"id"` Filter []byte `json:"filter"` Localpart string `json:"localpart"` } -type FilterCosmosData struct { +type filterCosmosData struct { cosmosdbapi.CosmosDocument - Filter FilterCosmos `json:"mx_syncapi_filter"` + Filter filterCosmos `json:"mx_syncapi_filter"` } // const selectFilterSQL = "" + @@ -72,34 +72,16 @@ type filterStatements struct { tableName string } -func queryFilter(s *filterStatements, ctx context.Context, qry string, params map[string]interface{}) ([]FilterCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []FilterCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - - if len(response) == 0 { - return nil, cosmosdbutil.ErrNoRows - } - - return response, nil +func (s *filterStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func getFilter(s *filterStatements, ctx context.Context, pk string, docId string) (*FilterCosmosData, error) { - response := FilterCosmosData{} +func (s *filterStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getFilter(s *filterStatements, ctx context.Context, pk string, docId string) (*filterCosmosData, error) { + response := filterCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -134,12 +116,10 @@ func (s *filterStatements) SelectFilter( var filterData []byte // err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // UNIQUE (id, localpart) docId := fmt.Sprintf("%s_%s", localpart, filterID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response, err = getFilter(s, ctx, pk, cosmosDocId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + var response, err = getFilter(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return nil, err @@ -187,21 +167,25 @@ func (s *filterStatements) InsertFilter( // TODO: See if we can avoid the search by Content []byte // "SELECT id FROM syncapi_filter WHERE localpart = $1 AND filter = $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": localpart, "@x3": filterJSON, } - response, err := queryFilter(s, ctx, s.selectFilterIDByContentStmt, params) + var rows []filterCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectFilterIDByContentStmt, params, &rows) if err != nil && err != cosmosdbutil.ErrNoRows { return "", err } - if response != nil { - existingFilterID = fmt.Sprintf("%d", response[0].Filter.ID) + if len(rows) > 0 { + existingFilterID = fmt.Sprintf("%d", rows[0].Filter.ID) } // If it does, return the existing ID if existingFilterID != "" { @@ -217,7 +201,7 @@ func (s *filterStatements) InsertFilter( return "", seqErr } - data := FilterCosmos{ + data := filterCosmos{ ID: seqID, Localpart: localpart, Filter: filterJSON, @@ -225,11 +209,10 @@ func (s *filterStatements) InsertFilter( // UNIQUE (id, localpart) docId := fmt.Sprintf("%s_%d", localpart, seqID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - var dbData = FilterCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + var dbData = filterCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), Filter: data, } diff --git a/syncapi/storage/cosmosdb/invites_table.go b/syncapi/storage/cosmosdb/invites_table.go index 2dc4188ad..4af177156 100644 --- a/syncapi/storage/cosmosdb/invites_table.go +++ b/syncapi/storage/cosmosdb/invites_table.go @@ -43,7 +43,7 @@ import ( // CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx ON syncapi_invite_events (event_id); // ` -type InviteEventCosmos struct { +type inviteEventCosmos struct { ID int64 `json:"id"` EventID string `json:"event_id"` RoomID string `json:"room_id"` @@ -52,13 +52,13 @@ type InviteEventCosmos struct { Deleted bool `json:"deleted"` } -type InviteEventCosmosMaxNumber struct { +type inviteEventCosmosMaxNumber struct { Max int64 `json:"number"` } -type InviteEventCosmosData struct { +type inviteEventCosmosData struct { cosmosdbapi.CosmosDocument - InviteEvent InviteEventCosmos `json:"mx_syncapi_invite_event"` + InviteEvent inviteEventCosmos `json:"mx_syncapi_invite_event"` } // const insertInviteEventSQL = "" + @@ -95,51 +95,16 @@ type inviteEventsStatements struct { tableName string } -func queryInviteEvent(s *inviteEventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InviteEventCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []InviteEventCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *inviteEventsStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func queryInviteEventMaxNumber(s *inviteEventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InviteEventCosmosMaxNumber, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []InviteEventCosmosMaxNumber - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, nil - } - - return response, nil +func (s *inviteEventsStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } -func getInviteEvent(s *inviteEventsStatements, ctx context.Context, pk string, docId string) (*InviteEventCosmosData, error) { - response := InviteEventCosmosData{} +func getInviteEvent(s *inviteEventsStatements, ctx context.Context, pk string, docId string) (*inviteEventCosmosData, error) { + response := inviteEventCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -155,7 +120,7 @@ func getInviteEvent(s *inviteEventsStatements, ctx context.Context, pk string, d return &response, err } -func setInviteEvent(s *inviteEventsStatements, ctx context.Context, invite InviteEventCosmosData) (*InviteEventCosmosData, error) { +func setInviteEvent(s *inviteEventsStatements, ctx context.Context, invite inviteEventCosmosData) (*inviteEventCosmosData, error) { var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(invite.Pk, invite.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, @@ -207,7 +172,7 @@ func (s *inviteEventsStatements) InsertInviteEvent( // *inviteEvent.StateKey(), // headeredJSON, // ) - data := InviteEventCosmos{ + data := inviteEventCosmos{ ID: int64(streamPos), RoomID: inviteEvent.RoomID(), EventID: inviteEvent.EventID(), @@ -215,14 +180,12 @@ func (s *inviteEventsStatements) InsertInviteEvent( HeaderedEventJSON: headeredJSON, } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) // id INTEGER PRIMARY KEY, docId := fmt.Sprintf("%d", streamPos) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - var dbData = InviteEventCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + var dbData = inviteEventCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), InviteEvent: data, } @@ -249,14 +212,18 @@ func (s *inviteEventsStatements) DeleteInviteEvent( // stmt := sqlutil.TxStmt(txn, s.deleteInviteEventStmt) // _, err = stmt.ExecContext(ctx, streamPos, inviteEventID) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": inviteEventID, } - response, err := queryInviteEvent(s, ctx, s.deleteInviteEventStmt, params) + var rows []inviteEventCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deleteInviteEventStmt, params, &rows) - for _, item := range response { + for _, item := range rows { item.InviteEvent.Deleted = true item.InviteEvent.ID = int64(streamPos) setInviteEvent(s, ctx, item) @@ -276,14 +243,18 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange( // stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) // rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": targetUserID, "@x3": r.Low(), "@x4": r.High(), } - rows, err := queryInviteEvent(s, ctx, s.selectInviteEventsInRangeStmt, params) + var rows []inviteEventCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectInviteEventsInRangeStmt, params, &rows) if err != nil { return nil, nil, err @@ -333,14 +304,18 @@ func (s *inviteEventsStatements) SelectMaxInviteID( // stmt := sqlutil.TxStmt(txn, s.selectMaxInviteIDStmt) // err = stmt.QueryRowContext(ctx).Scan(&nullableID) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), } - response, err := queryInviteEventMaxNumber(s, ctx, s.selectMaxInviteIDStmt, params) + var rows []inviteEventCosmosMaxNumber + err = cosmosdbapi.PerformQueryAllPartitions(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.selectMaxInviteIDStmt, params, &rows) - if response != nil { - nullableID.Int64 = response[0].Max + if len(rows) > 0 { + nullableID.Int64 = rows[0].Max } if nullableID.Valid { diff --git a/syncapi/storage/cosmosdb/memberships_table.go b/syncapi/storage/cosmosdb/memberships_table.go index 105766da5..7c3560757 100644 --- a/syncapi/storage/cosmosdb/memberships_table.go +++ b/syncapi/storage/cosmosdb/memberships_table.go @@ -51,7 +51,7 @@ import ( // ); // ` -type MembershipCosmos struct { +type membershipCosmos struct { RoomID string `json:"room_id"` UserID string `json:"user_id"` Membership string `json:"membership"` @@ -60,9 +60,9 @@ type MembershipCosmos struct { TopologicalPos int64 `json:"topological_pos"` } -type MembershipCosmosData struct { +type membershipCosmosData struct { cosmosdbapi.CosmosDocument - Membership MembershipCosmos `json:"mx_syncapi_membership"` + Membership membershipCosmos `json:"mx_syncapi_membership"` } // const upsertMembershipSQL = "" + @@ -88,8 +88,16 @@ type membershipsStatements struct { tableName string } -func getMembership(s *membershipsStatements, ctx context.Context, pk string, docId string) (*MembershipCosmosData, error) { - response := MembershipCosmosData{} +func (s *membershipsStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *membershipsStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getMembership(s *membershipsStatements, ctx context.Context, pk string, docId string) (*membershipCosmosData, error) { + response := membershipCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -105,27 +113,6 @@ func getMembership(s *membershipsStatements, ctx context.Context, pk string, doc return &response, err } -func queryMembership(s *membershipsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]MembershipCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []MembershipCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - func NewCosmosDBMembershipsTable(db *SyncServerDatasource) (tables.Memberships, error) { s := &membershipsStatements{ db: db, @@ -158,13 +145,11 @@ func (s *membershipsStatements) UpsertMembership( // topologicalPos, // ) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) // UNIQUE (room_id, user_id, membership) docId := fmt.Sprintf("%s_%s_%s", event.RoomID(), *event.StateKey(), membership) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getMembership(s, ctx, pk, cosmosDocId) + dbData, _ := getMembership(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { // " DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6" dbData.SetUpdateTime() @@ -172,7 +157,7 @@ func (s *membershipsStatements) UpsertMembership( dbData.Membership.StreamPos = int64(streamPos) dbData.Membership.TopologicalPos = int64(topologicalPos) } else { - data := MembershipCosmos{ + data := membershipCosmos{ RoomID: event.RoomID(), UserID: *event.StateKey(), Membership: membership, @@ -181,8 +166,8 @@ func (s *membershipsStatements) UpsertMembership( TopologicalPos: int64(topologicalPos), } - dbData = &MembershipCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &membershipCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), Membership: data, } } @@ -209,15 +194,19 @@ func (s *membershipsStatements) SelectMembership( // " LIMIT 1" // err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, "@x3": userID, "@x4": memberships, } // orig := strings.Replace(selectMembershipSQL, "@x4", cosmosdbutil.QueryVariadicOffset(len(memberships), 2), 1) - rows, err := queryMembership(s, ctx, selectMembershipSQL, params) + var rows []membershipCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), selectMembershipSQL, params, &rows) if err != nil || len(rows) == 0 { return "", 0, 0, err diff --git a/syncapi/storage/cosmosdb/output_room_events_table.go b/syncapi/storage/cosmosdb/output_room_events_table.go index 52760c322..e4d0d0f34 100644 --- a/syncapi/storage/cosmosdb/output_room_events_table.go +++ b/syncapi/storage/cosmosdb/output_room_events_table.go @@ -22,8 +22,6 @@ import ( "fmt" "sort" - "github.com/matrix-org/dendrite/internal/cosmosdbutil" - "github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/roomserver/api" @@ -52,7 +50,7 @@ import ( // ); // ` -type OutputRoomEventCosmos struct { +type outputRoomEventCosmos struct { ID int64 `json:"id"` EventID string `json:"event_id"` RoomID string `json:"room_id"` @@ -67,13 +65,13 @@ type OutputRoomEventCosmos struct { ExcludeFromSync bool `json:"exclude_from_sync"` } -type OutputRoomEventCosmosMaxNumber struct { +type outputRoomEventCosmosMaxNumber struct { Max int64 `json:"number"` } -type OutputRoomEventCosmosData struct { +type outputRoomEventCosmosData struct { cosmosdbapi.CosmosDocument - OutputRoomEvent OutputRoomEventCosmos `json:"mx_syncapi_output_room_event"` + OutputRoomEvent outputRoomEventCosmos `json:"mx_syncapi_output_room_event"` } // const insertEventSQL = "" + @@ -152,49 +150,15 @@ type outputRoomEventsStatements struct { jsonPropertyName string } -func queryOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OutputRoomEventCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []OutputRoomEventCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *outputRoomEventsStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func queryOutputRoomEventNumber(s *outputRoomEventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OutputRoomEventCosmosMaxNumber, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []OutputRoomEventCosmosMaxNumber - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, cosmosdbutil.ErrNoRows - } - return response, nil +func (s *outputRoomEventsStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } -func setOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, outputRoomEvent OutputRoomEventCosmosData) (*OutputRoomEventCosmosData, error) { +func setOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, outputRoomEvent outputRoomEventCosmosData) (*outputRoomEventCosmosData, error) { var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(outputRoomEvent.Pk, outputRoomEvent.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, @@ -206,7 +170,7 @@ func setOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, outp return &outputRoomEvent, ex } -func deleteOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, dbData OutputRoomEventCosmosData) error { +func deleteOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, dbData outputRoomEventCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, @@ -243,14 +207,19 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event // "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": event.EventID(), } // _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) - rows, err := queryOutputRoomEvent(s, ctx, s.deleteEventsForRoomStmt, params) + var rows []outputRoomEventCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deleteEventsForRoomStmt, params, &rows) + if err != nil { return err } @@ -261,7 +230,6 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event } return err - return err } // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. @@ -277,9 +245,8 @@ func (s *outputRoomEventsStatements) SelectStateInRange( // " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))" // // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": r.Low(), "@x3": r.High(), "@x4": stateFilter.Limit, @@ -292,7 +259,13 @@ func (s *outputRoomEventsStatements) SelectStateInRange( ) // rows, err := stmt.QueryContext(ctx, params...) - rows, err := queryOutputRoomEvent(s, ctx, query, params) + var rows []outputRoomEventCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), query, params, &rows) + if err != nil { return nil, nil, err } @@ -374,13 +347,17 @@ func (s *outputRoomEventsStatements) SelectMaxEventID( ) (id int64, err error) { var nullableID sql.NullInt64 - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), } // stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt) + var rows []outputRoomEventCosmosMaxNumber + err = cosmosdbapi.PerformQueryAllPartitions(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.selectMaxEventIDStmt, params, &rows) - rows, err := queryOutputRoomEventNumber(s, ctx, s.selectMaxEventIDStmt, params) // err = stmt.QueryRowContext(ctx).Scan(&nullableID) if rows != nil { @@ -464,7 +441,7 @@ func (s *outputRoomEventsStatements) InsertEvent( // excludeFromSync, // ) - data := OutputRoomEventCosmos{ + data := outputRoomEventCosmos{ ID: int64(streamPos), RoomID: event.RoomID(), EventID: event.EventID(), @@ -482,14 +459,12 @@ func (s *outputRoomEventsStatements) InsertEvent( data.TransactionID = *txnID } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) // id INTEGER PRIMARY KEY, docId := fmt.Sprintf("%d", streamPos) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - var dbData = OutputRoomEventCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + var dbData = outputRoomEventCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), OutputRoomEvent: data, } @@ -521,9 +496,8 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( query = selectRecentEventsSQL } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, "@x3": r.Low(), "@x4": r.High(), @@ -538,7 +512,12 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( ) // rows, err := stmt.QueryContext(ctx, params...) - rows, err := queryOutputRoomEvent(s, ctx, query, params) + var rows []outputRoomEventCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), query, params, &rows) if err != nil { return nil, false, err @@ -577,9 +556,8 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( // " WHERE room_id = $1 AND id > $2 AND id <= $3" // // WHEN, ORDER BY (and not LIMIT) are appended by prepareWithFilters - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, "@x3": r.Low(), "@x4": r.High(), @@ -593,7 +571,13 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( ) // rows, err := stmt.QueryContext(ctx, params...) - rows, err := queryOutputRoomEvent(s, ctx, stmt, params) + var rows []outputRoomEventCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), stmt, params, &rows) + if err != nil { return nil, err } @@ -622,14 +606,19 @@ func (s *outputRoomEventsStatements) SelectEvents( // stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) for _, eventID := range eventIDs { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventID, } // rows, err := stmt.QueryContext(ctx, eventID) - rows, err := queryOutputRoomEvent(s, ctx, s.selectEventsStmt, params) + var rows []outputRoomEventCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectEventsStmt, params, &rows) + if err != nil { return nil, err } @@ -645,14 +634,19 @@ func (s *outputRoomEventsStatements) DeleteEventsForRoom( ) (err error) { // "DELETE FROM syncapi_output_room_events WHERE room_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, } // _, err = sqlutil.TxStmt(txn, s.deleteEventsForRoomStmt).ExecContext(ctx, roomID) - rows, err := queryOutputRoomEvent(s, ctx, s.deleteEventsForRoomStmt, params) + var rows []outputRoomEventCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deleteEventsForRoomStmt, params, &rows) + if err != nil { return err } @@ -664,7 +658,7 @@ func (s *outputRoomEventsStatements) DeleteEventsForRoom( return err } -func rowsToStreamEvents(rows *[]OutputRoomEventCosmosData) ([]types.StreamEvent, error) { +func rowsToStreamEvents(rows *[]outputRoomEventCosmosData) ([]types.StreamEvent, error) { var result []types.StreamEvent for _, item := range *rows { var ( diff --git a/syncapi/storage/cosmosdb/output_room_events_topology_table.go b/syncapi/storage/cosmosdb/output_room_events_topology_table.go index 8a5c7ff83..db7d5cfd2 100644 --- a/syncapi/storage/cosmosdb/output_room_events_topology_table.go +++ b/syncapi/storage/cosmosdb/output_room_events_topology_table.go @@ -39,16 +39,16 @@ import ( // -- CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, stream_position, room_id); // ` -type OutputRoomEventTopologyCosmos struct { +type outputRoomEventTopologyCosmos struct { EventID string `json:"event_id"` TopologicalPosition int64 `json:"topological_position"` StreamPosition int64 `json:"stream_position"` RoomID string `json:"room_id"` } -type OutputRoomEventTopologyCosmosData struct { +type outputRoomEventTopologyCosmosData struct { cosmosdbapi.CosmosDocument - OutputRoomEventTopology OutputRoomEventTopologyCosmos `json:"mx_syncapi_output_room_event_topology"` + OutputRoomEventTopology outputRoomEventTopologyCosmos `json:"mx_syncapi_output_room_event_topology"` } // const insertEventInTopologySQL = "" + @@ -126,29 +126,16 @@ type outputRoomEventsTopologyStatements struct { tableName string } -func queryOutputRoomEventTopology(s *outputRoomEventsTopologyStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OutputRoomEventTopologyCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []OutputRoomEventTopologyCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *outputRoomEventsTopologyStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func getOutputRoomEventTopology(s *outputRoomEventsTopologyStatements, ctx context.Context, pk string, docId string) (*OutputRoomEventTopologyCosmosData, error) { - response := OutputRoomEventTopologyCosmosData{} +func (s *outputRoomEventsTopologyStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getOutputRoomEventTopology(s *outputRoomEventsTopologyStatements, ctx context.Context, pk string, docId string) (*outputRoomEventTopologyCosmosData, error) { + response := outputRoomEventTopologyCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -164,7 +151,7 @@ func getOutputRoomEventTopology(s *outputRoomEventsTopologyStatements, ctx conte return &response, err } -func deleteOutputRoomEventTopology(s *outputRoomEventsTopologyStatements, ctx context.Context, dbData OutputRoomEventTopologyCosmosData) error { +func deleteOutputRoomEventTopology(s *outputRoomEventsTopologyStatements, ctx context.Context, dbData outputRoomEventTopologyCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, @@ -203,26 +190,24 @@ func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( // " VALUES ($1, $2, $3, $4)" + // " ON CONFLICT DO NOTHING" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // UNIQUE(topological_position, room_id, stream_position) docId := fmt.Sprintf("%d_%s_%d", event.Depth(), event.RoomID(), pos) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) var err error - dbData, _ := getOutputRoomEventTopology(s, ctx, pk, cosmosDocId) + dbData, _ := getOutputRoomEventTopology(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { // " ON CONFLICT DO NOTHING" } else { - data := OutputRoomEventTopologyCosmos{ + data := outputRoomEventTopologyCosmos{ EventID: event.EventID(), TopologicalPosition: event.Depth(), RoomID: event.RoomID(), StreamPosition: int64(pos), } - dbData = &OutputRoomEventTopologyCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &outputRoomEventTopologyCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), OutputRoomEventTopology: data, } // _, err := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt).ExecContext( @@ -265,9 +250,8 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( } // Query the event IDs. - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, "@x3": minDepth, "@x4": maxDepth, @@ -275,8 +259,12 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( "@x6": maxStreamPos, "@x7": limit, } - - rows, err := queryOutputRoomEventTopology(s, ctx, stmt, params) + var rows []outputRoomEventTopologyCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), stmt, params, &rows) // rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit) if err == sql.ErrNoRows { @@ -308,13 +296,17 @@ func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology( // "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + // " WHERE event_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": eventID, } + var rows []outputRoomEventTopologyCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectPositionInTopologyStmt, params, &rows) - rows, err := queryOutputRoomEventTopology(s, ctx, s.selectPositionInTopologyStmt, params) // stmt := sqlutil.TxStmt(txn, s.selectPositionInTopologyStmt) if err != nil { @@ -342,13 +334,17 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( // ") ORDER BY stream_position DESC LIMIT 1" // stmt := sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, } + var rows []outputRoomEventTopologyCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectMaxPositionInTopologyStmt, params, &rows) - rows, err := queryOutputRoomEventTopology(s, ctx, s.selectMaxPositionInTopologyStmt, params) // err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) if err != nil { @@ -369,13 +365,17 @@ func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom( ) (err error) { // "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, } + var rows []outputRoomEventTopologyCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deleteTopologyForRoomStmt, params, &rows) - rows, err := queryOutputRoomEventTopology(s, ctx, s.deleteTopologyForRoomStmt, params) // _, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID) if err != nil { diff --git a/syncapi/storage/cosmosdb/peeks_table.go b/syncapi/storage/cosmosdb/peeks_table.go index d08d94cae..d13927f98 100644 --- a/syncapi/storage/cosmosdb/peeks_table.go +++ b/syncapi/storage/cosmosdb/peeks_table.go @@ -42,7 +42,7 @@ import ( // CREATE INDEX IF NOT EXISTS syncapi_peeks_user_id_device_id_idx ON syncapi_peeks(user_id, device_id); // ` -type PeekCosmos struct { +type peekCosmos struct { ID int64 `json:"id"` RoomID string `json:"room_id"` UserID string `json:"user_id"` @@ -52,13 +52,13 @@ type PeekCosmos struct { // creation_ts int64 `json:"creation_ts"` } -type PeekCosmosMaxNumber struct { +type peekCosmosMaxNumber struct { Max int64 `json:"number"` } -type PeekCosmosData struct { +type peekCosmosData struct { cosmosdbapi.CosmosDocument - Peek PeekCosmos `json:"mx_syncapi_peek"` + Peek peekCosmos `json:"mx_syncapi_peek"` } // const insertPeekSQL = "" + @@ -115,50 +115,16 @@ type peekStatements struct { tableName string } -func queryPeek(s *peekStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PeekCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []PeekCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *peekStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func queryPeekMaxNumber(s *peekStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PeekCosmosMaxNumber, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []PeekCosmosMaxNumber - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, nil - } - return response, nil +func (s *peekStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } -func getPeek(s *peekStatements, ctx context.Context, pk string, docId string) (*PeekCosmosData, error) { - response := PeekCosmosData{} +func getPeek(s *peekStatements, ctx context.Context, pk string, docId string) (*peekCosmosData, error) { + response := peekCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -174,7 +140,7 @@ func getPeek(s *peekStatements, ctx context.Context, pk string, docId string) (* return &response, err } -func setPeek(s *peekStatements, ctx context.Context, peek PeekCosmosData) (*PeekCosmosData, error) { +func setPeek(s *peekStatements, ctx context.Context, peek peekCosmosData) (*peekCosmosData, error) { var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(peek.Pk, peek.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, @@ -213,28 +179,26 @@ func (s *peekStatements) InsertPeek( // " (id, room_id, user_id, device_id, creation_ts, deleted)" + // " VALUES ($1, $2, $3, $4, $5, false)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // UNIQUE(room_id, user_id, device_id) - docId := fmt.Sprintf("%d_%s_%d", roomID, userID, deviceID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + docId := fmt.Sprintf("%s_%s_%s", roomID, userID, deviceID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getPeek(s, ctx, pk, cosmosDocId) + dbData, _ := getPeek(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { // " (id, room_id, user_id, device_id, creation_ts, deleted)" + // " VALUES ($1, $2, $3, $4, $5, false)" dbData.SetUpdateTime() dbData.Peek.Deleted = false } else { - data := PeekCosmos{ + data := peekCosmos{ ID: int64(streamPos), RoomID: roomID, UserID: userID, DeviceID: deviceID, } - dbData = &PeekCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &peekCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), Peek: data, } } @@ -257,15 +221,20 @@ func (s *peekStatements) DeletePeek( // "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3 AND device_id = $4" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, "@x3": userID, "@x4": deviceID, } - rows, err := queryPeek(s, ctx, s.deletePeekStmt, params) + var rows []peekCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deletePeekStmt, params, &rows) + // _, err = sqlutil.TxStmt(txn, s.deletePeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID) numAffected := len(rows) @@ -295,14 +264,19 @@ func (s *peekStatements) DeletePeeks( ) (types.StreamPosition, error) { // "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": roomID, "@x3": userID, } - rows, err := queryPeek(s, ctx, s.deletePeekStmt, params) + var rows []peekCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deletePeekStmt, params, &rows) + // result, err := sqlutil.TxStmt(txn, s.deletePeeksStmt).ExecContext(ctx, streamPos, roomID, userID) if err != nil { return 0, err @@ -334,16 +308,20 @@ func (s *peekStatements) SelectPeeksInRange( ) (peeks []types.Peek, err error) { // "SELECT id, room_id, deleted FROM syncapi_peeks WHERE user_id = $1 AND device_id = $2 AND ((id <= $3 AND NOT deleted=true) OR (id > $3 AND id <= $4))" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": userID, "@x3": deviceID, "@x4": r.Low(), "@x5": r.High(), } + var rows []peekCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectPeeksInRangeStmt, params, &rows) - rows, err := queryPeek(s, ctx, s.selectPeeksInRangeStmt, params) // rows, err := sqlutil.TxStmt(txn, s.selectPeeksInRangeStmt).QueryContext(ctx, userID, deviceID, r.Low(), r.High()) if err != nil { return @@ -371,12 +349,17 @@ func (s *peekStatements) SelectPeekingDevices( // "SELECT room_id, user_id, device_id FROM syncapi_peeks WHERE deleted=false" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), } - rows, err := queryPeek(s, ctx, s.selectPeekingDevicesStmt, params) + var rows []peekCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectPeekingDevicesStmt, params, &rows) + // rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx) if err != nil { return nil, err @@ -405,12 +388,16 @@ func (s *peekStatements) SelectMaxPeekID( // stmt := sqlutil.TxStmt(txn, s.selectMaxPeekIDStmt) var nullableID sql.NullInt64 - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), } + var rows []peekCosmosMaxNumber + err = cosmosdbapi.PerformQueryAllPartitions(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.selectMaxPeekIDStmt, params, &rows) - rows, err := queryPeekMaxNumber(s, ctx, s.selectMaxPeekIDStmt, params) // err = stmt.QueryRowContext(ctx).Scan(&nullableID) if rows != nil { diff --git a/syncapi/storage/cosmosdb/receipt_table.go b/syncapi/storage/cosmosdb/receipt_table.go index 1b2a03b83..73e792c4b 100644 --- a/syncapi/storage/cosmosdb/receipt_table.go +++ b/syncapi/storage/cosmosdb/receipt_table.go @@ -41,7 +41,7 @@ import ( // CREATE INDEX IF NOT EXISTS syncapi_receipts_room_id_idx ON syncapi_receipts(room_id); // ` -type ReceiptCosmos struct { +type receiptCosmos struct { ID int64 `json:"id"` RoomID string `json:"room_id"` ReceiptType string `json:"receipt_type"` @@ -50,13 +50,13 @@ type ReceiptCosmos struct { ReceiptTS int64 `json:"receipt_ts"` } -type ReceiptCosmosMaxNumber struct { +type receiptCosmosMaxNumber struct { Max int64 `json:"number"` } -type ReceiptCosmosData struct { +type receiptCosmosData struct { cosmosdbapi.CosmosDocument - Receipt ReceiptCosmos `json:"mx_syncapi_receipt"` + Receipt receiptCosmos `json:"mx_syncapi_receipt"` } // const upsertReceipt = "" + @@ -87,46 +87,12 @@ type receiptStatements struct { tableName string } -func queryReceipt(s *receiptStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ReceiptCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []ReceiptCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *receiptStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func queryReceiptNumber(s *receiptStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ReceiptCosmosMaxNumber, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []ReceiptCosmosMaxNumber - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, nil - } - return response, nil +func (s *receiptStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } func NewCosmosDBReceiptsTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.Receipts, error) { @@ -152,7 +118,11 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room // " ON CONFLICT (room_id, receipt_type, user_id)" + // " DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9" - data := ReceiptCosmos{ + // CONSTRAINT syncapi_receipts_unique UNIQUE (room_id, receipt_type, user_id) + docId := fmt.Sprintf("%s_%s_%s", roomId, receiptType, userId) + cosmosDocId := cosmosdbapi.GetDocumentId(r.db.cosmosConfig.ContainerName, r.getCollectionName(), docId) + + data := receiptCosmos{ ID: int64(pos), RoomID: roomId, ReceiptType: receiptType, @@ -161,14 +131,8 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room ReceiptTS: int64(timestamp), } - var dbCollectionName = cosmosdbapi.GetCollectionName(r.db.databaseName, r.tableName) - var pk = cosmosdbapi.GetPartitionKey(r.db.cosmosConfig.ContainerName, dbCollectionName) - // CONSTRAINT syncapi_receipts_unique UNIQUE (room_id, receipt_type, user_id) - docId := fmt.Sprintf("%s_%s_%s", roomId, receiptType, userId) - cosmosDocId := cosmosdbapi.GetDocumentId(r.db.cosmosConfig.ContainerName, dbCollectionName, docId) - - var dbData = ReceiptCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, r.db.cosmosConfig.TenantName, pk, cosmosDocId), + var dbData = receiptCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(r.getCollectionName(), r.db.cosmosConfig.TenantName, r.getPartitionKey(), cosmosDocId), Receipt: data, } @@ -197,14 +161,18 @@ func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs // for k, v := range roomIDs { // params[k+1] = v - var dbCollectionName = cosmosdbapi.GetCollectionName(r.db.databaseName, r.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": r.getCollectionName(), "@x2": streamPos, "@x3": roomIDs, } + var rows []receiptCosmosData + err := cosmosdbapi.PerformQuery(ctx, + r.db.connection, + r.db.cosmosConfig.DatabaseName, + r.db.cosmosConfig.ContainerName, + r.getPartitionKey(), selectRoomReceipts, params, &rows) - rows, err := queryReceipt(r, ctx, selectRoomReceipts, params) // rows, err := r.db.QueryContext(ctx, selectSQL, params...) if err != nil { return 0, nil, fmt.Errorf("unable to query room receipts: %w", err) @@ -239,12 +207,16 @@ func (s *receiptStatements) SelectMaxReceiptID( // "SELECT MAX(id) FROM syncapi_receipts" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), } + var rows []receiptCosmosMaxNumber + err = cosmosdbapi.PerformQueryAllPartitions(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.selectMaxReceiptID, params, &rows) - rows, err := queryReceiptNumber(s, ctx, s.selectMaxReceiptID, params) // stmt := sqlutil.TxStmt(txn, s.selectMaxReceiptID) if rows != nil { diff --git a/syncapi/storage/cosmosdb/send_to_device_table.go b/syncapi/storage/cosmosdb/send_to_device_table.go index 7db89115e..719d7801a 100644 --- a/syncapi/storage/cosmosdb/send_to_device_table.go +++ b/syncapi/storage/cosmosdb/send_to_device_table.go @@ -94,46 +94,12 @@ type sendToDeviceStatements struct { tableName string } -func querySendToDevice(s *sendToDeviceStatements, ctx context.Context, qry string, params map[string]interface{}) ([]SendToDeviceCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []SendToDeviceCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *sendToDeviceStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func querySendToDeviceNumber(s *sendToDeviceStatements, ctx context.Context, qry string, params map[string]interface{}) ([]SendToDeviceCosmosMaxNumber, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []SendToDeviceCosmosMaxNumber - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, nil - } - return response, nil +func (s *sendToDeviceStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } func deleteSendToDevice(s *sendToDeviceStatements, ctx context.Context, dbData SendToDeviceCosmosData) error { @@ -180,6 +146,10 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage( // INSERT INTO syncapi_send_to_device (user_id, device_id, content) // VALUES ($1, $2, $3) + // NO CONSTRAINT + docId := fmt.Sprintf("%d", pos) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + data := SendToDeviceCosmos{ ID: int64(pos), UserID: userID, @@ -187,14 +157,8 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage( Content: content, } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - // NO CONSTRAINT - docId := fmt.Sprintf("%d", pos) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - var dbData = SendToDeviceCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), SendToDevice: data, } @@ -217,16 +181,21 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages( // WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4 // ORDER BY id DESC - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": userID, "@x3": deviceID, "@x4": from, "@x5": to, } - rows, err := querySendToDevice(s, ctx, s.selectSendToDeviceMessagesStmt, params) + var rows []SendToDeviceCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectSendToDeviceMessagesStmt, params, &rows) + if err != nil { return } @@ -268,16 +237,21 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( // DELETE FROM syncapi_send_to_device // WHERE user_id = $1 AND device_id = $2 AND id < $3 - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": userID, "@x3": deviceID, "@x4": pos, } // _, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos) - rows, err := querySendToDevice(s, ctx, s.deleteSendToDeviceMessagesStmt, params) + var rows []SendToDeviceCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.deleteSendToDeviceMessagesStmt, params, &rows) + if err != nil { return err } @@ -297,12 +271,16 @@ func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID( var nullableID sql.NullInt64 // "SELECT MAX(id) FROM syncapi_send_to_device" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), } + var rows []SendToDeviceCosmosMaxNumber + err = cosmosdbapi.PerformQueryAllPartitions(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.selectMaxSendToDeviceIDStmt, params, &rows) - rows, err := querySendToDeviceNumber(s, ctx, s.selectMaxSendToDeviceIDStmt, params) // stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt) // err = stmt.QueryRowContext(ctx).Scan(&nullableID) diff --git a/userapi/storage/accounts/cosmosdb/account_data_table.go b/userapi/storage/accounts/cosmosdb/account_data_table.go index 40750dea0..9b76b1fac 100644 --- a/userapi/storage/accounts/cosmosdb/account_data_table.go +++ b/userapi/storage/accounts/cosmosdb/account_data_table.go @@ -38,18 +38,18 @@ import ( // ); // ` -type AccountDataCosmosData struct { - cosmosdbapi.CosmosDocument - AccountData AccountDataCosmos `json:"mx_userapi_accountdata"` -} - -type AccountDataCosmos struct { +type accountDataCosmos struct { LocalPart string `json:"local_part"` RoomId string `json:"room_id"` Type string `json:"type"` Content []byte `json:"content"` } +type accountDataCosmosData struct { + cosmosdbapi.CosmosDocument + AccountData accountDataCosmos `json:"mx_userapi_accountdata"` +} + type accountDataStatements struct { db *Database // insertAccountDataStmt *sql.Stmt @@ -58,6 +58,14 @@ type accountDataStatements struct { tableName string } +func (s *accountDataStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *accountDataStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + func (s *accountDataStatements) prepare(db *Database) (err error) { s.db = db s.selectAccountDataStmt = "select * from c where c._cn = @x1 and c.mx_userapi_accountdata.local_part = @x2" @@ -66,8 +74,8 @@ func (s *accountDataStatements) prepare(db *Database) (err error) { return } -func getAccountData(s *accountDataStatements, ctx context.Context, pk string, docId string) (*AccountDataCosmosData, error) { - response := AccountDataCosmosData{} +func getAccountData(s *accountDataStatements, ctx context.Context, pk string, docId string) (*accountDataCosmosData, error) { + response := accountDataCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -83,34 +91,12 @@ func getAccountData(s *accountDataStatements, ctx context.Context, pk string, do return &response, err } -func queryAccountData(s *accountDataStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountDataCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []AccountDataCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - func (s *accountDataStatements) insertAccountData( ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { // INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) // ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) id := "" if roomID == "" { id = fmt.Sprintf("%s_%s", localpart, dataType) @@ -119,24 +105,23 @@ func (s *accountDataStatements) insertAccountData( } docId := id - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - dbData, _ := getAccountData(s, ctx, pk, cosmosDocId) + dbData, _ := getAccountData(s, ctx, s.getPartitionKey(), cosmosDocId) if dbData != nil { // ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 dbData.SetUpdateTime() dbData.AccountData.Content = content } else { - var result = AccountDataCosmos{ + var result = accountDataCosmos{ LocalPart: localpart, RoomId: roomID, Type: dataType, Content: content, } - dbData = &AccountDataCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData = &accountDataCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), AccountData: result, } } @@ -157,13 +142,16 @@ func (s *accountDataStatements) selectAccountData( error, ) { // "SELECT room_id, type, content FROM account_data WHERE localpart = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": localpart, } - - response, err := queryAccountData(s, ctx, s.selectAccountDataStmt, params) + var rows []accountDataCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectAccountDataStmt, params, &rows) if err != nil { return nil, nil, err @@ -172,8 +160,8 @@ func (s *accountDataStatements) selectAccountData( global := map[string]json.RawMessage{} rooms := map[string]map[string]json.RawMessage{} - for i := 0; i < len(response); i++ { - var row = response[i] + for i := 0; i < len(rows); i++ { + var row = rows[i] var roomID = row.AccountData.RoomId if roomID != "" { if _, ok := rooms[row.AccountData.RoomId]; !ok { @@ -194,25 +182,28 @@ func (s *accountDataStatements) selectAccountDataByType( var bytes []byte // "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": localpart, "@x3": roomID, "@x4": dataType, } - - response, err := queryAccountData(s, ctx, s.selectAccountDataByTypeStmt, params) + var rows []accountDataCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectAccountDataByTypeStmt, params, &rows) if err != nil { return nil, err } - if len(response) == 0 { + if len(rows) == 0 { return data, nil } - bytes = response[0].AccountData.Content + bytes = rows[0].AccountData.Content data = json.RawMessage(bytes) return diff --git a/userapi/storage/accounts/cosmosdb/accounts_table.go b/userapi/storage/accounts/cosmosdb/accounts_table.go index 7ff922417..b5727ee12 100644 --- a/userapi/storage/accounts/cosmosdb/accounts_table.go +++ b/userapi/storage/accounts/cosmosdb/accounts_table.go @@ -46,7 +46,7 @@ import ( // ); // ` -type AccountCosmos struct { +type accountCosmos struct { UserID string `json:"user_id"` Localpart string `json:"local_part"` ServerName gomatrixserverlib.ServerName `json:"server_name"` @@ -56,12 +56,12 @@ type AccountCosmos struct { Created int64 `json:"created_ts"` } -type AccountCosmosData struct { +type accountCosmosData struct { cosmosdbapi.CosmosDocument - Account AccountCosmos `json:"mx_userapi_account"` + Account accountCosmos `json:"mx_userapi_account"` } -type AccountCosmosUserCount struct { +type accountCosmosUserCount struct { UserCount int64 `json:"usercount"` } @@ -74,6 +74,14 @@ type accountsStatements struct { serverName gomatrixserverlib.ServerName } +func (s *accountsStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *accountsStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) { s.db = db s.selectPasswordHashStmt = "select * from c where c._cn = @x1 and c.mx_userapi_account.local_part = @x2 and c.mx_userapi_account.is_deactivated = false" @@ -84,29 +92,8 @@ func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.Serv return } -func queryAccount(s *accountsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []AccountCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - -func getAccount(s *accountsStatements, ctx context.Context, pk string, docId string) (*AccountCosmosData, error) { - response := AccountCosmosData{} +func getAccount(s *accountsStatements, ctx context.Context, pk string, docId string) (*accountCosmosData, error) { + response := accountCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -122,8 +109,8 @@ func getAccount(s *accountsStatements, ctx context.Context, pk string, docId str return &response, err } -func setAccount(s *accountsStatements, ctx context.Context, account AccountCosmosData) (*AccountCosmosData, error) { - response := AccountCosmosData{} +func setAccount(s *accountsStatements, ctx context.Context, account accountCosmosData) (*accountCosmosData, error) { + response := accountCosmosData{} var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(account.Pk, account.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, @@ -135,7 +122,7 @@ func setAccount(s *accountsStatements, ctx context.Context, account AccountCosmo return &response, ex } -func mapFromAccount(db AccountCosmos) api.Account { +func mapFromAccount(db accountCosmos) api.Account { return api.Account{ AppServiceID: db.AppServiceID, Localpart: db.Localpart, @@ -144,8 +131,8 @@ func mapFromAccount(db AccountCosmos) api.Account { } } -func mapToAccount(api api.Account) AccountCosmos { - return AccountCosmos{ +func mapToAccount(api api.Account) accountCosmos { + return accountCosmos{ AppServiceID: api.AppServiceID, Localpart: api.Localpart, ServerName: api.ServerName, @@ -175,14 +162,11 @@ func (s *accountsStatements) insertAccount( data.PasswordHash = hash data.IsDeactivated = false - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) - docId := result.Localpart - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - var dbData = AccountCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + var dbData = accountCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), Account: data, } @@ -206,12 +190,10 @@ func (s *accountsStatements) updatePassword( ) (err error) { // "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) docId := localpart - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - var response, exGet = getAccount(s, ctx, pk, cosmosDocId) + var response, exGet = getAccount(s, ctx, s.getPartitionKey(), cosmosDocId) if exGet != nil { return exGet } @@ -230,13 +212,10 @@ func (s *accountsStatements) deactivateAccount( ) (err error) { // "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) - docId := localpart - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - var response, exGet = getAccount(s, ctx, pk, cosmosDocId) + var response, exGet = getAccount(s, ctx, s.getPartitionKey(), cosmosDocId) if exGet != nil { return exGet } @@ -255,27 +234,30 @@ func (s *accountsStatements) selectPasswordHash( ) (hash string, err error) { // "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": localpart, } - - response, err := queryAccount(s, ctx, s.selectPasswordHashStmt, params) + var rows []accountCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectPasswordHashStmt, params, &rows) if err != nil { return "", err } - if len(response) == 0 { + if len(rows) == 0 { return "", errors.New(fmt.Sprintf("Localpart %s not found", localpart)) } - if len(response) != 1 { + if len(rows) != 1 { return "", errors.New(fmt.Sprintf("Localpart %s has multiple entries", localpart)) } - return response[0].Account.PasswordHash, nil + return rows[0].Account.PasswordHash, nil } func (s *accountsStatements) selectAccountByLocalpart( @@ -284,23 +266,26 @@ func (s *accountsStatements) selectAccountByLocalpart( var acc api.Account // "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": localpart, } - - response, err := queryAccount(s, ctx, s.selectAccountByLocalpartStmt, params) + var rows []accountCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectAccountByLocalpartStmt, params, &rows) if err != nil { return nil, err } - if len(response) == 0 { + if len(rows) == 0 { return nil, nil } - acc = mapFromAccount(response[0].Account) + acc = mapFromAccount(rows[0].Account) acc.UserID = userutil.MakeUserID(localpart, s.serverName) acc.ServerName = s.serverName @@ -312,25 +297,20 @@ func (s *accountsStatements) selectNewNumericLocalpart( ) (id int64, err error) { // "SELECT COUNT(localpart) FROM account_accounts" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []AccountCosmosUserCount params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), } - var options = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(s.selectNewNumericLocalpartStmt, params) - var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, + + var rows []accountCosmosUserCount + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - query, - &response, - options) + s.getPartitionKey(), s.selectNewNumericLocalpartStmt, params, &rows) - if ex != nil { - return -1, ex + if err != nil { + return -1, err } - return int64(response[0].UserCount), nil + return int64(rows[0].UserCount), nil } diff --git a/userapi/storage/accounts/cosmosdb/key_backup_table.go b/userapi/storage/accounts/cosmosdb/key_backup_table.go index a6d968df1..bd525e8e7 100644 --- a/userapi/storage/accounts/cosmosdb/key_backup_table.go +++ b/userapi/storage/accounts/cosmosdb/key_backup_table.go @@ -40,12 +40,7 @@ import ( // CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); // ` -type KeyBackupCosmosData struct { - cosmosdbapi.CosmosDocument - KeyBackup KeyBackupCosmos `json:"mx_userapi_account_e2e_room_keys"` -} - -type KeyBackupCosmos struct { +type keyBackupCosmos struct { UserId string `json:"user_id"` RoomId string `json:"room_id"` SessionId string `json:"session_id"` @@ -56,7 +51,12 @@ type KeyBackupCosmos struct { SessionData []byte `json:"session_data"` } -type KeyBackupCosmosNumber struct { +type keyBackupCosmosData struct { + cosmosdbapi.CosmosDocument + KeyBackup keyBackupCosmos `json:"mx_userapi_account_e2e_room_keys"` +} + +type keyBackupCosmosNumber struct { Number int64 `json:"number"` } @@ -110,50 +110,16 @@ type keyBackupStatements struct { serverName gomatrixserverlib.ServerName } -func queryKeyBackup(s *keyBackupStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyBackupCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []KeyBackupCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *keyBackupStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func queryKeyBackupNumber(s *keyBackupStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyBackupCosmosNumber, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []KeyBackupCosmosNumber - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil +func (s *keyBackupStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } -func getKeyBackup(s *keyBackupStatements, ctx context.Context, pk string, docId string) (*KeyBackupCosmosData, error) { - response := KeyBackupCosmosData{} +func getKeyBackup(s *keyBackupStatements, ctx context.Context, pk string, docId string) (*keyBackupCosmosData, error) { + response := keyBackupCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -169,7 +135,7 @@ func getKeyBackup(s *keyBackupStatements, ctx context.Context, pk string, docId return &response, err } -func setKeyBackup(s *keyBackupStatements, ctx context.Context, keyBackup KeyBackupCosmosData) (*KeyBackupCosmosData, error) { +func setKeyBackup(s *keyBackupStatements, ctx context.Context, keyBackup keyBackupCosmosData) (*keyBackupCosmosData, error) { var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(keyBackup.Pk, keyBackup.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, @@ -200,13 +166,17 @@ func (s keyBackupStatements) countKeys( // "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2" // err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": userID, "@x3": version, } - rows, err := queryKeyBackupNumber(&s, ctx, s.countKeysStmt, params) + var rows []keyBackupCosmosNumber + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.countKeysStmt, params, &rows) if err != nil { return -1, err @@ -228,13 +198,11 @@ func (s *keyBackupStatements) insertBackupKey( // _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( // ctx, userID, key.RoomID, key.SessionID, version, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), // ) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); docId := fmt.Sprintf("%s_%s_%s_%s", userID, key.RoomID, key.SessionID, version) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - data := KeyBackupCosmos{ + data := keyBackupCosmos{ UserId: userID, RoomId: key.RoomID, SessionId: key.SessionID, @@ -245,8 +213,8 @@ func (s *keyBackupStatements) insertBackupKey( SessionData: key.SessionData, } - dbData := &KeyBackupCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData := &keyBackupCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), KeyBackup: data, } @@ -270,13 +238,11 @@ func (s *keyBackupStatements) updateBackupKey( // ctx, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), userID, key.RoomID, key.SessionID, version, // ) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); docId := fmt.Sprintf("%s_%s_%s_%s", userID, key.RoomID, key.SessionID, version) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - res, err := getKeyBackup(s, ctx, pk, cosmosDocId) + res, err := getKeyBackup(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return @@ -302,13 +268,17 @@ func (s *keyBackupStatements) selectKeys( ) (map[string]map[string]api.KeyBackupSession, error) { // "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + // "WHERE user_id = $1 AND version = $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": userID, "@x3": version, } - rows, err := queryKeyBackup(s, ctx, s.selectKeysStmt, params) + var rows []keyBackupCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectKeysStmt, params, &rows) if err != nil { return nil, err @@ -327,14 +297,18 @@ func (s *keyBackupStatements) selectKeysByRoomID( ) (map[string]map[string]api.KeyBackupSession, error) { // "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + // "WHERE user_id = $1 AND version = $2 AND room_id = $3" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": userID, "@x3": version, "@x4": roomID, } - rows, err := queryKeyBackup(s, ctx, s.selectKeysByRoomIDStmt, params) + var rows []keyBackupCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectKeysByRoomIDStmt, params, &rows) if err != nil { return nil, err @@ -355,15 +329,19 @@ func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( ) (map[string]map[string]api.KeyBackupSession, error) { // "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + // "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": userID, "@x3": version, "@x4": roomID, "@x5": sessionID, } - rows, err := queryKeyBackup(s, ctx, s.selectKeysByRoomIDAndSessionIDStmt, params) + var rows []keyBackupCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectKeysByRoomIDAndSessionIDStmt, params, &rows) if err != nil { return nil, err @@ -379,7 +357,7 @@ func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( return unpackKeys(ctx, &rows) } -func unpackKeys(ctx context.Context, rows *[]KeyBackupCosmosData) (map[string]map[string]api.KeyBackupSession, error) { +func unpackKeys(ctx context.Context, rows *[]keyBackupCosmosData) (map[string]map[string]api.KeyBackupSession, error) { result := make(map[string]map[string]api.KeyBackupSession) for _, item := range *rows { var key api.InternalKeyBackupSession diff --git a/userapi/storage/accounts/cosmosdb/key_backup_version_table.go b/userapi/storage/accounts/cosmosdb/key_backup_version_table.go index d08358fcf..ecd2c56ef 100644 --- a/userapi/storage/accounts/cosmosdb/key_backup_version_table.go +++ b/userapi/storage/accounts/cosmosdb/key_backup_version_table.go @@ -41,12 +41,7 @@ import ( // CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); // ` -type KeyBackupVersionCosmosData struct { - cosmosdbapi.CosmosDocument - KeyBackupVersion KeyBackupVersionCosmos `json:"mx_userapi_account_e2e_room_keys_versions"` -} - -type KeyBackupVersionCosmos struct { +type keyBackupVersionCosmos struct { UserId string `json:"user_id"` Version int64 `json:"vesion"` Algorithm string `json:"algorithm"` @@ -55,7 +50,12 @@ type KeyBackupVersionCosmos struct { Deleted int `json:"deleted"` } -type KeyBackupVersionCosmosNumber struct { +type keyBackupVersionCosmosData struct { + cosmosdbapi.CosmosDocument + KeyBackupVersion keyBackupVersionCosmos `json:"mx_userapi_account_e2e_room_keys_versions"` +} + +type keyBackupVersionCosmosNumber struct { Number int64 `json:"number"` } @@ -91,33 +91,16 @@ type keyBackupVersionStatements struct { serverName gomatrixserverlib.ServerName } -func queryKeyBackupVersionNumber(s *keyBackupVersionStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyBackupVersionCosmosNumber, error) { - var response []KeyBackupVersionCosmosNumber - - var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() - var query = cosmosdbapi.GetQuery(qry, params) - var _, _ = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - //WHen there is no data these GroupBy queries return errors - // if err != nil { - // return nil, err - // } - - if len(response) == 0 { - return nil, cosmosdbutil.ErrNoRows - } - - return response, nil +func (s *keyBackupVersionStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) } -func getKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, pk string, docId string) (*KeyBackupVersionCosmosData, error) { - response := KeyBackupVersionCosmosData{} +func (s *keyBackupVersionStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func getKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, pk string, docId string) (*keyBackupVersionCosmosData, error) { + response := keyBackupVersionCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -133,7 +116,7 @@ func getKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, pk return &response, err } -func setKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, keyBackup KeyBackupVersionCosmosData) (*KeyBackupVersionCosmosData, error) { +func setKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, keyBackup keyBackupVersionCosmosData) (*keyBackupVersionCosmosData, error) { var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(keyBackup.Pk, keyBackup.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, @@ -171,14 +154,11 @@ func (s *keyBackupVersionStatements) insertKeyBackup( return "", seqErr } // err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData), etag).Scan(&versionInt) - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); docId := fmt.Sprintf("%s_%d", userID, versionInt) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - - data := KeyBackupVersionCosmos{ + data := keyBackupVersionCosmos{ UserId: userID, Version: versionInt, Algorithm: algorithm, @@ -187,8 +167,8 @@ func (s *keyBackupVersionStatements) insertKeyBackup( Deleted: 0, } - dbData := &KeyBackupVersionCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + dbData := &keyBackupVersionCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), KeyBackupVersion: data, } @@ -211,13 +191,11 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData( if err != nil { return fmt.Errorf("invalid version") } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); docId := fmt.Sprintf("%s_%d", userID, versionInt) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - res, err := getKeyBackupVersion(s, ctx, pk, cosmosDocId) + res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return err @@ -243,13 +221,11 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag( if err != nil { return fmt.Errorf("invalid version") } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); docId := fmt.Sprintf("%s_%d", userID, versionInt) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - res, err := getKeyBackupVersion(s, ctx, pk, cosmosDocId) + res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return err @@ -275,13 +251,11 @@ func (s *keyBackupVersionStatements) deleteKeyBackup( if err != nil { return false, fmt.Errorf("invalid version") } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); docId := fmt.Sprintf("%s_%d", userID, versionInt) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - res, err := getKeyBackupVersion(s, ctx, pk, cosmosDocId) + res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return false, err @@ -309,17 +283,21 @@ func (s *keyBackupVersionStatements) selectKeyBackup( var versionInt int64 if version == "" { // var v *int64 // allows nulls - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) params := map[string]interface{}{ "@x1": s.db.cosmosConfig.TenantName, - "@x2": dbCollectionName, + "@x2": s.getCollectionName(), "@x3": userID, } // err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) - response, err1 := queryKeyBackupVersionNumber(s, ctx, s.selectLatestVersionStmt, params) + var rows []keyBackupVersionCosmosNumber + err = cosmosdbapi.PerformQueryAllPartitions(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.selectLatestVersionStmt, params, &rows) - if err1 != nil { + if err != nil { if err == cosmosdbutil.ErrNoRows { err = nil } @@ -327,12 +305,12 @@ func (s *keyBackupVersionStatements) selectKeyBackup( // if err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&v); err != nil { // return // } - if response == nil || len(response) == 0 { + if rows == nil || len(rows) == 0 { err = cosmosdbutil.ErrNoRows versionInt = 0 return } - versionInt = response[0].Number + versionInt = rows[0].Number } else { if versionInt, err = strconv.ParseInt(version, 10, 64); err != nil { return @@ -342,13 +320,11 @@ func (s *keyBackupVersionStatements) selectKeyBackup( if err != nil { return } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) // CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); docId := fmt.Sprintf("%s_%d", userID, versionInt) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - res, err := getKeyBackupVersion(s, ctx, pk, cosmosDocId) + res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return diff --git a/userapi/storage/accounts/cosmosdb/openid_table.go b/userapi/storage/accounts/cosmosdb/openid_table.go index 44a0a46d4..b21c5f25d 100644 --- a/userapi/storage/accounts/cosmosdb/openid_table.go +++ b/userapi/storage/accounts/cosmosdb/openid_table.go @@ -22,7 +22,7 @@ import ( // ` // OpenIDToken represents an OpenID token -type OpenIDTokenCosmos struct { +type openIDTokenCosmos struct { Token string `json:"token"` UserID string `json:"user_id"` ExpiresAtMS int64 `json:"expires_at"` @@ -30,7 +30,7 @@ type OpenIDTokenCosmos struct { type OpenIdTokenCosmosData struct { cosmosdbapi.CosmosDocument - OpenIdToken OpenIDTokenCosmos `json:"mx_userapi_openidtoken"` + OpenIdToken openIDTokenCosmos `json:"mx_userapi_openidtoken"` } type tokenStatements struct { @@ -41,7 +41,15 @@ type tokenStatements struct { serverName gomatrixserverlib.ServerName } -func mapFromToken(db OpenIDTokenCosmos) api.OpenIDToken { +func (s *tokenStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *tokenStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func mapFromToken(db openIDTokenCosmos) api.OpenIDToken { return api.OpenIDToken{ ExpiresAtMS: db.ExpiresAtMS, Token: db.Token, @@ -49,35 +57,14 @@ func mapFromToken(db OpenIDTokenCosmos) api.OpenIDToken { } } -func mapToToken(api api.OpenIDToken) OpenIDTokenCosmos { - return OpenIDTokenCosmos{ +func mapToToken(api api.OpenIDToken) openIDTokenCosmos { + return openIDTokenCosmos{ ExpiresAtMS: api.ExpiresAtMS, Token: api.Token, UserID: api.UserID, } } -func queryOpenIdToken(s *tokenStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OpenIdTokenCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []OpenIdTokenCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - func (s *tokenStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) { s.db = db s.selectTokenStmt = "select * from c where c._cn = @x1 and c.mx_userapi_openidtoken.token = @x2" @@ -95,20 +82,17 @@ func (s *tokenStatements) insertToken( ) (err error) { // "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + docId := token + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + var result = &api.OpenIDToken{ UserID: localpart, Token: token, ExpiresAtMS: expiresAtMS, } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) - - docId := result.Token - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var dbData = OpenIdTokenCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), OpenIdToken: mapToToken(*result), } @@ -136,22 +120,26 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes( var openIDTokenAttrs api.OpenIDTokenAttributes // "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": token, } - response, err := queryOpenIdToken(s, ctx, s.selectTokenStmt, params) + var rows []OpenIdTokenCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectTokenStmt, params, &rows) if err != nil { return nil, err } - if len(response) == 0 { + if len(rows) == 0 { return nil, nil } - var openIdToken = response[0].OpenIdToken + var openIdToken = rows[0].OpenIdToken openIDTokenAttrs = api.OpenIDTokenAttributes{ UserID: openIdToken.UserID, ExpiresAtMS: openIdToken.ExpiresAtMS, diff --git a/userapi/storage/accounts/cosmosdb/profile_table.go b/userapi/storage/accounts/cosmosdb/profile_table.go index aa5bc0d74..cc5daad87 100644 --- a/userapi/storage/accounts/cosmosdb/profile_table.go +++ b/userapi/storage/accounts/cosmosdb/profile_table.go @@ -38,15 +38,15 @@ import ( // ` // Profile represents the profile for a Matrix account. -type ProfileCosmos struct { +type profileCosmos struct { Localpart string `json:"local_part"` DisplayName string `json:"display_name"` AvatarURL string `json:"avatar_url"` } -type ProfileCosmosData struct { +type profileCosmosData struct { cosmosdbapi.CosmosDocument - Profile ProfileCosmos `json:"mx_userapi_profile"` + Profile profileCosmos `json:"mx_userapi_profile"` } type profilesStatements struct { @@ -59,7 +59,15 @@ type profilesStatements struct { tableName string } -func mapFromProfile(db ProfileCosmos) authtypes.Profile { +func (s *profilesStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *profilesStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func mapFromProfile(db profileCosmos) authtypes.Profile { return authtypes.Profile{ AvatarURL: db.AvatarURL, DisplayName: db.DisplayName, @@ -67,8 +75,8 @@ func mapFromProfile(db ProfileCosmos) authtypes.Profile { } } -func mapToProfile(api authtypes.Profile) ProfileCosmos { - return ProfileCosmos{ +func mapToProfile(api authtypes.Profile) profileCosmos { + return profileCosmos{ AvatarURL: api.AvatarURL, DisplayName: api.DisplayName, Localpart: api.Localpart, @@ -83,29 +91,8 @@ func (s *profilesStatements) prepare(db *Database) (err error) { return } -func queryProfile(s *profilesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ProfileCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []ProfileCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - -func getProfile(s *profilesStatements, ctx context.Context, pk string, docId string) (*ProfileCosmosData, error) { - response := ProfileCosmosData{} +func getProfile(s *profilesStatements, ctx context.Context, pk string, docId string) (*profileCosmosData, error) { + response := profileCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -121,7 +108,7 @@ func getProfile(s *profilesStatements, ctx context.Context, pk string, docId str return &response, err } -func setProfile(s *profilesStatements, ctx context.Context, profile ProfileCosmosData) (*ProfileCosmosData, error) { +func setProfile(s *profilesStatements, ctx context.Context, profile profileCosmosData) (*profileCosmosData, error) { var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(profile.Pk, profile.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, @@ -142,14 +129,11 @@ func (s *profilesStatements) insertProfile( Localpart: localpart, } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) - docId := localpart - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - var dbData = ProfileCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + var dbData = profileCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), Profile: mapToProfile(*result), } @@ -169,27 +153,30 @@ func (s *profilesStatements) selectProfileByLocalpart( ) (*authtypes.Profile, error) { // "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": localpart, } - - response, err := queryProfile(s, ctx, s.selectProfileByLocalpartStmt, params) + var rows []profileCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectProfileByLocalpartStmt, params, &rows) if err != nil { return nil, err } - if len(response) == 0 { - return nil, errors.New(fmt.Sprintf("Localpart %s not found", len(response))) + if len(rows) == 0 { + return nil, errors.New(fmt.Sprintf("Localpart %s not found", len(rows))) } - if len(response) != 1 { - return nil, errors.New(fmt.Sprintf("Localpart %s has multiple entries", len(response))) + if len(rows) != 1 { + return nil, errors.New(fmt.Sprintf("Localpart %s has multiple entries", len(rows))) } - var result = mapFromProfile(response[0].Profile) + var result = mapFromProfile(rows[0].Profile) return &result, nil } @@ -198,12 +185,10 @@ func (s *profilesStatements) setAvatarURL( ) (err error) { // "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) docId := localpart - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - var response, exGet = getProfile(s, ctx, pk, cosmosDocId) + var response, exGet = getProfile(s, ctx, s.getPartitionKey(), cosmosDocId) if exGet != nil { return exGet } @@ -222,11 +207,9 @@ func (s *profilesStatements) setDisplayName( ) (err error) { // "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) docId := localpart - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response, exGet = getProfile(s, ctx, pk, cosmosDocId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + var response, exGet = getProfile(s, ctx, s.getPartitionKey(), cosmosDocId) if exGet != nil { return exGet } @@ -246,21 +229,24 @@ func (s *profilesStatements) selectProfilesBySearch( var profiles []authtypes.Profile // "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": searchString, "@x3": limit, } - - response, err := queryProfile(s, ctx, s.selectProfilesBySearchStmt, params) + var rows []profileCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectProfilesBySearchStmt, params, &rows) if err != nil { return nil, err } - for i := 0; i < len(response); i++ { - var responseData = response[i] + for i := 0; i < len(rows); i++ { + var responseData = rows[i] profiles = append(profiles, mapFromProfile(responseData.Profile)) } diff --git a/userapi/storage/accounts/cosmosdb/threepid_table.go b/userapi/storage/accounts/cosmosdb/threepid_table.go index 52d30cd26..2eb18c3b6 100644 --- a/userapi/storage/accounts/cosmosdb/threepid_table.go +++ b/userapi/storage/accounts/cosmosdb/threepid_table.go @@ -36,15 +36,15 @@ import ( // PRIMARY KEY(threepid, medium) // ); -type ThreePIDCosmos struct { +type threePIDCosmos struct { Localpart string `json:"local_part"` ThreePID string `json:"three_pid"` Medium string `json:"medium"` } -type ThreePIDCosmosData struct { +type threePIDCosmosData struct { cosmosdbapi.CosmosDocument - ThreePID ThreePIDCosmos `json:"mx_userapi_threepid"` + ThreePID threePIDCosmos `json:"mx_userapi_threepid"` } type threepidStatements struct { @@ -56,6 +56,14 @@ type threepidStatements struct { tableName string } +func (s *threepidStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *threepidStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + func (s *threepidStatements) prepare(db *Database) (err error) { s.db = db s.selectLocalpartForThreePIDStmt = "select * from c where c._cn = @x1 and c.mx_userapi_threepid.three_pid = @x2 and c.mx_userapi_threepid.medium = @x3" @@ -64,50 +72,33 @@ func (s *threepidStatements) prepare(db *Database) (err error) { return } -func queryThreePID(s *threepidStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ThreePIDCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []ThreePIDCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - func (s *threepidStatements) selectLocalpartForThreePID( ctx context.Context, threepid string, medium string, ) (localpart string, err error) { // "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": threepid, "@x3": medium, } - response, err := queryThreePID(s, ctx, s.selectLocalpartForThreePIDStmt, params) + var rows []threePIDCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectLocalpartForThreePIDStmt, params, &rows) if err != nil { return "", err } - if len(response) == 0 { + if len(rows) == 0 { return "", nil } - return response[0].ThreePID.Localpart, nil + return rows[0].ThreePID.Localpart, nil } func (s *threepidStatements) selectThreePIDsForLocalpart( @@ -115,23 +106,27 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( ) (threepids []authtypes.ThreePID, err error) { // "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": localpart, } - response, err := queryThreePID(s, ctx, s.selectThreePIDsForLocalpartStmt, params) + var rows []threePIDCosmosData + err = cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectThreePIDsForLocalpartStmt, params, &rows) if err != nil { return threepids, err } - if len(response) == 0 { + if len(rows) == 0 { return threepids, nil } threepids = []authtypes.ThreePID{} - for _, item := range response { + for _, item := range rows { threepids = append(threepids, authtypes.ThreePID{ Address: item.ThreePID.ThreePID, Medium: item.ThreePID.Medium, @@ -145,19 +140,16 @@ func (s *threepidStatements) insertThreePID( ) (err error) { // "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" - var result = ThreePIDCosmos{ + var result = threePIDCosmos{ Localpart: localpart, Medium: medium, ThreePID: threepid, } - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) - docId := fmt.Sprintf("%s_%s", threepid, medium) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var dbData = ThreePIDCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + var dbData = threePIDCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), ThreePID: result, } @@ -179,11 +171,9 @@ func (s *threepidStatements) deleteThreePID( ctx context.Context, threepid string, medium string) (err error) { // "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName) docId := fmt.Sprintf("%s_%s", threepid, medium) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var options = cosmosdbapi.GetDeleteDocumentOptions(pk) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + var options = cosmosdbapi.GetDeleteDocumentOptions(s.getPartitionKey()) _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, s.db.cosmosConfig.DatabaseName, diff --git a/userapi/storage/devices/cosmosdb/devices_table.go b/userapi/storage/devices/cosmosdb/devices_table.go index bfd8f8847..3c31d8e50 100644 --- a/userapi/storage/devices/cosmosdb/devices_table.go +++ b/userapi/storage/devices/cosmosdb/devices_table.go @@ -49,7 +49,7 @@ import ( // ); // ` -type DeviceCosmos struct { +type deviceCosmos struct { ID string `json:"device_id"` UserID string `json:"user_id"` // The access_token granted to this device. @@ -69,12 +69,12 @@ type DeviceCosmos struct { AppserviceID string `json:"app_service_id"` } -type DeviceCosmosData struct { +type deviceCosmosData struct { cosmosdbapi.CosmosDocument - Device DeviceCosmos `json:"mx_userapi_device"` + Device deviceCosmos `json:"mx_userapi_device"` } -type DeviceCosmosSessionCount struct { +type deviceCosmosSessionCount struct { SessionCount int64 `json:"sessioncount"` } @@ -90,7 +90,15 @@ type devicesStatements struct { tableName string } -func mapFromDevice(db DeviceCosmos) api.Device { +func (s *devicesStatements) getCollectionName() string { + return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) +} + +func (s *devicesStatements) getPartitionKey() string { + return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) +} + +func mapFromDevice(db deviceCosmos) api.Device { return api.Device{ AccessToken: db.AccessToken, AppserviceID: db.AppserviceID, @@ -103,9 +111,9 @@ func mapFromDevice(db DeviceCosmos) api.Device { } } -func mapTodevice(api api.Device, s *devicesStatements) DeviceCosmos { +func mapTodevice(api api.Device, s *devicesStatements) deviceCosmos { localPart, _ := userutil.ParseUsernameParam(api.UserID, &s.serverName) - return DeviceCosmos{ + return deviceCosmos{ AccessToken: api.AccessToken, AppserviceID: api.AppserviceID, ID: api.ID, @@ -118,29 +126,8 @@ func mapTodevice(api api.Device, s *devicesStatements) DeviceCosmos { } } -func queryDevice(s *devicesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceCosmosData, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []DeviceCosmosData - - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(qry, params) - _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) - - if err != nil { - return nil, err - } - return response, nil -} - -func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*DeviceCosmosData, error) { - response := DeviceCosmosData{} +func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*deviceCosmosData, error) { + response := deviceCosmosData{} err := cosmosdbapi.GetDocumentOrNil( s.db.connection, s.db.cosmosConfig, @@ -156,7 +143,7 @@ func getDevice(s *devicesStatements, ctx context.Context, pk string, docId strin return &response, err } -func setDevice(s *devicesStatements, ctx context.Context, device DeviceCosmosData) (*DeviceCosmosData, error) { +func setDevice(s *devicesStatements, ctx context.Context, device deviceCosmosData) (*deviceCosmosData, error) { var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(device.Pk, device.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, @@ -191,31 +178,31 @@ func (s *devicesStatements) insertDevice( var sessionID int64 // "SELECT COUNT(access_token) FROM device_devices" // HACK: Do we need a Cosmos Table for the sequence? - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) - var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response []DeviceCosmosSessionCount params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), } - var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) - var query = cosmosdbapi.GetQuery(s.selectDevicesCountStmt, params) - var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, + var rows []deviceCosmosSessionCount + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, - query, - &response, - optionsQry) + s.getPartitionKey(), s.selectDevicesCountStmt, params, &rows) if err != nil { return nil, err } - sessionID = response[0].SessionCount + sessionID = rows[0].SessionCount sessionID++ // "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + // " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" - data := DeviceCosmos{ + // access_token TEXT PRIMARY KEY, + // UNIQUE (localpart, device_id) + // HACK: check for duplicate PK as we are using the UNIQUE key for the DocId + docId := fmt.Sprintf("%s_%s", localpart, id) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + + data := deviceCosmos{ ID: id, UserID: userutil.MakeUserID(localpart, s.serverName), AccessToken: accessToken, @@ -226,14 +213,8 @@ func (s *devicesStatements) insertDevice( UserAgent: userAgent, } - // access_token TEXT PRIMARY KEY, - // UNIQUE (localpart, device_id) - // HACK: check for duplicate PK as we are using the UNIQUE key for the DocId - docId := fmt.Sprintf("%s_%s", localpart, id) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - - var dbData = DeviceCosmosData{ - CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId), + var dbData = deviceCosmosData{ + CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId), Device: data, } @@ -257,11 +238,9 @@ func (s *devicesStatements) deleteDevice( ctx context.Context, id, localpart string, ) error { // "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) docId := fmt.Sprintf("%s_%s", localpart, id) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var options = cosmosdbapi.GetDeleteDocumentOptions(pk) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + var options = cosmosdbapi.GetDeleteDocumentOptions(s.getPartitionKey()) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( ctx, s.db.cosmosConfig.DatabaseName, @@ -279,20 +258,23 @@ func (s *devicesStatements) deleteDevices( ctx context.Context, localpart string, devices []string, ) error { // "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) - var response []DeviceCosmosData params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": localpart, "@x3": devices, } - response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartStmt, params) + var rows []deviceCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectDevicesByLocalpartStmt, params, &rows) if err != nil { return err } - for _, item := range response { + for _, item := range rows { s.deleteDevice(ctx, item.Device.ID, item.Device.Localpart) } return err @@ -302,22 +284,25 @@ func (s *devicesStatements) deleteDevicesByLocalpart( ctx context.Context, localpart, exceptDeviceID string, ) error { // "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) exceptDevices := []string{ exceptDeviceID, } params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": localpart, "@x3": exceptDevices, } - - response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartStmt, params) + var rows []deviceCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectDevicesByLocalpartStmt, params, &rows) if err != nil { return err } - for _, item := range response { + for _, item := range rows { s.deleteDevice(ctx, item.Device.ID, item.Device.Localpart) } return err @@ -327,11 +312,9 @@ func (s *devicesStatements) updateDeviceName( ctx context.Context, localpart, deviceID string, displayName *string, ) error { // "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) docId := fmt.Sprintf("%s_%s", localpart, deviceID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response, exGet = getDevice(s, ctx, pk, cosmosDocId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + var response, exGet = getDevice(s, ctx, s.getPartitionKey(), cosmosDocId) if exGet != nil { return exGet } @@ -349,24 +332,27 @@ func (s *devicesStatements) selectDeviceByToken( ctx context.Context, accessToken string, ) (*api.Device, error) { // "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) - var response []DeviceCosmosData params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": accessToken, } - response, err := queryDevice(s, ctx, s.selectDeviceByTokenStmt, params) + var rows []deviceCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectDeviceByTokenStmt, params, &rows) if err != nil { return nil, err } - if len(response) == 0 { + if len(rows) == 0 { return nil, cosmosdbutil.ErrNoRows } if err == nil { - result := mapFromDevice(response[0].Device) + result := mapFromDevice(rows[0].Device) return &result, nil } return nil, err @@ -378,11 +364,9 @@ func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { // "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) docId := fmt.Sprintf("%s_%s", localpart, deviceID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response, exGet = getDevice(s, ctx, pk, cosmosDocId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + var response, exGet = getDevice(s, ctx, s.getPartitionKey(), cosmosDocId) if exGet != nil { return nil, exGet } @@ -395,20 +379,23 @@ func (s *devicesStatements) selectDevicesByLocalpart( ) ([]api.Device, error) { devices := []api.Device{} // "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": localpart, "@x3": exceptDeviceID, } - - response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartExceptIDStmt, params) + var rows []deviceCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectDevicesByLocalpartExceptIDStmt, params, &rows) if err != nil { return nil, err } - for _, item := range response { + for _, item := range rows { dev := mapFromDevice(item.Device) dev.UserID = userutil.MakeUserID(localpart, s.serverName) devices = append(devices, dev) @@ -420,19 +407,21 @@ func (s *devicesStatements) selectDevicesByLocalpart( func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { // "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)" var devices []api.Device - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) - var response []DeviceCosmosData params := map[string]interface{}{ - "@x1": dbCollectionName, + "@x1": s.getCollectionName(), "@x2": deviceIDs, } - - response, err := queryDevice(s, ctx, s.selectDevicesByIDStmt, params) + var rows []deviceCosmosData + err := cosmosdbapi.PerformQuery(ctx, + s.db.connection, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + s.getPartitionKey(), s.selectDevicesByIDStmt, params, &rows) if err != nil { return nil, err } - for _, item := range response { + for _, item := range rows { dev := mapFromDevice(item.Device) devices = append(devices, dev) } @@ -443,11 +432,9 @@ func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, localpart, lastSeenTs := time.Now().UnixNano() / 1000000 // "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName) docId := fmt.Sprintf("%s_%s", localpart, deviceID) - cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) - pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) - var response, exGet = getDevice(s, ctx, pk, cosmosDocId) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) + var response, exGet = getDevice(s, ctx, s.getPartitionKey(), cosmosDocId) if exGet != nil { return exGet }