Use a common way to generate CollectionName and PartitionKey (#18)

* - Create CosmosDocument as a base class
- Add CT and UT
- Refactor all tables to use the CosmosDocument

* - Add UpsertDocument method to perform updates in a generic way
- Add SetUpdateTime() to update the UT for updates
- Refactor it all

* - Add Performquery method
- Refactor appservice_events_table

* - Update naffka Topics and Messages to use the common pattern

* - Update keyserver to use the common pattern for collection and PK

* - Update mediaapi to use the common pattern for collection and pk

* - Update roomserver to use the common pattern for collectionname and pk

* - Update signingkeyserver to use the common pattern for collectionname and pk

* - Update userapi touse the common pattern for collectionname and pk

* - Update partitionOffset to use the common collectionname and pk
- Remove generic GetPartitionKey() method

Co-authored-by: alexf@example.com <alexf@example.com>
This commit is contained in:
alexfca 2021-09-23 09:02:37 +10:00 committed by GitHub
parent acf63daf79
commit 927238a686
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
54 changed files with 2299 additions and 2904 deletions

View file

@ -45,20 +45,20 @@ import (
// CREATE INDEX IF NOT EXISTS appservice_events_as_id ON appservice_events(as_id);
// `
type EventCosmos struct {
type eventCosmos struct {
ID int64 `json:"id"`
AppServiceID string `json:"as_id"`
HeaderedEventJSON []byte `json:"headered_event_json"`
TXNID int64 `json:"txn_id"`
}
type EventNumberCosmosData struct {
type eventNumberCosmosData struct {
Number int `json:"number"`
}
type EventCosmosData struct {
type eventCosmosData struct {
cosmosdbapi.CosmosDocument
Event EventCosmos `json:"mx_appservice_event"`
Event eventCosmos `json:"mx_appservice_event"`
}
// "SELECT id, headered_event_json, txn_id " +
@ -119,50 +119,16 @@ func (s *eventsStatements) prepare(db *Database, writer sqlutil.Writer) (err err
return
}
func queryEvent(s *eventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []EventCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *eventsStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func queryEventEventNumber(s *eventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventNumberCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []EventNumberCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *eventsStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getEvent(s *eventsStatements, ctx context.Context, pk string, docId string) (*EventCosmosData, error) {
response := EventCosmosData{}
func getEvent(s *eventsStatements, ctx context.Context, pk string, docId string) (*eventCosmosData, error) {
response := eventCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -178,7 +144,7 @@ func getEvent(s *eventsStatements, ctx context.Context, pk string, docId string)
return &response, err
}
func setEvent(s *eventsStatements, ctx context.Context, event EventCosmosData) (*EventCosmosData, error) {
func setEvent(s *eventsStatements, ctx context.Context, event eventCosmosData) (*eventCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(event.Pk, event.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
@ -190,7 +156,7 @@ func setEvent(s *eventsStatements, ctx context.Context, event EventCosmosData) (
return &event, ex
}
func deleteEvent(s *eventsStatements, ctx context.Context, event EventCosmosData) error {
func deleteEvent(s *eventsStatements, ctx context.Context, event eventCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(event.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
@ -223,13 +189,17 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
// "SELECT id, headered_event_json, txn_id " +
// "FROM appservice_events WHERE as_id = $1 ORDER BY txn_id DESC, id ASC"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": applicationServiceID,
}
eventRows, err := queryEvent(s, ctx, s.selectEventsByApplicationServiceIDStmt, params)
var rows []eventCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectEventsByApplicationServiceIDStmt, params, &rows)
if err != nil {
log.WithFields(log.Fields{
@ -237,7 +207,7 @@ func (s *eventsStatements) selectEventsByApplicationServiceID(
}).WithError(err).Fatalf("appservice unable to select new events to send")
}
events, maxID, txnID, eventsRemaining, err = retrieveEvents(eventRows, limit)
events, maxID, txnID, eventsRemaining, err = retrieveEvents(rows, limit)
if err != nil {
return
}
@ -252,7 +222,7 @@ func checkNamedErr(fn func() error, err *error) {
}
}
func retrieveEvents(eventRows []EventCosmosData, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) {
func retrieveEvents(eventRows []eventCosmosData, limit int) (events []gomatrixserverlib.HeaderedEvent, maxID, txnID int, eventsRemaining bool, err error) {
// Get current time for use in calculating event age
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
@ -318,18 +288,22 @@ func (s *eventsStatements) countEventsByApplicationServiceID(
// "SELECT COUNT(id) FROM appservice_events WHERE as_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": appServiceID,
}
response, err := queryEventEventNumber(s, ctx, s.countEventsByApplicationServiceIDStmt, params)
var rows []eventNumberCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectEventsByApplicationServiceIDStmt, params, &rows)
if err != nil && err != sql.ErrNoRows {
return 0, err
}
count = response[0].Number
count = rows[0].Number
return count, nil
}
@ -350,12 +324,10 @@ func (s *eventsStatements) insertEvent(
// "INSERT INTO appservice_events(as_id, headered_event_json, txn_id) " +
// "VALUES ($1, $2, $3)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := fmt.Sprintf("%s", appServiceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, err := getEvent(s, ctx, pk, cosmosDocId)
dbData, err := getEvent(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
dbData.SetUpdateTime()
dbData.Event.HeaderedEventJSON = eventJSON
@ -369,15 +341,15 @@ func (s *eventsStatements) insertEvent(
// appServiceID,
// eventJSON,
// -1, // No transaction ID yet
data := EventCosmos{
data := eventCosmos{
AppServiceID: appServiceID,
HeaderedEventJSON: eventJSON,
ID: idSeq,
TXNID: -1,
}
dbData = &EventCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &eventCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
Event: data,
}
@ -401,19 +373,23 @@ func (s *eventsStatements) updateTxnIDForEvents(
) (err error) {
// "UPDATE appservice_events SET txn_id = $1 WHERE as_id = $2 AND id <= $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": appserviceID,
"@x3": maxID,
}
response, err := queryEvent(s, ctx, s.updateTxnIDForEventsStmt, params)
var rows []eventCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.updateTxnIDForEventsStmt, params, &rows)
if err != nil {
return err
}
for _, item := range response {
for _, item := range rows {
item.Event.TXNID = int64(txnID)
// _, err := s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID)
_, err = setEvent(s, ctx, item)
@ -430,19 +406,24 @@ func (s *eventsStatements) deleteEventsBeforeAndIncludingID(
) (err error) {
// "DELETE FROM appservice_events WHERE as_id = $1 AND id <= $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": appserviceID,
"@x3": eventTableID,
}
response, err := queryEvent(s, ctx, s.deleteEventsBeforeAndIncludingIDStmt, params)
var rows []eventCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteEventsBeforeAndIncludingIDStmt, params, &rows)
if err != nil {
return err
}
for _, item := range response {
for _, item := range rows {
// _, err := s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID)
err = deleteEvent(s, ctx, item)
}

View file

@ -32,13 +32,13 @@ import (
// );
// `
type BlacklistCosmos struct {
type blacklistCosmos struct {
ServerName string `json:"server_name"`
}
type BlacklistCosmosData struct {
type blacklistCosmosData struct {
cosmosdbapi.CosmosDocument
Blacklist BlacklistCosmos `json:"mx_federationsender_blacklist"`
Blacklist blacklistCosmos `json:"mx_federationsender_blacklist"`
}
// const insertBlacklistSQL = "" +
@ -64,8 +64,16 @@ type blacklistStatements struct {
tableName string
}
func getBlacklist(s *blacklistStatements, ctx context.Context, pk string, docId string) (*BlacklistCosmosData, error) {
response := BlacklistCosmosData{}
func (s *blacklistStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *blacklistStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getBlacklist(s *blacklistStatements, ctx context.Context, pk string, docId string) (*blacklistCosmosData, error) {
response := blacklistCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -81,28 +89,7 @@ func getBlacklist(s *blacklistStatements, ctx context.Context, pk string, docId
return &response, err
}
func queryBlacklist(s *blacklistStatements, ctx context.Context, qry string, params map[string]interface{}) ([]BlacklistCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []BlacklistCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func deleteBlacklist(s *blacklistStatements, ctx context.Context, dbData BlacklistCosmosData) error {
func deleteBlacklist(s *blacklistStatements, ctx context.Context, dbData blacklistCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
@ -137,22 +124,20 @@ func (s *blacklistStatements) InsertBlacklist(
// stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (server_name)
docId := fmt.Sprintf("%s", serverName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getBlacklist(s, ctx, pk, cosmosDocId)
dbData, _ := getBlacklist(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
dbData.SetUpdateTime()
} else {
data := BlacklistCosmos{
data := blacklistCosmos{
ServerName: string(serverName),
}
dbData = &BlacklistCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &blacklistCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
Blacklist: data,
}
}
@ -177,13 +162,11 @@ func (s *blacklistStatements) SelectBlacklist(
// stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (server_name)
docId := fmt.Sprintf("%s", serverName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
// res, err := stmt.QueryContext(ctx, serverName)
res, err := getBlacklist(s, ctx, pk, cosmosDocId)
res, err := getBlacklist(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return false, err
}
@ -201,13 +184,11 @@ func (s *blacklistStatements) DeleteBlacklist(
// "DELETE FROM federationsender_blacklist WHERE server_name = $1"
// stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (server_name)
docId := fmt.Sprintf("%s", serverName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
// _, err := stmt.ExecContext(ctx, serverName)
res, err := getBlacklist(s, ctx, pk, cosmosDocId)
res, err := getBlacklist(s, ctx, s.getPartitionKey(), cosmosDocId)
if res != nil {
_ = deleteBlacklist(s, ctx, *res)
}
@ -220,13 +201,17 @@ func (s *blacklistStatements) DeleteAllBlacklist(
// "DELETE FROM federationsender_blacklist"
// stmt := sqlutil.TxStmt(txn, s.deleteAllBlacklistStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
}
// rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID)
rows, err := queryBlacklist(s, ctx, s.deleteAllBlacklistStmt, params)
var rows []blacklistCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteAllBlacklistStmt, params, &rows)
if err != nil {
return err

View file

@ -38,7 +38,7 @@ import (
// );
// `
type InboundPeekCosmos struct {
type inboundPeekCosmos struct {
RoomID string `json:"room_id"`
ServerName string `json:"server_name"`
PeekID string `json:"peek_id"`
@ -47,9 +47,9 @@ type InboundPeekCosmos struct {
RenewalInterval int64 `json:"renewal_interval"`
}
type InboundPeekCosmosData struct {
type inboundPeekCosmosData struct {
cosmosdbapi.CosmosDocument
InboundPeek InboundPeekCosmos `json:"mx_federationsender_inbound_peek"`
InboundPeek inboundPeekCosmos `json:"mx_federationsender_inbound_peek"`
}
// const insertInboundPeekSQL = "" +
@ -88,29 +88,16 @@ type inboundPeeksStatements struct {
tableName string
}
func queryInboundPeek(s *inboundPeeksStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InboundPeekCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []InboundPeekCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *inboundPeeksStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func getInboundPeek(s *inboundPeeksStatements, ctx context.Context, pk string, docId string) (*InboundPeekCosmosData, error) {
response := InboundPeekCosmosData{}
func (s *inboundPeeksStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getInboundPeek(s *inboundPeeksStatements, ctx context.Context, pk string, docId string) (*inboundPeekCosmosData, error) {
response := inboundPeekCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -126,7 +113,7 @@ func getInboundPeek(s *inboundPeeksStatements, ctx context.Context, pk string, d
return &response, err
}
func setInboundPeek(s *inboundPeeksStatements, ctx context.Context, inboundPeek InboundPeekCosmosData) (*InboundPeekCosmosData, error) {
func setInboundPeek(s *inboundPeeksStatements, ctx context.Context, inboundPeek inboundPeekCosmosData) (*inboundPeekCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(inboundPeek.Pk, inboundPeek.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
@ -138,7 +125,7 @@ func setInboundPeek(s *inboundPeeksStatements, ctx context.Context, inboundPeek
return &inboundPeek, ex
}
func deleteInboundPeek(s *inboundPeeksStatements, ctx context.Context, dbData InboundPeekCosmosData) error {
func deleteInboundPeek(s *inboundPeeksStatements, ctx context.Context, dbData inboundPeekCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
@ -172,19 +159,17 @@ func (s *inboundPeeksStatements) InsertInboundPeek(
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
// stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_id, server_name, peek_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getInboundPeek(s, ctx, pk, cosmosDocId)
dbData, _ := getInboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
dbData.SetUpdateTime()
dbData.InboundPeek.RenewedTimestamp = nowMilli
dbData.InboundPeek.RenewalInterval = renewalInterval
} else {
data := InboundPeekCosmos{
data := inboundPeekCosmos{
RoomID: roomID,
ServerName: string(serverName),
PeekID: peekID,
@ -193,8 +178,8 @@ func (s *inboundPeeksStatements) InsertInboundPeek(
RenewalInterval: renewalInterval,
}
dbData = &InboundPeekCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &inboundPeekCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
InboundPeek: data,
}
}
@ -218,14 +203,12 @@ func (s *inboundPeeksStatements) RenewInboundPeek(
// "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
// _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_id, server_name, peek_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
// _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
res, err := getInboundPeek(s, ctx, pk, cosmosDocId)
res, err := getInboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return
@ -248,14 +231,12 @@ func (s *inboundPeeksStatements) SelectInboundPeek(
) (*types.InboundPeek, error) {
// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_id, server_name, peek_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getPartitionKey(), docId)
// row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID)
row, err := getInboundPeek(s, ctx, pk, cosmosDocId)
row, err := getInboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId)
if row == nil {
return nil, nil
@ -278,14 +259,18 @@ func (s *inboundPeeksStatements) SelectInboundPeeks(
) (inboundPeeks []types.InboundPeek, err error) {
// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
}
// rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID)
rows, err := queryInboundPeek(s, ctx, s.selectInboundPeeksStmt, params)
var rows []inboundPeekCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectInboundPeeksStmt, params, &rows)
if err != nil {
return
@ -310,15 +295,19 @@ func (s *inboundPeeksStatements) DeleteInboundPeek(
) (err error) {
// "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
"@x3": serverName,
}
// _, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
rows, err := queryInboundPeek(s, ctx, s.deleteInboundPeekStmt, params)
var rows []inboundPeekCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteInboundPeekStmt, params, &rows)
if err != nil {
return
@ -339,14 +328,18 @@ func (s *inboundPeeksStatements) DeleteInboundPeeks(
) (err error) {
// "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
}
// _, err = sqlutil.TxStmt(txn, s.deleteInboundPeeksStmt).ExecContext(ctx, roomID)
rows, err := queryInboundPeek(s, ctx, s.deleteInboundPeekStmt, params)
var rows []inboundPeekCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteInboundPeekStmt, params, &rows)
if err != nil {
return

View file

@ -45,15 +45,15 @@ import (
// ON federationsender_joined_hosts (room_id)
// `
type JoinedHostCosmos struct {
type joinedHostCosmos struct {
RoomID string `json:"room_id"`
EventID string `json:"event_id"`
ServerName string `json:"server_name"`
}
type JoinedHostCosmosData struct {
type joinedHostCosmosData struct {
cosmosdbapi.CosmosDocument
JoinedHost JoinedHostCosmos `json:"mx_federationsender_joined_host"`
JoinedHost joinedHostCosmos `json:"mx_federationsender_joined_host"`
}
// const insertJoinedHostsSQL = "" +
@ -96,8 +96,16 @@ type joinedHostsStatements struct {
tableName string
}
func getJoinedHost(s *joinedHostsStatements, ctx context.Context, pk string, docId string) (*JoinedHostCosmosData, error) {
response := JoinedHostCosmosData{}
func (s *joinedHostsStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *joinedHostsStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getJoinedHost(s *joinedHostsStatements, ctx context.Context, pk string, docId string) (*joinedHostCosmosData, error) {
response := joinedHostCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -113,49 +121,7 @@ func getJoinedHost(s *joinedHostsStatements, ctx context.Context, pk string, doc
return &response, err
}
func queryJoinedHostDistinct(s *joinedHostsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]JoinedHostCosmos, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []JoinedHostCosmos
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryJoinedHost(s *joinedHostsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]JoinedHostCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []JoinedHostCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func deleteJoinedHost(s *joinedHostsStatements, ctx context.Context, dbData JoinedHostCosmosData) error {
func deleteJoinedHost(s *joinedHostsStatements, ctx context.Context, dbData joinedHostCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
@ -194,23 +160,21 @@ func (s *joinedHostsStatements) InsertJoinedHosts(
// stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx
// ON federationsender_joined_hosts (event_id);
docId := fmt.Sprintf("%s", eventID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getJoinedHost(s, ctx, pk, cosmosDocId)
dbData, _ := getJoinedHost(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData == nil {
data := JoinedHostCosmos{
data := joinedHostCosmos{
EventID: eventID,
RoomID: roomID,
ServerName: string(serverName),
}
dbData = &JoinedHostCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &joinedHostCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
JoinedHost: data,
}
// _, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
@ -231,14 +195,17 @@ func (s *joinedHostsStatements) DeleteJoinedHosts(
for _, eventID := range eventIDs {
// "DELETE FROM federationsender_joined_hosts WHERE event_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventID,
}
// stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
rows, err := queryJoinedHost(s, ctx, s.deleteJoinedHostsStmt, params)
var rows []joinedHostCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteJoinedHostsStmt, params, &rows)
for _, item := range rows {
if err = deleteJoinedHost(s, ctx, item); err != nil {
@ -254,14 +221,18 @@ func (s *joinedHostsStatements) DeleteJoinedHostsForRoom(
) error {
// "DELETE FROM federationsender_joined_hosts WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
}
// stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
rows, err := queryJoinedHost(s, ctx, s.deleteJoinedHostsStmt, params)
var rows []joinedHostCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteJoinedHostsStmt, params, &rows)
// _, err := stmt.ExecContext(ctx, roomID)
for _, item := range rows {
@ -278,14 +249,18 @@ func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
// "SELECT event_id, server_name FROM federationsender_joined_hosts" +
// " WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
}
// stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt)
rows, err := queryJoinedHost(s, ctx, s.deleteJoinedHostsStmt, params)
var rows []joinedHostCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectJoinedHostsStmt, params, &rows)
if err != nil {
return nil, err
@ -305,13 +280,18 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
) ([]gomatrixserverlib.ServerName, error) {
// "SELECT DISTINCT server_name FROM federationsender_joined_hosts"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
}
// rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx)
rows, err := queryJoinedHostDistinct(s, ctx, s.selectAllJoinedHostsStmt, params)
var rows []joinedHostCosmos
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectAllJoinedHostsStmt, params, &rows)
if err != nil {
return nil, err
}
@ -337,14 +317,19 @@ func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
// "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
// sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomIDs,
}
// rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...)
rows, err := queryJoinedHostDistinct(s, ctx, s.selectAllJoinedHostsStmt, params)
var rows []joinedHostCosmos
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectAllJoinedHostsStmt, params, &rows)
if err != nil {
return nil, err
}
@ -359,7 +344,7 @@ func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
return result, nil
}
func rowsToJoinedHosts(rows *[]JoinedHostCosmosData) []types.JoinedHost {
func rowsToJoinedHosts(rows *[]joinedHostCosmosData) []types.JoinedHost {
var result []types.JoinedHost
if rows == nil {
return result

View file

@ -38,7 +38,7 @@ import (
// );
// `
type OutboundPeekCosmos struct {
type outboundPeekCosmos struct {
RoomID string `json:"room_id"`
ServerName string `json:"server_name"`
PeekID string `json:"peek_id"`
@ -47,9 +47,9 @@ type OutboundPeekCosmos struct {
RenewalInterval int64 `json:"renewal_interval"`
}
type OutboundPeekCosmosData struct {
type outboundPeekCosmosData struct {
cosmosdbapi.CosmosDocument
OutboundPeek OutboundPeekCosmos `json:"mx_federationsender_outbound_peek"`
OutboundPeek outboundPeekCosmos `json:"mx_federationsender_outbound_peek"`
}
// const insertOutboundPeekSQL = "" +
@ -85,29 +85,16 @@ type outboundPeeksStatements struct {
tableName string
}
func queryOutboundPeek(s *outboundPeeksStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OutboundPeekCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []OutboundPeekCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *outboundPeeksStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func getOutboundPeek(s *outboundPeeksStatements, ctx context.Context, pk string, docId string) (*OutboundPeekCosmosData, error) {
response := OutboundPeekCosmosData{}
func (s *outboundPeeksStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getOutboundPeek(s *outboundPeeksStatements, ctx context.Context, pk string, docId string) (*outboundPeekCosmosData, error) {
response := outboundPeekCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -123,7 +110,7 @@ func getOutboundPeek(s *outboundPeeksStatements, ctx context.Context, pk string,
return &response, err
}
func setOutboundPeek(s *outboundPeeksStatements, ctx context.Context, outboundPeek OutboundPeekCosmosData) (*OutboundPeekCosmosData, error) {
func setOutboundPeek(s *outboundPeeksStatements, ctx context.Context, outboundPeek outboundPeekCosmosData) (*outboundPeekCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(outboundPeek.Pk, outboundPeek.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
@ -135,7 +122,7 @@ func setOutboundPeek(s *outboundPeeksStatements, ctx context.Context, outboundPe
return &outboundPeek, ex
}
func deleteOutboundPeek(s *outboundPeeksStatements, ctx context.Context, dbData OutboundPeekCosmosData) error {
func deleteOutboundPeek(s *outboundPeeksStatements, ctx context.Context, dbData outboundPeekCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
@ -168,20 +155,18 @@ func (s *outboundPeeksStatements) InsertOutboundPeek(
// stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt)
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_id, server_name, peek_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getOutboundPeek(s, ctx, pk, cosmosDocId)
dbData, _ := getOutboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
dbData.SetUpdateTime()
dbData.OutboundPeek.RenewalInterval = renewalInterval
dbData.OutboundPeek.RenewedTimestamp = nowMilli
} else {
data := OutboundPeekCosmos{
data := outboundPeekCosmos{
RoomID: roomID,
ServerName: string(serverName),
PeekID: peekID,
@ -190,8 +175,8 @@ func (s *outboundPeeksStatements) InsertOutboundPeek(
RenewalInterval: renewalInterval,
}
dbData = &OutboundPeekCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &outboundPeekCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
OutboundPeek: data,
}
@ -215,14 +200,12 @@ func (s *outboundPeeksStatements) RenewOutboundPeek(
// "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_id, server_name, peek_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
// _, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
res, err := getOutboundPeek(s, ctx, pk, cosmosDocId)
res, err := getOutboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return
@ -245,14 +228,12 @@ func (s *outboundPeeksStatements) SelectOutboundPeek(
// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_id, server_name, peek_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
// row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID)
row, err := getOutboundPeek(s, ctx, pk, cosmosDocId)
row, err := getOutboundPeek(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return nil, err
@ -280,14 +261,19 @@ func (s *outboundPeeksStatements) SelectOutboundPeeks(
if err != nil {
return
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
}
// rows, err := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryContext(ctx, roomID)
rows, err := queryOutboundPeek(s, ctx, s.selectOutboundPeeksStmt, params)
var rows []outboundPeekCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectOutboundPeeksStmt, params, &rows)
if err != nil {
return
@ -313,15 +299,19 @@ func (s *outboundPeeksStatements) DeleteOutboundPeek(
// "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
"@x3": serverName,
}
// _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
rows, err := queryOutboundPeek(s, ctx, s.deleteOutboundPeekStmt, params)
var rows []outboundPeekCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteOutboundPeekStmt, params, &rows)
if err != nil {
return
@ -343,14 +333,18 @@ func (s *outboundPeeksStatements) DeleteOutboundPeeks(
// "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
}
// _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeeksStmt).ExecContext(ctx, roomID)
rows, err := queryOutboundPeek(s, ctx, s.deleteOutboundPeeksStmt, params)
var rows []outboundPeekCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteOutboundPeeksStmt, params, &rows)
if err != nil {
return

View file

@ -38,19 +38,19 @@ import (
// ON federationsender_queue_edus (json_nid, server_name);
// `
type QueueEDUCosmos struct {
type queueEDUCosmos struct {
EDUType string `json:"edu_type"`
ServerName string `json:"server_name"`
JSONNID int64 `json:"json_nid"`
}
type QueueEDUCosmosNumber struct {
type queueEDUCosmosNumber struct {
Number int64 `json:"number"`
}
type QueueEDUCosmosData struct {
type queueEDUCosmosData struct {
cosmosdbapi.CosmosDocument
QueueEDU QueueEDUCosmos `json:"mx_federationsender_queue_edu"`
QueueEDU queueEDUCosmos `json:"mx_federationsender_queue_edu"`
}
// const insertQueueEDUSQL = "" +
@ -96,70 +96,15 @@ type queueEDUsStatements struct {
tableName string
}
func queryQueueEDUC(s *queueEDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueEDUCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []QueueEDUCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *queueEDUsStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func queryQueueEDUCDistinct(s *queueEDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueEDUCosmos, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []QueueEDUCosmos
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *queueEDUsStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func queryQueueEDUCNumber(s *queueEDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueEDUCosmosNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []QueueEDUCosmosNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func deleteQueueEDUC(s *queueEDUsStatements, ctx context.Context, dbData QueueEDUCosmosData) error {
func deleteQueueEDUC(s *queueEDUsStatements, ctx context.Context, dbData queueEDUCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
@ -198,21 +143,19 @@ func (s *queueEDUsStatements) InsertQueueEDU(
// stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
// ON federationsender_queue_edus (json_nid, server_name);
docId := fmt.Sprintf("%d_%s", nid, eduType)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
data := QueueEDUCosmos{
data := queueEDUCosmos{
EDUType: eduType,
JSONNID: nid,
ServerName: string(serverName),
}
dbData := &QueueEDUCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData := &queueEDUCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
QueueEDU: data,
}
@ -244,16 +187,20 @@ func (s *queueEDUsStatements) DeleteQueueEDUs(
// deleteSQL := strings.Replace(deleteQueueEDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": serverName,
"@x3": jsonNIDs,
}
// stmt := sqlutil.TxStmt(txn, deleteStmt)
// _, err = stmt.ExecContext(ctx, params...)
rows, err := queryQueueEDUC(s, ctx, deleteQueueEDUsSQL, params)
var rows []queueEDUCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), deleteQueueEDUsSQL, params, &rows)
if err != nil {
return err
@ -280,15 +227,20 @@ func (s *queueEDUsStatements) SelectQueueEDUs(
// " LIMIT $2"
// stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": serverName,
"@x3": limit,
}
// rows, err := stmt.QueryContext(ctx, serverName, limit)
rows, err := queryQueueEDUC(s, ctx, deleteQueueEDUsSQL, params)
var rows []queueEDUCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), deleteQueueEDUsSQL, params, &rows)
if err != nil {
return nil, err
}
@ -309,14 +261,19 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
// "SELECT COUNT(*) FROM federationsender_queue_edus" +
// " WHERE json_nid = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": jsonNID,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt)
rows, err := queryQueueEDUCNumber(s, ctx, s.selectQueueEDUReferenceJSONCountStmt, params)
var rows []queueEDUCosmosNumber
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectQueueEDUReferenceJSONCountStmt, params, &rows)
if len(rows) == 0 {
return -1, nil
}
@ -333,14 +290,19 @@ func (s *queueEDUsStatements) SelectQueueEDUCount(
// "SELECT COUNT(*) FROM federationsender_queue_edus" +
// " WHERE server_name = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": serverName,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
rows, err := queryQueueEDUCNumber(s, ctx, s.selectQueueEDUCountStmt, params)
var rows []queueEDUCosmosNumber
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectQueueEDUCountStmt, params, &rows)
if len(rows) == 0 {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
@ -358,14 +320,19 @@ func (s *queueEDUsStatements) SelectQueueEDUServerNames(
// "SELECT DISTINCT server_name FROM federationsender_queue_edus"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt)
// rows, err := stmt.QueryContext(ctx)
rows, err := queryQueueEDUCDistinct(s, ctx, s.selectQueueEDUServerNamesStmt, params)
var rows []queueEDUCosmos
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectQueueEDUServerNamesStmt, params, &rows)
if err != nil {
return nil, err
}

View file

@ -35,14 +35,14 @@ import (
// );
// `
type QueueJSONCosmos struct {
type queueJSONCosmos struct {
JSONNID int64 `json:"json_nid"`
JSONBody []byte `json:"json_body"`
}
type QueueJSONCosmosData struct {
type queueJSONCosmosData struct {
cosmosdbapi.CosmosDocument
QueueJSON QueueJSONCosmos `json:"mx_federationsender_queue_json"`
QueueJSON queueJSONCosmos `json:"mx_federationsender_queue_json"`
}
// const insertJSONSQL = "" +
@ -68,28 +68,15 @@ type queueJSONStatements struct {
tableName string
}
func queryQueueJSON(s *queueJSONStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueJSONCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []QueueJSONCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *queueJSONStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func deleteQueueJSON(s *queueJSONStatements, ctx context.Context, dbData QueueJSONCosmosData) error {
func (s *queueJSONStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func deleteQueueJSON(s *queueJSONStatements, ctx context.Context, dbData queueJSONCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
@ -122,22 +109,20 @@ func (s *queueJSONStatements) InsertQueueJSON(
// json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
idSeq, err := GetNextQueueJSONNID(s, ctx)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
docId := fmt.Sprintf("%d", idSeq)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
//Convert to byte
jsonData := []byte(json)
data := QueueJSONCosmos{
data := queueJSONCosmos{
JSONNID: idSeq,
JSONBody: jsonData,
}
dbData := &QueueJSONCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData := &queueJSONCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
QueueJSON: data,
}
@ -165,16 +150,20 @@ func (s *queueJSONStatements) DeleteQueueJSON(
// "DELETE FROM federationsender_queue_json WHERE json_nid IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": nids,
}
// deleteSQL := strings.Replace(deleteJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
// deleteStmt, err := txn.Prepare(deleteSQL)
// stmt := sqlutil.TxStmt(txn, deleteStmt)
rows, err := queryQueueJSON(s, ctx, deleteJSONSQL, params)
var rows []queueJSONCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), deleteJSONSQL, params, &rows)
if err != nil {
return err
@ -198,15 +187,19 @@ func (s *queueJSONStatements) SelectQueueJSON(
// "SELECT json_nid, json_body FROM federationsender_queue_json" +
// " WHERE json_nid IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": jsonNIDs,
}
// selectSQL := strings.Replace(selectJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
// selectStmt, err := txn.Prepare(selectSQL)
rows, err := queryQueueJSON(s, ctx, selectJSONSQL, params)
var rows []queueJSONCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), selectJSONSQL, params, &rows)
if err != nil {
return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err)

View file

@ -39,19 +39,19 @@ import (
// ON federationsender_queue_pdus (json_nid, server_name);
// `
type QueuePDUCosmos struct {
type queuePDUCosmos struct {
TransactionID string `json:"transaction_id"`
ServerName string `json:"server_name"`
JSONNID int64 `json:"json_nid"`
}
type QueuePDUCosmosNumber struct {
type queuePDUCosmosNumber struct {
Number int64 `json:"number"`
}
type QueuePDUCosmosData struct {
type queuePDUCosmosData struct {
cosmosdbapi.CosmosDocument
QueuePDU QueuePDUCosmos `json:"mx_federationsender_queue_pdu"`
QueuePDU queuePDUCosmos `json:"mx_federationsender_queue_pdu"`
}
// const insertQueuePDUSQL = "" +
@ -108,70 +108,15 @@ type queuePDUsStatements struct {
tableName string
}
func queryQueuePDU(s *queuePDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueuePDUCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []QueuePDUCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *queuePDUsStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func queryQueuePDUDistinct(s *queuePDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueuePDUCosmos, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []QueuePDUCosmos
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *queuePDUsStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func queryQueuePDUNumber(s *queuePDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueuePDUCosmosNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []QueuePDUCosmosNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func deleteQueuePDU(s *queuePDUsStatements, ctx context.Context, dbData QueuePDUCosmosData) error {
func deleteQueuePDU(s *queuePDUsStatements, ctx context.Context, dbData queuePDUCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
@ -210,21 +155,19 @@ func (s *queuePDUsStatements) InsertQueuePDU(
// "INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" +
// " VALUES ($1, $2, $3)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx
// ON federationsender_queue_pdus (json_nid, server_name);
docId := fmt.Sprintf("%d_%s", nid, serverName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
data := QueuePDUCosmos{
data := queuePDUCosmos{
JSONNID: nid,
ServerName: string(serverName),
TransactionID: string(transactionID),
}
dbData := &QueuePDUCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData := &queuePDUCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
QueuePDU: data,
}
@ -255,16 +198,20 @@ func (s *queuePDUsStatements) DeleteQueuePDUs(
// "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid IN ($2)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": serverName,
"@x3": jsonNIDs,
}
// deleteSQL := strings.Replace(deleteQueuePDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
// deleteStmt, err := txn.Prepare(deleteSQL)
rows, err := queryQueuePDU(s, ctx, deleteQueuePDUsSQL, params)
var rows []queuePDUCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), deleteQueuePDUsSQL, params, &rows)
if err != nil {
return err
@ -290,14 +237,18 @@ func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID(
// " ORDER BY transaction_id ASC" +
// " LIMIT 1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": serverName,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt)
rows, err := queryQueuePDU(s, ctx, s.selectQueueNextTransactionIDStmt, params)
var rows []queuePDUCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectQueueNextTransactionIDStmt, params, &rows)
if err != nil {
return "", err
@ -319,14 +270,18 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
// "SELECT COUNT(*) FROM federationsender_queue_pdus" +
// " WHERE json_nid = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": jsonNID,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt)
rows, err := queryQueuePDUNumber(s, ctx, s.selectQueueReferenceJSONCountStmt, params)
var rows []queuePDUCosmosNumber
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectQueueReferenceJSONCountStmt, params, &rows)
if err != nil {
return -1, err
@ -348,14 +303,18 @@ func (s *queuePDUsStatements) SelectQueuePDUCount(
// "SELECT COUNT(*) FROM federationsender_queue_pdus" +
// " WHERE server_name = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": serverName,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
rows, err := queryQueuePDUNumber(s, ctx, s.selectQueuePDUsCountStmt, params)
var rows []queuePDUCosmosNumber
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectQueuePDUsCountStmt, params, &rows)
if err != nil {
return 0, err
@ -382,16 +341,20 @@ func (s *queuePDUsStatements) SelectQueuePDUs(
// " WHERE server_name = $1" +
// " LIMIT $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": serverName,
"@x3": limit,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt)
// rows, err := stmt.QueryContext(ctx, serverName, limit)
rows, err := queryQueuePDU(s, ctx, s.selectQueuePDUsStmt, params)
var rows []queuePDUCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectQueuePDUsStmt, params, &rows)
if err != nil {
return nil, err
@ -412,14 +375,19 @@ func (s *queuePDUsStatements) SelectQueuePDUServerNames(
// "SELECT DISTINCT server_name FROM federationsender_queue_pdus"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt)
// rows, err := stmt.QueryContext(ctx)
rows, err := queryQueuePDUDistinct(s, ctx, s.selectQueueServerNamesStmt, params)
var rows []queuePDUCosmos
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectQueueServerNamesStmt, params, &rows)
if err != nil {
return nil, err
}

View file

@ -47,6 +47,51 @@ func (doc *CosmosDocument) SetUpdateTime() {
doc.Ut = now.Format(time.RFC3339)
}
func PerformQuery(ctx context.Context,
conn CosmosConnection,
databaseName string,
containerName string,
partitonKey string,
qryString string,
params map[string]interface{},
response interface{}) error {
optionsQry := GetQueryDocumentsOptions(partitonKey)
var query = GetQuery(qryString, params)
_, err := GetClient(conn).QueryDocuments(
ctx,
databaseName,
containerName,
query,
&response,
optionsQry)
return err
}
func PerformQueryAllPartitions(ctx context.Context,
conn CosmosConnection,
databaseName string,
containerName string,
qryString string,
params map[string]interface{},
response interface{}) error {
var optionsQry = GetQueryAllPartitionsDocumentsOptions()
var query = GetQuery(qryString, params)
_, err := GetClient(conn).QueryDocuments(
ctx,
databaseName,
containerName,
query,
&response,
optionsQry)
// When there are no Rows we seem to get the generic Bad Req JSON error
if err != nil {
// return nil, err
}
return nil
}
func GenerateDocument(
collection string,
tenantName string,

View file

@ -32,6 +32,10 @@ func GetDocumentId(tenantName string, collectionName string, id string) string {
return fmt.Sprintf("%s,%s,%s", collectionName, tenantName, safeId)
}
func GetPartitionKey(tenantName string, collectionName string) string {
func GetPartitionKeyByCollection(tenantName string, collectionName string) string {
return fmt.Sprintf("%s,%s", collectionName, tenantName)
}
func GetPartitionKeyByUniqueId(tenantName string, collectionName string, uniqueId string) string {
return fmt.Sprintf("%s,%s,%s", collectionName, tenantName, uniqueId)
}

View file

@ -46,15 +46,15 @@ import (
// );
// `
type PartitionOffsetCosmos struct {
type partitionOffsetCosmos struct {
Topic string `json:"topic"`
Partition int32 `json:"partition"`
PartitionOffset int64 `json:"partition_offset"`
}
type PartitionOffsetCosmosData struct {
type partitionOffsetCosmosData struct {
cosmosdbapi.CosmosDocument
PartitionOffset PartitionOffsetCosmos `json:"mx_partition_offset"`
PartitionOffset partitionOffsetCosmos `json:"mx_partition_offset"`
}
// "SELECT partition, partition_offset FROM ${prefix}_partition_offsets WHERE topic = $1"
@ -84,8 +84,18 @@ type PartitionOffsetStatements struct {
tableName string
}
func getPartitionOffset(s *PartitionOffsetStatements, ctx context.Context, pk string, docId string) (*PartitionOffsetCosmosData, error) {
response := PartitionOffsetCosmosData{}
func (s PartitionOffsetStatements) getCollectionName() string {
// Include the Prefix
tableName := fmt.Sprintf("%s_%s", s.prefix, s.tableName)
return cosmosdbapi.GetCollectionName(s.db.DatabaseName, tableName)
}
func (s *PartitionOffsetStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.CosmosConfig.TenantName, s.getCollectionName())
}
func getPartitionOffset(s *PartitionOffsetStatements, ctx context.Context, pk string, docId string) (*partitionOffsetCosmosData, error) {
response := partitionOffsetCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.Connection,
s.db.CosmosConfig,
@ -101,27 +111,6 @@ func getPartitionOffset(s *PartitionOffsetStatements, ctx context.Context, pk st
return &response, err
}
func queryPartitionOffset(s *PartitionOffsetStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PartitionOffsetCosmosData, error) {
var dbCollectionName = getCollectionName(*s)
var pk = cosmosdbapi.GetPartitionKey(s.db.CosmosConfig.ContainerName, dbCollectionName)
var response []PartitionOffsetCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.Connection).QueryDocuments(
ctx,
s.db.CosmosConfig.DatabaseName,
s.db.CosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
// Prepare converts the raw SQL statements into prepared statements.
// Takes a prefix to prepend to the table name used to store the partition offsets.
// This allows multiple components to share the same database schema.
@ -155,13 +144,18 @@ func (s *PartitionOffsetStatements) selectPartitionOffsets(
// "SELECT partition, partition_offset FROM ${prefix}_partition_offsets WHERE topic = $1"
var dbCollectionName = getCollectionName(*s)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": topic,
}
rows, err := queryPartitionOffset(s, ctx, s.selectPartitionOffsetsStmt, params)
var rows []partitionOffsetCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.Connection,
s.db.CosmosConfig.DatabaseName,
s.db.CosmosConfig.ContainerName,
s.getPartitionKey(), s.selectPartitionOffsetsStmt, params, &rows)
// rows, err := s.selectPartitionOffsetsStmt.QueryContext(ctx, topic)
if err != nil {
return nil, err
@ -197,25 +191,23 @@ func (s *PartitionOffsetStatements) upsertPartitionOffset(
// stmt := TxStmt(txn, s.upsertPartitionOffsetStmt)
dbCollectionName := getCollectionName(*s)
// UNIQUE (topic, partition)
docId := fmt.Sprintf("%s_%d", topic, partition)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.CosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.CosmosConfig.ContainerName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.CosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getPartitionOffset(s, ctx, pk, cosmosDocId)
dbData, _ := getPartitionOffset(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
dbData.SetUpdateTime()
dbData.PartitionOffset.PartitionOffset = offset
} else {
data := PartitionOffsetCosmos{
data := partitionOffsetCosmos{
Partition: partition,
PartitionOffset: offset,
Topic: topic,
}
dbData = &PartitionOffsetCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.CosmosConfig.TenantName, pk, cosmosDocId),
dbData = &partitionOffsetCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.CosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
PartitionOffset: data,
}
@ -231,9 +223,3 @@ func (s *PartitionOffsetStatements) upsertPartitionOffset(
&dbData)
})
}
func getCollectionName(s PartitionOffsetStatements) string {
// Include the Prefix
tableName := fmt.Sprintf("%s_%s", s.prefix, s.tableName)
return cosmosdbapi.GetCollectionName(s.db.DatabaseName, tableName)
}

View file

@ -28,21 +28,21 @@ import (
// );
// `
type TopicCosmos struct {
type topicCosmos struct {
TopicName string `json:"topic_name"`
TopicNID int64 `json:"topic_nid"`
}
type TopicCosmosNumber struct {
type topicCosmosNumber struct {
Number int64 `json:"number"`
}
type TopicCosmosData struct {
type topicCosmosData struct {
cosmosdbapi.CosmosDocument
Topic TopicCosmos `json:"mx_naffka_topic"`
Topic topicCosmos `json:"mx_naffka_topic"`
}
type MessageCosmos struct {
type messageCosmos struct {
TopicNID int64 `json:"topic_nid"`
MessageOffset int64 `json:"message_offset"`
MessageKey []byte `json:"message_key"`
@ -50,9 +50,9 @@ type MessageCosmos struct {
MessageTimestampNS int64 `json:"message_timestamp_ns"`
}
type MessageCosmosData struct {
type messageCosmosData struct {
cosmosdbapi.CosmosDocument
Message MessageCosmos `json:"mx_naffka_message"`
Message messageCosmos `json:"mx_naffka_message"`
}
// const insertTopicSQL = "" +
@ -104,71 +104,24 @@ type topicsStatements struct {
tableNameMessages string
}
func queryTopic(s *topicsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]TopicCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.DB.databaseName, s.tableNameTopics)
var pk = cosmosdbapi.GetPartitionKey(s.DB.cosmosConfig.ContainerName, dbCollectionName)
var response []TopicCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.DB.connection).QueryDocuments(
ctx,
s.DB.cosmosConfig.DatabaseName,
s.DB.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *topicsStatements) getCollectionNameTopics() string {
return cosmosdbapi.GetCollectionName(s.DB.databaseName, s.tableNameTopics)
}
func queryTopicNumber(s *topicsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]TopicCosmosNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.DB.databaseName, s.tableNameTopics)
var pk = cosmosdbapi.GetPartitionKey(s.DB.cosmosConfig.ContainerName, dbCollectionName)
var response []TopicCosmosNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.DB.connection).QueryDocuments(
ctx,
s.DB.cosmosConfig.DatabaseName,
s.DB.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *topicsStatements) getPartitionKeyTopics() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.DB.cosmosConfig.TenantName, s.getCollectionNameTopics())
}
func queryMessage(s *topicsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]MessageCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.DB.databaseName, s.tableNameMessages)
var pk = cosmosdbapi.GetPartitionKey(s.DB.cosmosConfig.ContainerName, dbCollectionName)
var response []MessageCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.DB.connection).QueryDocuments(
ctx,
s.DB.cosmosConfig.DatabaseName,
s.DB.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *topicsStatements) getCollectionNameMessages() string {
return cosmosdbapi.GetCollectionName(s.DB.databaseName, s.tableNameMessages)
}
func getTopic(s *topicsStatements, ctx context.Context, pk string, docId string) (*TopicCosmosData, error) {
response := TopicCosmosData{}
func (s *topicsStatements) getPartitionKeyMessages() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.DB.cosmosConfig.TenantName, s.getCollectionNameMessages())
}
func getTopic(s *topicsStatements, ctx context.Context, pk string, docId string) (*topicCosmosData, error) {
response := topicCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.DB.connection,
s.DB.cosmosConfig,
@ -212,24 +165,22 @@ func (t *topicsStatements) InsertTopic(
// return errSeq
// }
var dbCollectionName = cosmosdbapi.GetCollectionName(t.DB.databaseName, t.tableNameTopics)
// topic_name TEXT UNIQUE,
docId := fmt.Sprintf("%s", topicName)
cosmosDocId := cosmosdbapi.GetDocumentId(t.DB.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(t.DB.cosmosConfig.ContainerName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(t.DB.cosmosConfig.ContainerName, t.getCollectionNameTopics(), docId)
dbData, _ := getTopic(t, ctx, pk, cosmosDocId)
dbData, _ := getTopic(t, ctx, t.getPartitionKeyTopics(), cosmosDocId)
if dbData != nil {
dbData.SetUpdateTime()
dbData.Topic.TopicName = topicName
} else {
data := TopicCosmos{
data := topicCosmos{
TopicNID: topicNID,
TopicName: topicName,
}
dbData = &TopicCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, t.DB.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &topicCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(t.getCollectionNameTopics(), t.DB.cosmosConfig.TenantName, t.getPartitionKeyTopics(), cosmosDocId),
Topic: data,
}
}
@ -250,14 +201,18 @@ func (t *topicsStatements) SelectNextTopicNID(
// "SELECT COUNT(topic_nid)+1 AS topic_nid FROM naffka_topics"
var dbCollectionName = cosmosdbapi.GetCollectionName(t.DB.databaseName, t.tableNameTopics)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": t.getCollectionNameTopics(),
}
// stmt := sqlutil.TxStmt(txn, t.selectNextTopicNIDStmt)
// err = stmt.QueryRowContext(ctx).Scan(&topicNID)
rows, err := queryTopicNumber(t, ctx, t.selectNextTopicNIDStmt, params)
var rows []topicCosmosNumber
err = cosmosdbapi.PerformQuery(ctx,
t.DB.connection,
t.DB.cosmosConfig.DatabaseName,
t.DB.cosmosConfig.ContainerName,
t.getPartitionKeyTopics(), t.selectNextTopicNIDStmt, params, &rows)
if err != nil {
return 0, err
@ -279,14 +234,12 @@ func (t *topicsStatements) SelectTopic(
// stmt := sqlutil.TxStmt(txn, t.selectTopicStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(t.DB.databaseName, t.tableNameTopics)
// topic_name TEXT UNIQUE,
docId := fmt.Sprintf("%s", topicName)
cosmosDocId := cosmosdbapi.GetDocumentId(t.DB.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(t.DB.cosmosConfig.ContainerName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(t.DB.cosmosConfig.ContainerName, t.getCollectionNameTopics(), docId)
// err = stmt.QueryRowContext(ctx, topicName).Scan(&topicNID)
res, err := getTopic(t, ctx, pk, cosmosDocId)
res, err := getTopic(t, ctx, t.getPartitionKeyTopics(), cosmosDocId)
if err != nil {
return 0, err
@ -304,14 +257,18 @@ func (t *topicsStatements) SelectTopics(
// "SELECT topic_name, topic_nid FROM naffka_topics"
var dbCollectionName = cosmosdbapi.GetCollectionName(t.DB.databaseName, t.tableNameTopics)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": t.getCollectionNameTopics(),
}
// stmt := sqlutil.TxStmt(txn, t.selectTopicsStmt)
// rows, err := stmt.QueryContext(ctx)
rows, err := queryTopic(t, ctx, t.selectTopicsStmt, params)
var rows []topicCosmosData
err := cosmosdbapi.PerformQuery(ctx,
t.DB.connection,
t.DB.cosmosConfig.DatabaseName,
t.DB.cosmosConfig.ContainerName,
t.getPartitionKeyTopics(), t.selectTopicsStmt, params, &rows)
if err != nil {
return nil, err
@ -340,13 +297,11 @@ func (t *topicsStatements) InsertTopics(
// stmt := sqlutil.TxStmt(txn, t.insertTopicsStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(t.DB.databaseName, t.tableNameMessages)
// UNIQUE (topic_nid, message_offset)
docId := fmt.Sprintf("%d_%d", topicNID, messageOffset)
cosmosDocId := cosmosdbapi.GetDocumentId(t.DB.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(t.DB.cosmosConfig.ContainerName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(t.DB.cosmosConfig.ContainerName, t.getCollectionNameMessages(), docId)
data := MessageCosmos{
data := messageCosmos{
TopicNID: topicNID,
MessageOffset: messageOffset,
MessageKey: topicKey,
@ -354,8 +309,8 @@ func (t *topicsStatements) InsertTopics(
MessageTimestampNS: messageTimestampNs,
}
dbData := &MessageCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, t.DB.cosmosConfig.TenantName, pk, cosmosDocId),
dbData := &messageCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(t.getCollectionNameMessages(), t.DB.cosmosConfig.TenantName, t.getPartitionKeyMessages(), cosmosDocId),
Message: data,
}
@ -379,9 +334,8 @@ func (t *topicsStatements) SelectMessages(
// " FROM naffka_messages WHERE topic_nid = $1 AND $2 <= message_offset AND message_offset < $3" +
// " ORDER BY message_offset ASC"
var dbCollectionName = cosmosdbapi.GetCollectionName(t.DB.databaseName, t.tableNameMessages)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": t.getCollectionNameMessages(),
"@x2": topicNID,
"@x3": startOffset,
"@x4": endOffset,
@ -389,7 +343,12 @@ func (t *topicsStatements) SelectMessages(
// stmt := sqlutil.TxStmt(txn, t.selectMessagesStmt)
// rows, err := stmt.QueryContext(ctx, topicNID, startOffset, endOffset)
rows, err := queryMessage(t, ctx, t.selectMessagesStmt, params)
var rows []messageCosmosData
err := cosmosdbapi.PerformQuery(ctx,
t.DB.connection,
t.DB.cosmosConfig.DatabaseName,
t.DB.cosmosConfig.ContainerName,
t.getPartitionKeyMessages(), t.selectMessagesStmt, params, &rows)
if err != nil {
return nil, err
@ -416,15 +375,19 @@ func (t *topicsStatements) SelectMaxOffset(
// "SELECT message_offset FROM naffka_messages WHERE topic_nid = $1" +
// " ORDER BY message_offset DESC LIMIT 1"
var dbCollectionName = cosmosdbapi.GetCollectionName(t.DB.databaseName, t.tableNameMessages)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": t.getCollectionNameMessages(),
"@x2": topicNID,
}
// stmt := sqlutil.TxStmt(txn, t.selectMaxOffsetStmt)
// err = stmt.QueryRowContext(ctx, topicNID).Scan(&offset)
rows, err := queryMessage(t, ctx, t.selectMaxOffsetStmt, params)
var rows []messageCosmosData
err = cosmosdbapi.PerformQuery(ctx,
t.DB.connection,
t.DB.cosmosConfig.DatabaseName,
t.DB.cosmosConfig.ContainerName,
t.getPartitionKeyMessages(), t.selectMaxOffsetStmt, params, &rows)
if err != nil {
return 0, err

View file

@ -34,15 +34,15 @@ import (
// );
// `
type CrossSigningKeysCosmos struct {
type crossSigningKeysCosmos struct {
UserID string `json:"user_id"`
KeyType int64 `json:"key_type"`
KeyData []byte `json:"key_data"`
}
type CrossSigningKeysCosmosData struct {
type crossSigningKeysCosmosData struct {
cosmosdbapi.CosmosDocument
CrossSigningKeys CrossSigningKeysCosmos `json:"mx_keyserver_cross_signing_keys"`
CrossSigningKeys crossSigningKeysCosmos `json:"mx_keyserver_cross_signing_keys"`
}
// "SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
@ -62,8 +62,16 @@ type crossSigningKeysStatements struct {
tableName string
}
func getCrossSigningKeys(s *crossSigningKeysStatements, ctx context.Context, pk string, docId string) (*CrossSigningKeysCosmosData, error) {
response := CrossSigningKeysCosmosData{}
func (s *crossSigningKeysStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *crossSigningKeysStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getCrossSigningKeys(s *crossSigningKeysStatements, ctx context.Context, pk string, docId string) (*crossSigningKeysCosmosData, error) {
response := crossSigningKeysCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -79,27 +87,6 @@ func getCrossSigningKeys(s *crossSigningKeysStatements, ctx context.Context, pk
return &response, err
}
func queryCrossSigningKeys(s *crossSigningKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]CrossSigningKeysCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []CrossSigningKeysCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func NewSqliteCrossSigningKeysTable(db *Database) (tables.CrossSigningKeys, error) {
s := &crossSigningKeysStatements{
db: db,
@ -115,12 +102,18 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
) (r types.CrossSigningKeyMap, err error) {
// "SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
// " WHERE user_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": userID,
}
rows, err := queryCrossSigningKeys(s, ctx, s.selectCrossSigningKeysForUserStmt, params)
var rows []crossSigningKeysCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectCrossSigningKeysForUserStmt, params, &rows)
// rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
if err != nil {
return nil, err
@ -154,25 +147,23 @@ func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
if !ok {
return fmt.Errorf("unknown key purpose %q", keyType)
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
// PRIMARY KEY (user_id, key_type)
docId := fmt.Sprintf("%s_%s", userID, keyType)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getCrossSigningKeys(s, ctx, pk, cosmosDocId)
dbData, _ := getCrossSigningKeys(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
dbData.SetUpdateTime()
dbData.CrossSigningKeys.KeyData = keyData
} else {
data := CrossSigningKeysCosmos{
data := crossSigningKeysCosmos{
UserID: userID,
KeyType: int64(keyTypeInt),
KeyData: keyData,
}
dbData = &CrossSigningKeysCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &crossSigningKeysCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
CrossSigningKeys: data,
}
}

View file

@ -36,7 +36,7 @@ import (
// );
// `
type CrossSigningSigsCosmos struct {
type crossSigningSigsCosmos struct {
OriginUserId string `json:"origin_user_id"`
OriginKeyId string `json:"origin_key_id"`
TargetUserId string `json:"target_user_id"`
@ -44,9 +44,9 @@ type CrossSigningSigsCosmos struct {
Signature []byte `json:"signature"`
}
type CrossSigningSigsCosmosData struct {
type crossSigningSigsCosmosData struct {
cosmosdbapi.CosmosDocument
CrossSigningSigs CrossSigningSigsCosmos `json:"mx_keyserver_cross_signing_sigs"`
CrossSigningSigs crossSigningSigsCosmos `json:"mx_keyserver_cross_signing_sigs"`
}
// "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
@ -74,8 +74,16 @@ type crossSigningSigsStatements struct {
tableName string
}
func getCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, pk string, docId string) (*CrossSigningSigsCosmosData, error) {
response := CrossSigningSigsCosmosData{}
func (s *crossSigningSigsStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *crossSigningSigsStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, pk string, docId string) (*crossSigningSigsCosmosData, error) {
response := crossSigningSigsCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -91,28 +99,7 @@ func getCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, pk
return &response, err
}
func queryCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]CrossSigningSigsCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []CrossSigningSigsCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func deleteCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, dbData CrossSigningSigsCosmosData) error {
func deleteCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, dbData crossSigningSigsCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
@ -147,13 +134,19 @@ func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
) (r types.CrossSigningSigMap, err error) {
// "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
// " WHERE target_user_id = $1 AND target_key_id = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": targetUserID,
"@x3": targetKeyID,
}
rows, err := queryCrossSigningSigs(s, ctx, s.selectCrossSigningSigsForTargetStmt, params)
var rows []crossSigningSigsCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectCrossSigningSigsForTargetStmt, params, &rows)
// rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID)
if err != nil {
return nil, err
@ -187,19 +180,18 @@ func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
) error {
// "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
// " VALUES($1, $2, $3, $4, $5)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
// PRIMARY KEY (origin_user_id, target_user_id, target_key_id)
docId := fmt.Sprintf("%s_%s_%s", originUserID, targetUserID, targetKeyID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getCrossSigningSigs(s, ctx, pk, cosmosDocId)
dbData, _ := getCrossSigningSigs(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
dbData.SetUpdateTime()
dbData.CrossSigningSigs.OriginKeyId = string(originKeyID)
dbData.CrossSigningSigs.Signature = signature
} else {
data := CrossSigningSigsCosmos{
data := crossSigningSigsCosmos{
TargetUserId: targetUserID,
TargetKeyId: string(targetKeyID),
OriginUserId: originUserID,
@ -207,8 +199,8 @@ func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
Signature: signature,
}
dbData = &CrossSigningSigsCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &crossSigningSigsCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
CrossSigningSigs: data,
}
}
@ -228,13 +220,18 @@ func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
) error {
// "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": targetUserID,
"@x3": targetKeyID,
}
rows, err := queryCrossSigningSigs(s, ctx, s.selectCrossSigningSigsForTargetStmt, params)
var rows []crossSigningSigsCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectCrossSigningSigsForTargetStmt, params, &rows)
// if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
// return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
// }

View file

@ -39,7 +39,7 @@ import (
// );
// `
type DeviceKeyCosmos struct {
type deviceKeyCosmos struct {
UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
// Use the CosmosDB.Timestamp for this one
@ -49,13 +49,13 @@ type DeviceKeyCosmos struct {
DisplayName string `json:"display_name"`
}
type DeviceKeyCosmosNumber struct {
type deviceKeyCosmosNumber struct {
Number int64 `json:"number"`
}
type DeviceKeyCosmosData struct {
type deviceKeyCosmosData struct {
cosmosdbapi.CosmosDocument
DeviceKey DeviceKeyCosmos `json:"mx_keyserver_device_key"`
DeviceKey deviceKeyCosmos `json:"mx_keyserver_device_key"`
}
// const upsertDeviceKeysSQL = "" +
@ -97,54 +97,8 @@ const deleteDeviceKeysSQL = "" +
// const deleteAllDeviceKeysSQL = "" +
// "DELETE FROM keyserver_device_keys WHERE user_id=$1"
func queryDeviceKey(s *deviceKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceKeyCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []DeviceKeyCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryDeviceKeyNumber(s *deviceKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceKeyCosmosNumber, error) {
var response []DeviceKeyCosmosNumber
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
var query = cosmosdbapi.GetQuery(qry, params)
var _, _ = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
//WHen there is no data these GroupBy queries return errors
// if err != nil {
// return nil, err
// }
if len(response) == 0 {
return nil, cosmosdbutil.ErrNoRows
}
return response, nil
}
func getDeviceKey(s *deviceKeysStatements, ctx context.Context, pk string, docId string) (*DeviceKeyCosmosData, error) {
response := DeviceKeyCosmosData{}
func getDeviceKey(s *deviceKeysStatements, ctx context.Context, pk string, docId string) (*deviceKeyCosmosData, error) {
response := deviceKeyCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -160,7 +114,7 @@ func getDeviceKey(s *deviceKeysStatements, ctx context.Context, pk string, docId
return &response, err
}
func insertDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData DeviceKeyCosmosData) error {
func insertDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData deviceKeyCosmosData) error {
// "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
// " VALUES ($1, $2, $3, $4, $5, $6)" +
// " ON CONFLICT (user_id, device_id)" +
@ -189,8 +143,8 @@ func insertDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData De
return nil
}
func mapFromDeviceKeyMessage(key api.DeviceMessage) DeviceKeyCosmos {
return DeviceKeyCosmos{
func mapFromDeviceKeyMessage(key api.DeviceMessage) deviceKeyCosmos {
return deviceKeyCosmos{
DeviceID: key.DeviceID,
DisplayName: key.DisplayName,
KeyJSON: key.KeyJSON,
@ -210,6 +164,14 @@ type deviceKeysStatements struct {
tableName string
}
func (s *deviceKeysStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *deviceKeysStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func NewCosmosDBDeviceKeysTable(db *Database) (tables.DeviceKeys, error) {
s := &deviceKeysStatements{
db: db,
@ -221,7 +183,7 @@ func NewCosmosDBDeviceKeysTable(db *Database) (tables.DeviceKeys, error) {
return s, nil
}
func deleteDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData DeviceKeyCosmosData) error {
func deleteDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData deviceKeyCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
@ -239,19 +201,24 @@ func deleteDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData De
func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
// "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
// _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": userID,
"@x3": deviceID,
}
response, err := queryDeviceKey(s, ctx, selectAllDeviceKeysSQL, params)
var rows []deviceKeyCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), selectAllDeviceKeysSQL, params, &rows)
if err != nil {
return err
}
for _, item := range response {
for _, item := range rows {
errItem := deleteDeviceKeyCore(s, ctx, item)
if errItem != nil {
return errItem
@ -265,18 +232,23 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
// "DELETE FROM keyserver_device_keys WHERE user_id=$1"
// _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": userID,
}
response, err := queryDeviceKey(s, ctx, selectAllDeviceKeysSQL, params)
var rows []deviceKeyCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), selectAllDeviceKeysSQL, params, &rows)
if err != nil {
return err
}
for _, item := range response {
for _, item := range rows {
errItem := deleteDeviceKeyCore(s, ctx, item)
if errItem != nil {
return errItem
@ -293,12 +265,18 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
for _, d := range deviceIDs {
deviceIDMap[d] = true
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": userID,
}
response, err := queryDeviceKey(s, ctx, s.selectBatchDeviceKeysStmt, params)
var rows []deviceKeyCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectBatchDeviceKeysStmt, params, &rows)
// rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
@ -306,7 +284,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
// defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
var result []api.DeviceMessage
for _, item := range response {
for _, item := range rows {
dk := api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{},
@ -344,13 +322,12 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
// "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
// err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (user_id, device_id)
docId := fmt.Sprintf("%s_%s", key.UserID, key.DeviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
response, err := getDeviceKey(s, ctx, pk, cosmosDocId)
response, err := getDeviceKey(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil && err != cosmosdbutil.ErrNoRows {
return err
@ -377,15 +354,19 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
// "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": s.db.cosmosConfig.TenantName,
"@x2": dbCollectionName,
"@x2": s.getCollectionName(),
"@x3": userID,
}
// err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
response, err := queryDeviceKeyNumber(s, ctx, selectMaxStreamForUserSQL, params)
var rows []deviceKeyCosmosNumber
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), selectMaxStreamForUserSQL, params, &rows)
if err != nil {
if err == cosmosdbutil.ErrNoRows {
@ -395,8 +376,8 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
}
}
if len(response) > 0 {
nullStream.Int32 = int32(response[0].Number)
if len(rows) > 0 {
nullStream.Int32 = int32(rows[0].Number)
}
if nullStream.Valid {
@ -415,10 +396,9 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
iStreamIDs[i+1] = streamIDs[i]
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": s.db.cosmosConfig.TenantName,
"@x2": dbCollectionName,
"@x2": s.getCollectionName(),
"@x3": userID,
"@x4": iStreamIDs,
}
@ -428,7 +408,12 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
// var count sql.NullInt32
// err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
response, err := queryDeviceKeyNumber(s, ctx, countStreamIDsForUserSQL, params)
var rows []deviceKeyCosmosNumber
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), countStreamIDsForUserSQL, params, &rows)
if err != nil {
return 0, err
@ -436,8 +421,8 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
// if count.Valid {
// return int(count.Int32), nil
// }
if response[0].Number >= 0 {
return int(response[0].Number), nil
if rows[0].Number >= 0 {
return int(rows[0].Number), nil
}
return 0, nil
}
@ -448,16 +433,14 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
// " VALUES ($1, $2, $3, $4, $5, $6)" +
// " ON CONFLICT (user_id, device_id)" +
// " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
for _, key := range keys {
// UNIQUE (user_id, device_id)
docId := fmt.Sprintf("%s_%s", key.UserID, key.DeviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData := &DeviceKeyCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData := &deviceKeyCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
DeviceKey: mapFromDeviceKeyMessage(key),
}

View file

@ -35,20 +35,20 @@ import (
// );
// `
type KeyChangeCosmos struct {
type keyChangeCosmos struct {
Partition int32 `json:"partition"`
Offset int64 `json:"_offset"` //offset is reserved
UserID string `json:"user_id"`
}
type KeyChangeUserMaxCosmosData struct {
type keyChangeUserMaxCosmosData struct {
UserID string `json:"user_id"`
MaxOffset int64 `json:"max_offset"`
}
type KeyChangeCosmosData struct {
type keyChangeCosmosData struct {
cosmosdbapi.CosmosDocument
KeyChange KeyChangeCosmos `json:"mx_keyserver_key_change"`
KeyChange keyChangeCosmos `json:"mx_keyserver_key_change"`
}
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
@ -78,8 +78,16 @@ type keyChangesStatements struct {
tableName string
}
func getKeyChangeUser(s *keyChangesStatements, ctx context.Context, pk string, docId string) (*KeyChangeCosmosData, error) {
response := KeyChangeCosmosData{}
func (s *keyChangesStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *keyChangesStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getKeyChangeUser(s *keyChangesStatements, ctx context.Context, pk string, docId string) (*keyChangeCosmosData, error) {
response := keyChangeCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -95,27 +103,6 @@ func getKeyChangeUser(s *keyChangesStatements, ctx context.Context, pk string, d
return &response, err
}
func queryKeyChangeUserMax(s *keyChangesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyChangeUserMaxCosmosData, error) {
var response []KeyChangeUserMaxCosmosData
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
var query = cosmosdbapi.GetQuery(qry, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
// When there are no Rows we seem to get the generic Bad Req JSON error
if err != nil {
// return nil, err
}
return response, nil
}
func NewCosmosDBKeyChangesTable(db *Database) (tables.KeyChanges, error) {
s := &keyChangesStatements{
db: db,
@ -132,25 +119,23 @@ func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition in
// " ON CONFLICT (partition, offset)" +
// " DO UPDATE SET user_id = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
// UNIQUE (partition, offset)
docId := fmt.Sprintf("%d_%d", partition, offset)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getKeyChangeUser(s, ctx, pk, cosmosDocId)
dbData, _ := getKeyChangeUser(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
dbData.SetUpdateTime()
dbData.KeyChange.UserID = userID
} else {
data := KeyChangeCosmos{
data := keyChangeCosmos{
Offset: offset,
Partition: partition,
UserID: userID,
}
dbData = &KeyChangeCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &keyChangeCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
KeyChange: data,
}
}
@ -175,22 +160,26 @@ func (s *keyChangesStatements) SelectKeyChanges(
// "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id"
// rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": s.db.cosmosConfig.TenantName,
"@x2": dbCollectionName,
"@x2": s.getCollectionName(),
"@x3": partition,
"@x4": fromOffset,
"@x5": toOffset,
}
response, err := queryKeyChangeUserMax(s, ctx, s.selectKeyChangesStmt, params)
var rows []keyChangeUserMaxCosmosData
err = cosmosdbapi.PerformQueryAllPartitions(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.selectKeyChangesStmt, params, &rows)
if err != nil {
return nil, 0, err
}
for _, item := range response {
for _, item := range rows {
var userID string
var offset int64
userID = item.UserID

View file

@ -42,7 +42,7 @@ import (
// );
// `
type OneTimeKeyCosmos struct {
type oneTimeKeyCosmos struct {
UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
KeyID string `json:"key_id"`
@ -52,14 +52,14 @@ type OneTimeKeyCosmos struct {
KeyJSON []byte `json:"key_json"`
}
type OneTimeKeyAlgoNumberCosmosData struct {
type oneTimeKeyAlgoNumberCosmosData struct {
Algorithm string `json:"algorithm"`
Number int `json:"number"`
}
type OneTimeKeyCosmosData struct {
type oneTimeKeyCosmosData struct {
cosmosdbapi.CosmosDocument
OneTimeKey OneTimeKeyCosmos `json:"mx_keyserver_one_time_key"`
OneTimeKey oneTimeKeyCosmos `json:"mx_keyserver_one_time_key"`
}
// const upsertKeysSQL = "" +
@ -102,8 +102,16 @@ type oneTimeKeysStatements struct {
tableName string
}
func getOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, pk string, docId string) (*OneTimeKeyCosmosData, error) {
response := OneTimeKeyCosmosData{}
func (s *oneTimeKeysStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *oneTimeKeysStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, pk string, docId string) (*oneTimeKeyCosmosData, error) {
response := oneTimeKeyCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -119,53 +127,7 @@ func getOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, pk string, doc
return &response, err
}
func queryOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyCosmosData, error) {
var response []OneTimeKeyCosmosData
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryOneTimeKeyAlgoCount(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyAlgoNumberCosmosData, error) {
var response []OneTimeKeyAlgoNumberCosmosData
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
// var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
var query = cosmosdbapi.GetQuery(qry, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
// When there are no Rows we seem to get the generic Bad Req JSON error
if err != nil {
// return nil, err
}
return response, nil
}
func insertOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData OneTimeKeyCosmosData) error {
func insertOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData oneTimeKeyCosmosData) error {
// "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
// " VALUES ($1, $2, $3, $4, $5, $6)" +
// " ON CONFLICT (user_id, device_id, key_id, algorithm)" +
@ -191,7 +153,7 @@ func insertOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData
return nil
}
func deleteOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData OneTimeKeyCosmosData) error {
func deleteOneTimeKeyCore(s *oneTimeKeysStatements, ctx context.Context, dbData oneTimeKeyCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
@ -221,14 +183,19 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
// "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": userID,
"@x3": deviceID,
}
response, err := queryOneTimeKey(s, ctx, s.selectKeyByAlgorithmStmt, params)
var rows []oneTimeKeyCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectKeyByAlgorithmStmt, params, &rows)
if err != nil {
return nil, err
}
@ -239,7 +206,7 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
}
result := make(map[string]json.RawMessage)
for _, item := range response {
for _, item := range rows {
var keyID string
var algorithm string
keyID = item.OneTimeKey.KeyID
@ -260,21 +227,25 @@ func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, de
KeyCount: make(map[string]int),
}
// rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": counts.UserID,
"@x3": counts.DeviceID,
}
// "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
response, err := queryOneTimeKeyAlgoCount(s, ctx, s.selectKeysCountStmt, params)
var rows []oneTimeKeyAlgoNumberCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectKeysCountStmt, params, &rows)
if err != nil {
return nil, err
}
for _, item := range response {
for _, item := range rows {
var algorithm string
var count int
algorithm = item.Algorithm
@ -293,9 +264,6 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(
KeyCount: make(map[string]int),
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
// "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
@ -307,9 +275,9 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(
// UNIQUE (user_id, device_id, key_id, algorithm)
docId := fmt.Sprintf("%s_%s_%s_%s", keys.UserID, keys.DeviceID, keyID, algo)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
data := OneTimeKeyCosmos{
data := oneTimeKeyCosmos{
Algorithm: algo,
DeviceID: keys.DeviceID,
KeyID: keyID,
@ -317,8 +285,8 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(
UserID: keys.UserID,
}
dbData := &OneTimeKeyCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData := &oneTimeKeyCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
OneTimeKey: data,
}
@ -330,19 +298,24 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(
}
// rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": keys.UserID,
"@x3": keys.DeviceID,
}
// "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
response, err := queryOneTimeKeyAlgoCount(s, ctx, s.selectKeysCountStmt, params)
var rows []oneTimeKeyAlgoNumberCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectKeysCountStmt, params, &rows)
if err != nil {
return nil, err
}
for _, item := range response {
for _, item := range rows {
var algorithm string
var count int
algorithm = item.Algorithm
@ -361,24 +334,29 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
// "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": userID,
"@x3": deviceID,
"@x4": algorithm,
}
response, err := queryOneTimeKey(s, ctx, s.selectKeyByAlgorithmStmt, params)
var rows []oneTimeKeyCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectKeyByAlgorithmStmt, params, &rows)
if err != nil {
if err == cosmosdbutil.ErrNoRows {
return nil, nil
}
return nil, err
}
keyID = response[0].OneTimeKey.KeyID
keyJSONBytes := response[0].OneTimeKey.KeyJSON
err = deleteOneTimeKeyCore(s, ctx, response[0])
keyID = rows[0].OneTimeKey.KeyID
keyJSONBytes := rows[0].OneTimeKey.KeyJSON
err = deleteOneTimeKeyCore(s, ctx, rows[0])
if err != nil {
return nil, err
}

View file

@ -35,16 +35,16 @@ import (
// CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
// `
type StaleDeviceListCosmos struct {
type staleDeviceListCosmos struct {
UserID string `json:"user_id"`
Domain string `json:"domain"`
IsStale bool `json:"is_stale"`
AddedSecs int64 `json:"ts_added_secs"`
}
type StaleDeviceListCosmosData struct {
type staleDeviceListCosmosData struct {
cosmosdbapi.CosmosDocument
StaleDeviceList StaleDeviceListCosmos `json:"mx_keyserver_stale_device_list"`
StaleDeviceList staleDeviceListCosmos `json:"mx_keyserver_stale_device_list"`
}
// const upsertStaleDeviceListSQL = "" +
@ -72,8 +72,16 @@ type staleDeviceListsStatements struct {
tableName string
}
func getStaleDeviceList(s *staleDeviceListsStatements, ctx context.Context, pk string, docId string) (*StaleDeviceListCosmosData, error) {
response := StaleDeviceListCosmosData{}
func (s *staleDeviceListsStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *staleDeviceListsStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getStaleDeviceList(s *staleDeviceListsStatements, ctx context.Context, pk string, docId string) (*staleDeviceListCosmosData, error) {
response := staleDeviceListCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -89,26 +97,6 @@ func getStaleDeviceList(s *staleDeviceListsStatements, ctx context.Context, pk s
return &response, err
}
func queryStaleDeviceList(s *staleDeviceListsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]StaleDeviceListCosmosData, error) {
var response []StaleDeviceListCosmosData
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
var query = cosmosdbapi.GetQuery(qry, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func NewCosmosDBStaleDeviceListsTable(db *Database) (tables.StaleDeviceLists, error) {
s := &staleDeviceListsStatements{
db: db,
@ -131,26 +119,24 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context,
return err
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
// user_id TEXT PRIMARY KEY NOT NULL,
docId := userID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getStaleDeviceList(s, ctx, pk, cosmosDocId)
dbData, _ := getStaleDeviceList(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
dbData.SetUpdateTime()
dbData.StaleDeviceList.IsStale = isStale
dbData.StaleDeviceList.AddedSecs = time.Now().Unix()
} else {
data := StaleDeviceListCosmos{
data := staleDeviceListCosmos{
Domain: string(domain),
IsStale: isStale,
UserID: userID,
}
dbData = &StaleDeviceListCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &staleDeviceListCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
StaleDeviceList: data,
}
}
@ -165,17 +151,22 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context,
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
// we only query for 1 domain or all domains so optimise for those use cases
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
if len(domains) == 0 {
// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
// rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
params := map[string]interface{}{
"@x1": s.db.cosmosConfig.TenantName,
"@x2": dbCollectionName,
"@x2": s.getCollectionName(),
"@x3": true,
}
rows, err := queryStaleDeviceList(s, ctx, s.selectStaleDeviceListsWithDomainsStmt, params)
var rows []staleDeviceListCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectStaleDeviceListsStmt, params, &rows)
if err != nil {
return nil, err
@ -188,12 +179,17 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
// "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
// rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": true,
"@x3": string(domain),
}
rows, err := queryStaleDeviceList(s, ctx, s.selectStaleDeviceListsWithDomainsStmt, params)
var rows []staleDeviceListCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectStaleDeviceListsWithDomainsStmt, params, &rows)
if err != nil {
return nil, err
@ -207,7 +203,7 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
return result, nil
}
func rowsToUserIDs(ctx context.Context, rows []StaleDeviceListCosmosData) (result []string, err error) {
func rowsToUserIDs(ctx context.Context, rows []staleDeviceListCosmosData) (result []string, err error) {
for _, item := range rows {
var userID string
userID = item.StaleDeviceList.UserID

View file

@ -54,7 +54,7 @@ import (
// CREATE UNIQUE INDEX IF NOT EXISTS mediaapi_media_repository_index ON mediaapi_media_repository (media_id, media_origin);
// `
type MediaRepositoryCosmos struct {
type mediaRepositoryCosmos struct {
MediaID string `json:"media_id"`
MediaOrigin string `json:"media_origin"`
ContentType string `json:"content_type"`
@ -65,9 +65,9 @@ type MediaRepositoryCosmos struct {
UserID string `json:"user_id"`
}
type MediaRepositoryCosmosData struct {
type mediaRepositoryCosmosData struct {
cosmosdbapi.CosmosDocument
MediaRepository MediaRepositoryCosmos `json:"mx_mediaapi_media_repository"`
MediaRepository mediaRepositoryCosmos `json:"mx_mediaapi_media_repository"`
}
// const insertMediaSQL = `
@ -94,29 +94,16 @@ type mediaStatements struct {
tableName string
}
func queryMediaRepository(s *mediaStatements, ctx context.Context, qry string, params map[string]interface{}) ([]MediaRepositoryCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []MediaRepositoryCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *mediaStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func getMediaRepository(s *mediaStatements, ctx context.Context, pk string, docId string) (*MediaRepositoryCosmosData, error) {
response := MediaRepositoryCosmosData{}
func (s *mediaStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getMediaRepository(s *mediaStatements, ctx context.Context, pk string, docId string) (*mediaRepositoryCosmosData, error) {
response := mediaRepositoryCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -150,13 +137,11 @@ func (s *mediaStatements) insertMedia(
// INSERT INTO mediaapi_media_repository (media_id, media_origin, content_type, file_size_bytes, creation_ts, upload_name, base64hash, user_id)
// VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS mediaapi_media_repository_index ON mediaapi_media_repository (media_id, media_origin);
docId := fmt.Sprintf("%s_%s", mediaMetadata.MediaID, mediaMetadata.Origin)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
data := MediaRepositoryCosmos{
data := mediaRepositoryCosmos{
MediaID: string(mediaMetadata.MediaID),
MediaOrigin: string(mediaMetadata.Origin),
ContentType: string(mediaMetadata.ContentType),
@ -167,8 +152,8 @@ func (s *mediaStatements) insertMedia(
UserID: string(mediaMetadata.UserID),
}
dbData := &MediaRepositoryCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData := &mediaRepositoryCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
MediaRepository: data,
}
@ -207,15 +192,13 @@ func (s *mediaStatements) selectMedia(
// SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user_id FROM mediaapi_media_repository WHERE media_id = $1 AND media_origin = $2
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS mediaapi_media_repository_index ON mediaapi_media_repository (media_id, media_origin);
docId := fmt.Sprintf("%s_%s", mediaMetadata.MediaID, mediaMetadata.Origin)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
// err := s.selectMediaStmt.QueryRowContext(
// ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
row, err := getMediaRepository(s, ctx, pk, cosmosDocId)
row, err := getMediaRepository(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return nil, err
@ -245,9 +228,8 @@ func (s *mediaStatements) selectMediaByHash(
// SELECT content_type, file_size_bytes, creation_ts, upload_name, media_id, user_id FROM mediaapi_media_repository WHERE base64hash = $1 AND media_origin = $2
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": mediaHash,
"@x3": mediaOrigin,
}
@ -255,7 +237,12 @@ func (s *mediaStatements) selectMediaByHash(
// err := s.selectMediaStmt.QueryRowContext(
// ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
// ).Scan(
rows, err := queryMediaRepository(s, ctx, s.selectMediaByHashStmt, params)
var rows []mediaRepositoryCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectMediaByHashStmt, params, &rows)
if err != nil {
return nil, err

View file

@ -43,7 +43,7 @@ import (
// CREATE UNIQUE INDEX IF NOT EXISTS mediaapi_thumbnail_index ON mediaapi_thumbnail (media_id, media_origin, width, height, resize_method);
// `
type ThumbnailCosmos struct {
type thumbnailCosmos struct {
MediaID string `json:"media_id"`
MediaOrigin string `json:"media_origin"`
ContentType string `json:"content_type"`
@ -54,9 +54,9 @@ type ThumbnailCosmos struct {
ResizeMethod string `json:"resize_method"`
}
type ThumbnailCosmosData struct {
type thumbnailCosmosData struct {
cosmosdbapi.CosmosDocument
Thumbnail ThumbnailCosmos `json:"mx_mediaapi_thumbnail"`
Thumbnail thumbnailCosmos `json:"mx_mediaapi_thumbnail"`
}
// const insertThumbnailSQL = `
@ -85,29 +85,16 @@ type thumbnailStatements struct {
tableName string
}
func queryThumbnail(s *thumbnailStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ThumbnailCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []ThumbnailCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *thumbnailStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func getThumbnail(s *thumbnailStatements, ctx context.Context, pk string, docId string) (*ThumbnailCosmosData, error) {
response := ThumbnailCosmosData{}
func (s *thumbnailStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getThumbnail(s *thumbnailStatements, ctx context.Context, pk string, docId string) (*thumbnailCosmosData, error) {
response := thumbnailCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -142,7 +129,6 @@ func (s *thumbnailStatements) insertThumbnail(
// return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
// stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS mediaapi_thumbnail_index ON mediaapi_thumbnail (media_id, media_origin, width, height, resize_method);
docId := fmt.Sprintf("%s_%s_%d_%d_%s",
thumbnailMetadata.MediaMetadata.MediaID,
@ -151,8 +137,7 @@ func (s *thumbnailStatements) insertThumbnail(
thumbnailMetadata.ThumbnailSize.Height,
thumbnailMetadata.ThumbnailSize.ResizeMethod,
)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
// _, err := stmt.ExecContext(
// ctx,
@ -166,7 +151,7 @@ func (s *thumbnailStatements) insertThumbnail(
// thumbnailMetadata.ThumbnailSize.ResizeMethod,
// )
data := ThumbnailCosmos{
data := thumbnailCosmos{
MediaID: string(thumbnailMetadata.MediaMetadata.MediaID),
MediaOrigin: string(thumbnailMetadata.MediaMetadata.Origin),
ContentType: string(thumbnailMetadata.MediaMetadata.ContentType),
@ -177,8 +162,8 @@ func (s *thumbnailStatements) insertThumbnail(
ResizeMethod: string(thumbnailMetadata.ThumbnailSize.ResizeMethod),
}
dbData := &ThumbnailCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData := &thumbnailCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
Thumbnail: data,
}
@ -213,7 +198,6 @@ func (s *thumbnailStatements) selectThumbnail(
// SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 AND width = $3 AND height = $4 AND resize_method = $5
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS mediaapi_thumbnail_index ON mediaapi_thumbnail (media_id, media_origin, width, height, resize_method);
docId := fmt.Sprintf("%s_%s_%d_%d_%s",
mediaID,
@ -222,11 +206,10 @@ func (s *thumbnailStatements) selectThumbnail(
height,
resizeMethod,
)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
// row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID)
row, err := getThumbnail(s, ctx, pk, cosmosDocId)
row, err := getThumbnail(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return nil, err
@ -253,9 +236,8 @@ func (s *thumbnailStatements) selectThumbnails(
// SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": mediaID,
"@x3": mediaOrigin,
}
@ -263,7 +245,12 @@ func (s *thumbnailStatements) selectThumbnails(
// rows, err := s.selectThumbnailsStmt.QueryContext(
// ctx, mediaID, mediaOrigin,
// )
rows, err := queryThumbnail(s, ctx, s.selectThumbnailsStmt, params)
var rows []thumbnailCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectThumbnailsStmt, params, &rows)
if err != nil {
return nil, err

View file

@ -33,14 +33,14 @@ import (
// );
// `
type EventJSONCosmos struct {
type eventJSONCosmos struct {
EventNID int64 `json:"event_nid"`
EventJSON []byte `json:"event_json"`
}
type EventJSONCosmosData struct {
type eventJSONCosmosData struct {
cosmosdbapi.CosmosDocument
EventJSON EventJSONCosmos `json:"mx_roomserver_event_json"`
EventJSON eventJSONCosmos `json:"mx_roomserver_event_json"`
}
// const insertEventJSONSQL = `
@ -65,8 +65,16 @@ type eventJSONStatements struct {
tableName string
}
func getEventJSON(s *eventJSONStatements, ctx context.Context, pk string, docId string) (*EventJSONCosmosData, error) {
response := EventJSONCosmosData{}
func (s *eventJSONStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *eventJSONStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getEventJSON(s *eventJSONStatements, ctx context.Context, pk string, docId string) (*eventJSONCosmosData, error) {
response := eventJSONCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -82,27 +90,6 @@ func getEventJSON(s *eventJSONStatements, ctx context.Context, pk string, docId
return &response, err
}
func queryEventJSON(s *eventJSONStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventJSONCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []EventJSONCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func NewCosmosDBEventJSONTable(db *Database) (tables.EventJSON, error) {
s := &eventJSONStatements{
db: db,
@ -126,24 +113,21 @@ func (s *eventJSONStatements) InsertEventJSON(
// _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
// INSERT OR REPLACE INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := fmt.Sprintf("%d", eventNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getEventJSON(s, ctx, pk, cosmosDocId)
dbData, _ := getEventJSON(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
dbData.SetUpdateTime()
dbData.EventJSON.EventJSON = eventJSON
} else {
data := EventJSONCosmos{
data := eventJSONCosmos{
EventNID: int64(eventNID),
EventJSON: eventJSON,
}
dbData = &EventJSONCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &eventJSONCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
EventJSON: data,
}
}
@ -165,13 +149,17 @@ func (s *eventJSONStatements) BulkSelectEventJSON(
// WHERE event_nid IN ($1)
// ORDER BY event_nid ASC
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventNIDs,
}
response, err := queryEventJSON(s, ctx, s.bulkSelectEventJSONStmt, params)
var rows []eventJSONCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.bulkSelectEventJSONStmt, params, &rows)
if err != nil {
return nil, err
@ -183,7 +171,7 @@ func (s *eventJSONStatements) BulkSelectEventJSON(
// We might get fewer results than NIDs so we adjust the length of the slice before returning it.
results := make([]tables.EventJSONPair, len(eventNIDs))
i := 0
for _, item := range response {
for _, item := range rows {
result := &results[i]
result.EventNID = types.EventNID(item.EventJSON.EventNID)
result.EventJSON = item.EventJSON.EventJSON

View file

@ -37,14 +37,14 @@ import (
// ON CONFLICT DO NOTHING;
// `
type EventStateKeysCosmos struct {
type eventStateKeysCosmos struct {
EventStateKeyNID int64 `json:"event_state_key_nid"`
EventStateKey string `json:"event_state_key"`
}
type EventStateKeysCosmosData struct {
type eventStateKeysCosmosData struct {
cosmosdbapi.CosmosDocument
EventStateKeys EventStateKeysCosmos `json:"mx_roomserver_event_state_keys"`
EventStateKeys eventStateKeysCosmos `json:"mx_roomserver_event_state_keys"`
}
// Same as insertEventTypeNIDSQL
@ -84,29 +84,16 @@ type eventStateKeyStatements struct {
tableName string
}
func queryEventStateKeys(s *eventStateKeyStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventStateKeysCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []EventStateKeysCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *eventStateKeyStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func getEventStateKeys(s *eventStateKeyStatements, ctx context.Context, pk string, docId string) (*EventStateKeysCosmosData, error) {
response := EventStateKeysCosmosData{}
func (s *eventStateKeyStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getEventStateKeys(s *eventStateKeyStatements, ctx context.Context, pk string, docId string) (*eventStateKeysCosmosData, error) {
response := eventStateKeysCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -144,27 +131,25 @@ func ensureEventStateKeys(s *eventStateKeyStatements, ctx context.Context) {
// VALUES (1, '')
// ON CONFLICT DO NOTHING;
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// event_state_key TEXT NOT NULL UNIQUE
docId := ""
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
data := EventStateKeysCosmos{
data := eventStateKeysCosmos{
EventStateKey: "",
EventStateKeyNID: 1,
}
// event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT,
dbData := EventStateKeysCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData := eventStateKeysCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
EventStateKeys: data,
}
insertEventStateKeyCore(s, ctx, dbData)
}
func insertEventStateKeyCore(s *eventStateKeyStatements, ctx context.Context, dbData EventStateKeysCosmosData) error {
func insertEventStateKeyCore(s *eventStateKeyStatements, ctx context.Context, dbData eventStateKeysCosmosData) error {
err := cosmosdbapi.UpsertDocument(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
@ -189,15 +174,13 @@ func (s *eventStateKeyStatements) InsertEventStateKeyNID(
return 0, cosmosdbutil.ErrNoRows
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// event_state_key TEXT NOT NULL UNIQUE
docId := eventStateKey
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
existing, _ := getEventStateKeys(s, ctx, pk, cosmosDocId)
existing, _ := getEventStateKeys(s, ctx, s.getPartitionKey(), cosmosDocId)
var dbData EventStateKeysCosmosData
var dbData eventStateKeysCosmosData
if existing == nil {
//Not exists, we need to create a new one with a SEQ
eventStateKeyNIDSeq, seqErr := GetNextEventStateKeyNID(s, ctx)
@ -205,14 +188,14 @@ func (s *eventStateKeyStatements) InsertEventStateKeyNID(
return -1, seqErr
}
data := EventStateKeysCosmos{
data := eventStateKeysCosmos{
EventStateKey: eventStateKey,
EventStateKeyNID: eventStateKeyNIDSeq,
}
// event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT,
dbData = EventStateKeysCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = eventStateKeysCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
EventStateKeys: data,
}
} else {
@ -232,23 +215,27 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID(
// SELECT event_state_key_nid FROM roomserver_event_state_keys
// WHERE event_state_key = $1
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventStateKey,
}
response, err := queryEventStateKeys(s, ctx, s.selectEventStateKeyNIDStmt, params)
var rows []eventStateKeysCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectEventStateKeyNIDStmt, params, &rows)
if err != nil {
return 0, err
}
//See storage.assignStateKeyNID()
if len(response) == 0 {
if len(rows) == 0 {
return 0, cosmosdbutil.ErrNoRows
}
return types.EventStateKeyNID(response[0].EventStateKeys.EventStateKeyNID), err
return types.EventStateKeyNID(rows[0].EventStateKeys.EventStateKeyNID), err
}
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
@ -262,20 +249,24 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
// SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
// WHERE event_state_key IN ($1)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventStateKeys,
}
response, err := queryEventStateKeys(s, ctx, s.bulkSelectEventStateKeyNIDStmt, params)
var rows []eventStateKeysCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.bulkSelectEventStateKeyNIDStmt, params, &rows)
if err != nil {
return nil, err
}
result := make(map[string]types.EventStateKeyNID, len(eventStateKeys))
for _, item := range response {
for _, item := range rows {
result[item.EventStateKeys.EventStateKey] = types.EventStateKeyNID(item.EventStateKeys.EventStateKeyNID)
}
return result, nil
@ -288,19 +279,23 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKey(
// SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
// WHERE event_state_key_nid IN ($1)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventStateKeyNIDs,
}
response, err := queryEventStateKeys(s, ctx, s.bulkSelectEventStateKeyStmt, params)
var rows []eventStateKeysCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.bulkSelectEventStateKeyStmt, params, &rows)
if err != nil {
return nil, err
}
result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs))
for _, item := range response {
for _, item := range rows {
result[types.EventStateKeyNID(item.EventStateKeys.EventStateKeyNID)] = item.EventStateKeys.EventStateKey
}
return result, nil

View file

@ -42,16 +42,16 @@ import (
// (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
// `
type EventTypeCosmosData struct {
cosmosdbapi.CosmosDocument
EventType EventTypeCosmos `json:"mx_roomserver_event_type"`
}
type EventTypeCosmos struct {
type eventTypeCosmos struct {
EventTypeNID int64 `json:"event_type_nid"`
EventType string `json:"event_type"`
}
type eventTypeCosmosData struct {
cosmosdbapi.CosmosDocument
EventType eventTypeCosmos `json:"mx_roomserver_event_type"`
}
// Assign a new numeric event type ID.
// The usual case is that the event type is not in the database.
// In that case the ID will be assigned using the next value from the sequence.
@ -96,6 +96,14 @@ type eventTypeStatements struct {
tableName string
}
func (s *eventTypeStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *eventTypeStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func NewCosmosDBEventTypesTable(db *Database) (tables.EventTypes, error) {
s := &eventTypeStatements{
db: db,
@ -112,27 +120,6 @@ func NewCosmosDBEventTypesTable(db *Database) (tables.EventTypes, error) {
return s, nil
}
func queryEventTypes(s *eventTypeStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventTypeCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []EventTypeCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func (s *eventTypeStatements) InsertEventTypeNID(
ctx context.Context, txn *sql.Tx, eventType string,
) (types.EventTypeNID, error) {
@ -142,7 +129,7 @@ func (s *eventTypeStatements) InsertEventTypeNID(
return -1, seqErr
}
data := EventTypeCosmos{
data := eventTypeCosmos{
EventType: eventType,
EventTypeNID: eventTypeNIDSeq,
}
@ -156,17 +143,15 @@ func (s *eventTypeStatements) InsertEventTypeNID(
return types.EventTypeNID(dbData.EventTypeNID), err
}
func insertEventTypeCore(s *eventTypeStatements, ctx context.Context, eventType EventTypeCosmos) (*EventTypeCosmos, error) {
func insertEventTypeCore(s *eventTypeStatements, ctx context.Context, eventType eventTypeCosmos) (*eventTypeCosmos, error) {
// INSERT INTO roomserver_event_types (event_type) VALUES ($1)
// ON CONFLICT DO NOTHING;
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
//Unique on eventType
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, eventType.EventType)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), eventType.EventType)
var dbData = EventTypeCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
var dbData = eventTypeCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
EventType: eventType,
}
@ -200,53 +185,53 @@ func ensureEventTypes(s *eventTypeStatements, ctx context.Context) error {
// (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
// (1, 'm.room.create'),
_, err := insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 1, EventType: "m.room.create"})
_, err := insertEventTypeCore(s, context.Background(), eventTypeCosmos{EventTypeNID: 1, EventType: "m.room.create"})
if err != nil {
return err
}
// (2, 'm.room.power_levels'),
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 2, EventType: "m.room.power_levels"})
_, err = insertEventTypeCore(s, context.Background(), eventTypeCosmos{EventTypeNID: 2, EventType: "m.room.power_levels"})
if err != nil {
return err
}
// (3, 'm.room.join_rules'),
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 3, EventType: "m.room.join_rules"})
_, err = insertEventTypeCore(s, context.Background(), eventTypeCosmos{EventTypeNID: 3, EventType: "m.room.join_rules"})
if err != nil {
return err
}
// (4, 'm.room.third_party_invite'),
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 4, EventType: "m.room.third_party_invite"})
_, err = insertEventTypeCore(s, context.Background(), eventTypeCosmos{EventTypeNID: 4, EventType: "m.room.third_party_invite"})
if err != nil {
return err
}
// (5, 'm.room.member'),
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 5, EventType: "m.room.member"})
_, err = insertEventTypeCore(s, context.Background(), eventTypeCosmos{EventTypeNID: 5, EventType: "m.room.member"})
if err != nil {
return err
}
// (6, 'm.room.redaction'),
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 6, EventType: "m.room.redaction"})
_, err = insertEventTypeCore(s, context.Background(), eventTypeCosmos{EventTypeNID: 6, EventType: "m.room.redaction"})
if err != nil {
return err
}
// (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
_, err = insertEventTypeCore(s, context.Background(), EventTypeCosmos{EventTypeNID: 7, EventType: "m.room.history_visibility"})
_, err = insertEventTypeCore(s, context.Background(), eventTypeCosmos{EventTypeNID: 7, EventType: "m.room.history_visibility"})
if err != nil {
return err
}
return nil
}
func selectEventTypeCore(s *eventTypeStatements, ctx context.Context, eventType string) (*EventTypeCosmos, error) {
var response EventTypeCosmosData
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, eventType)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
func selectEventTypeCore(s *eventTypeStatements, ctx context.Context, eventType string) (*eventTypeCosmos, error) {
var response eventTypeCosmosData
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), eventType)
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
s.getPartitionKey(),
cosmosDocId,
&response)
@ -281,20 +266,24 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID(
// SELECT event_type, event_type_nid FROM roomserver_event_types
// WHERE event_type IN ($1)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventTypes,
}
response, err := queryEventTypes(s, ctx, s.bulkSelectEventTypeNIDStmt, params)
var rows []eventTypeCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.bulkSelectEventTypeNIDStmt, params, &rows)
if err != nil {
return nil, err
}
result := make(map[string]types.EventTypeNID, len(eventTypes))
for _, item := range response {
for _, item := range rows {
var eventType string
var eventTypeNID int64
eventType = item.EventType.EventType

View file

@ -46,7 +46,7 @@ import (
// );
// `
type EventCosmos struct {
type eventCosmos struct {
EventNID int64 `json:"event_nid"`
RoomNID int64 `json:"room_nid"`
EventTypeNID int64 `json:"event_type_nid"`
@ -60,13 +60,13 @@ type EventCosmos struct {
IsRejected bool `json:"is_rejected"`
}
type EventCosmosMaxDepth struct {
type eventCosmosMaxDepth struct {
Max int64 `json:"maxdepth"`
}
type EventCosmosData struct {
type eventCosmosData struct {
cosmosdbapi.CosmosDocument
Event EventCosmos `json:"mx_roomserver_event"`
Event eventCosmos `json:"mx_roomserver_event"`
}
// const insertEventSQL = `
@ -173,6 +173,14 @@ type eventStatements struct {
tableName string
}
func (s *eventStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *eventStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func NewCosmosDBEventsTable(db *Database) (tables.Events, error) {
s := &eventStatements{
db: db,
@ -207,29 +215,8 @@ func mapFromEventNIDArray(eventNIDs []types.EventNID) []int64 {
return result
}
func queryEvent(s *eventStatements, ctx context.Context, qry string, params map[string]interface{}) ([]EventCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []EventCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getEvent(s *eventStatements, ctx context.Context, pk string, docId string) (*EventCosmosData, error) {
response := EventCosmosData{}
func getEvent(s *eventStatements, ctx context.Context, pk string, docId string) (*eventCosmosData, error) {
response := eventCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -245,7 +232,7 @@ func getEvent(s *eventStatements, ctx context.Context, pk string, docId string)
return &response, err
}
func setEvent(s *eventStatements, ctx context.Context, event EventCosmosData) (*EventCosmosData, error) {
func setEvent(s *eventStatements, ctx context.Context, event eventCosmosData) (*eventCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(event.Pk, event.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
@ -288,7 +275,7 @@ func isReferenceSha256Same(
}
func isEventSame(
event EventCosmos,
event eventCosmos,
roomNID types.RoomNID,
eventTypeNID types.EventTypeNID,
eventStateKeyNID types.EventStateKeyNID,
@ -343,12 +330,10 @@ func (s *eventStatements) InsertEvent(
// event_nid INTEGER PRIMARY KEY AUTOINCREMENT,
// event_id TEXT NOT NULL UNIQUE,
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := eventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, errGet := getEvent(s, ctx, pk, cosmosDocId)
dbData, errGet := getEvent(s, ctx, s.getPartitionKey(), cosmosDocId)
// ON CONFLICT DO NOTHING;
// event_nid INTEGER PRIMARY KEY AUTOINCREMENT,
@ -358,7 +343,7 @@ func (s *eventStatements) InsertEvent(
if seqErr != nil {
return 0, 0, seqErr
}
data := EventCosmos{
data := eventCosmos{
AuthEventNIDs: mapFromEventNIDArray(authEventNIDs),
Depth: depth,
EventId: eventID,
@ -370,8 +355,8 @@ func (s *eventStatements) InsertEvent(
RoomNID: int64(roomNID),
}
dbData = &EventCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &eventCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
Event: data,
}
} else {
@ -424,11 +409,9 @@ func (s *eventStatements) SelectEvent(
) (types.EventNID, types.StateSnapshotNID, error) {
// "SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
docId := eventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
var response, err = getEvent(s, ctx, pk, cosmosDocId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var response, err = getEvent(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return 0, 0, err
}
@ -451,13 +434,17 @@ func (s *eventStatements) BulkSelectStateEventByID(
// " WHERE event_id IN ($1)" +
// " ORDER BY event_type_nid, event_state_key_nid ASC"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventIDs,
}
response, err := queryEvent(s, ctx, s.bulkSelectStateEventByIDStmt, params)
var rows []eventCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.bulkSelectStateEventByIDStmt, params, &rows)
if err != nil {
return nil, err
@ -467,9 +454,9 @@ func (s *eventStatements) BulkSelectStateEventByID(
// because of the unique constraint on event IDs.
// So we can allocate an array of the correct size now.
// We might get fewer results than IDs so we adjust the length of the slice before returning it.
results := make([]types.StateEntry, len(response))
results := make([]types.StateEntry, len(rows))
i := 0
for _, item := range response {
for _, item := range rows {
result := &results[i]
result.EventTypeNID = types.EventTypeNID(item.Event.EventTypeNID)
result.EventStateKeyNID = types.EventStateKeyNID(item.Event.EventStateKeyNID)
@ -502,9 +489,8 @@ func (s *eventStatements) BulkSelectStateEventByNID(
sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
// params := make([]interface{}, 0, len(eventNIDs)+len(eventTypeNIDArray)+len(eventStateKeyNIDArray))
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventNIDs,
}
// selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
@ -536,7 +522,13 @@ func (s *eventStatements) BulkSelectStateEventByNID(
// return nil, fmt.Errorf("s.db.Prepare: %w", err)
// }
// rows, err := selectStmt.QueryContext(ctx, params...)
rows, err := queryEvent(s, ctx, selectOrig, params)
var rows []eventCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), selectOrig, params, &rows)
if err != nil {
return nil, fmt.Errorf("selectStmt.QueryContext: %w", err)
@ -578,13 +570,17 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
// "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" +
// " WHERE event_id IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventIDs,
}
response, err := queryEvent(s, ctx, s.bulkSelectStateAtEventByIDStmt, params)
var rows []eventCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.bulkSelectStateAtEventByIDStmt, params, &rows)
if err != nil {
return nil, err
@ -592,7 +588,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
results := make([]types.StateAtEvent, len(eventIDs))
i := 0
for _, item := range response {
for _, item := range rows {
result := &results[i]
result.EventTypeNID = types.EventTypeNID(item.Event.EventTypeNID)
result.EventStateKeyNID = types.EventStateKeyNID(item.Event.EventStateKeyNID)
@ -620,19 +616,23 @@ func (s *eventStatements) UpdateEventState(
// "UPDATE roomserver_events SET state_snapshot_nid = $1 WHERE event_nid = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventNID,
}
response, err := queryEvent(s, ctx, s.updateEventStateStmt, params)
var rows []eventCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.updateEventStateStmt, params, &rows)
if err != nil {
return err
}
item := response[0]
item := rows[0]
item.Event.StateSnapshotNID = int64(stateNID)
var _, exReplace = setEvent(s, ctx, item)
@ -648,19 +648,23 @@ func (s *eventStatements) SelectEventSentToOutput(
// "SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventNID,
}
response, err := queryEvent(s, ctx, s.selectEventSentToOutputStmt, params)
var rows []eventCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectEventSentToOutputStmt, params, &rows)
if err != nil {
return false, err
}
item := response[0]
item := rows[0]
sentToOutput = item.Event.SentToOutput
return
}
@ -669,19 +673,23 @@ func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.
// "UPDATE roomserver_events SET sent_to_output = TRUE WHERE event_nid = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventNID,
}
response, err := queryEvent(s, ctx, s.updateEventSentToOutputStmt, params)
var rows []eventCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.updateEventSentToOutputStmt, params, &rows)
if err != nil {
return err
}
item := response[0]
item := rows[0]
item.Event.SentToOutput = true
var _, exReplace = setEvent(s, ctx, item)
@ -697,19 +705,23 @@ func (s *eventStatements) SelectEventID(
// "SELECT event_id FROM roomserver_events WHERE event_nid = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventNID,
}
response, err := queryEvent(s, ctx, s.selectEventIDStmt, params)
var rows []eventCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectEventIDStmt, params, &rows)
if err != nil {
return "", err
}
item := response[0]
item := rows[0]
eventNID = types.EventNID(item.Event.EventNID)
return
}
@ -724,21 +736,25 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference(
// "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" +
// " FROM roomserver_events WHERE event_nid IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventNIDs,
}
response, err := queryEvent(s, ctx, s.bulkSelectStateAtEventAndReferenceStmt, params)
var rows []eventCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.bulkSelectStateAtEventAndReferenceStmt, params, &rows)
if err != nil {
return nil, err
}
results := make([]types.StateAtEventAndReference, len(response))
results := make([]types.StateAtEventAndReference, len(rows))
i := 0
for _, item := range response {
for _, item := range rows {
result := &results[i]
result.EventTypeNID = types.EventTypeNID(item.Event.EventTypeNID)
result.EventStateKeyNID = types.EventStateKeyNID(item.Event.EventStateKeyNID)
@ -762,13 +778,17 @@ func (s *eventStatements) BulkSelectEventReference(
}
// "SELECT event_id, reference_sha256 FROM roomserver_events WHERE event_nid IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventNIDs,
}
response, err := queryEvent(s, ctx, s.bulkSelectEventReferenceStmt, params)
var rows []eventCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.bulkSelectEventReferenceStmt, params, &rows)
if err != nil {
return nil, err
@ -776,7 +796,7 @@ func (s *eventStatements) BulkSelectEventReference(
results := make([]gomatrixserverlib.EventReference, len(eventNIDs))
i := 0
for _, item := range response {
for _, item := range rows {
result := &results[i]
result.EventID = item.Event.EventId
result.EventSHA256 = item.Event.ReferenceSha256
@ -797,20 +817,24 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
// "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventNIDs,
}
response, err := queryEvent(s, ctx, s.bulkSelectEventIDStmt, params)
var rows []eventCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.bulkSelectEventIDStmt, params, &rows)
if err != nil {
return nil, err
}
i := 0
for _, item := range response {
for _, item := range rows {
eventNID := item.Event.EventNID
eventID := item.Event.EventId
results[types.EventNID(eventNID)] = eventID
@ -830,20 +854,24 @@ func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []str
}
// "SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventIDs,
}
response, err := queryEvent(s, ctx, s.bulkSelectEventNIDStmt, params)
var rows []eventCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.bulkSelectEventNIDStmt, params, &rows)
if err != nil {
return nil, err
}
results := make(map[string]types.EventNID, len(eventIDs))
for _, item := range response {
for _, item := range rows {
eventID := item.Event.EventId
eventNID := item.Event.EventNID
results[eventID] = types.EventNID(eventNID)
@ -857,28 +885,23 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
}
// "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var response []EventCosmosMaxDepth
params := map[string]interface{}{
"@x1": s.db.cosmosConfig.TenantName,
"@x2": dbCollectionName,
"@x2": s.getCollectionName(),
"@x3": eventNIDs,
}
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
var query = cosmosdbapi.GetQuery(selectMaxEventDepthSQL, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
var rows []eventCosmosMaxDepth
err := cosmosdbapi.PerformQueryAllPartitions(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
selectMaxEventDepthSQL, params, &rows)
if err != nil {
return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err)
}
return response[0].Max, nil
return rows[0].Max, nil
}
func (s *eventStatements) SelectRoomNIDsForEventNIDs(
@ -890,20 +913,24 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs(
// "SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventNIDs,
}
response, err := queryEvent(s, ctx, selectRoomNIDsForEventNIDsSQL, params)
var rows []eventCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), selectRoomNIDsForEventNIDsSQL, params, &rows)
if err != nil {
return nil, err
}
result := make(map[types.EventNID]types.RoomNID)
for _, item := range response {
for _, item := range rows {
roomNID := types.RoomNID(item.Event.RoomNID)
eventNID := types.EventNID(item.Event.EventNID)
result[eventNID] = roomNID

View file

@ -40,7 +40,7 @@ import (
// WHERE NOT retired;
// `
type InviteCosmos struct {
type inviteCosmos struct {
InviteEventID string `json:"invite_event_id"`
RoomNID int64 `json:"room_nid"`
TargetNID int64 `json:"target_nid"`
@ -49,9 +49,9 @@ type InviteCosmos struct {
InviteEventJSON []byte `json:"invite_event_json"`
}
type InviteCosmosData struct {
type inviteCosmosData struct {
cosmosdbapi.CosmosDocument
Invite InviteCosmos `json:"mx_roomserver_invite"`
Invite inviteCosmos `json:"mx_roomserver_invite"`
}
// const insertInviteEventSQL = "" +
@ -93,29 +93,16 @@ type inviteStatements struct {
tableName string
}
func queryInvite(s *inviteStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InviteCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []InviteCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *inviteStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func getInvite(s *inviteStatements, ctx context.Context, pk string, docId string) (*InviteCosmosData, error) {
response := InviteCosmosData{}
func (s *inviteStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getInvite(s *inviteStatements, ctx context.Context, pk string, docId string) (*inviteCosmosData, error) {
response := inviteCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -131,7 +118,7 @@ func getInvite(s *inviteStatements, ctx context.Context, pk string, docId string
return &response, err
}
func setInvite(s *inviteStatements, ctx context.Context, invite InviteCosmosData) (*InviteCosmosData, error) {
func setInvite(s *inviteStatements, ctx context.Context, invite inviteCosmosData) (*inviteCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(invite.Pk, invite.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
@ -168,9 +155,11 @@ func (s *inviteStatements) InsertInviteEvent(
// " sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" +
// " ON CONFLICT DO NOTHING"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// invite_event_id TEXT PRIMARY KEY,
docId := inviteEventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
data := InviteCosmos{
data := inviteCosmos{
InviteEventID: inviteEventID,
InviteEventJSON: inviteEventJSON,
Retired: false,
@ -179,13 +168,8 @@ func (s *inviteStatements) InsertInviteEvent(
TargetNID: int64(targetUserNID),
}
// invite_event_id TEXT PRIMARY KEY,
docId := inviteEventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var dbData = InviteCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
var dbData = inviteCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
Invite: data,
}
@ -216,20 +200,24 @@ func (s *inviteStatements) UpdateInviteRetired(
// " AND NOT retired"
// gather all the event IDs we will retire
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": targetUserNID,
"@x3": roomNID,
}
response, err := queryInvite(s, ctx, s.selectInvitesAboutToRetireStmt, params)
var rows []inviteCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectInvitesAboutToRetireStmt, params, &rows)
if err != nil {
return
}
for _, item := range response {
for _, item := range rows {
eventIDs = append(eventIDs, item.Invite.InviteEventID)
// UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
@ -249,14 +237,18 @@ func (s *inviteStatements) SelectInviteActiveForUserInRoom(
// SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomNID,
"@x3": targetUserNID,
}
response, err := queryInvite(s, ctx, s.selectInviteActiveForUserInRoomStmt, params)
var rows []inviteCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectInviteActiveForUserInRoomStmt, params, &rows)
if err != nil {
return nil, nil, err
@ -264,7 +256,7 @@ func (s *inviteStatements) SelectInviteActiveForUserInRoom(
var result []types.EventStateKeyNID
var eventIDs []string
for _, item := range response {
for _, item := range rows {
var eventID = item.Invite.InviteEventID
var senderUserNID = item.Invite.SenderNID
result = append(result, types.EventStateKeyNID(senderUserNID))

View file

@ -41,7 +41,7 @@ import (
// );
// `
type MembershipCosmos struct {
type membershipCosmos struct {
RoomNID int64 `json:"room_nid"`
TargetNID int64 `json:"target_nid"`
SenderNID int64 `json:"sender_nid"`
@ -51,9 +51,9 @@ type MembershipCosmos struct {
Forgotten bool `json:"forgotten"`
}
type MembershipCosmosData struct {
type membershipCosmosData struct {
cosmosdbapi.CosmosDocument
Membership MembershipCosmos `json:"mx_roomserver_membership"`
Membership membershipCosmos `json:"mx_roomserver_membership"`
}
type MembershipJoinedCountCosmosData struct {
@ -199,29 +199,24 @@ type membershipStatements struct {
tableName string
}
func queryMembership(s *membershipStatements, ctx context.Context, qry string, params map[string]interface{}) ([]MembershipCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []MembershipCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *membershipStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func getMembership(s *membershipStatements, ctx context.Context, pk string, docId string) (*MembershipCosmosData, error) {
response := MembershipCosmosData{}
func (s *membershipStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func (s *membershipStatements) getCollectionEventStateKeys() string {
return "roomserver_event_state_keys"
}
func (s *membershipStatements) getPartitionKeyEventStateKeys() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionEventStateKeys())
}
func getMembership(s *membershipStatements, ctx context.Context, pk string, docId string) (*membershipCosmosData, error) {
response := membershipCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -237,7 +232,7 @@ func getMembership(s *membershipStatements, ctx context.Context, pk string, docI
return &response, err
}
func setMembership(s *membershipStatements, ctx context.Context, membership MembershipCosmosData) (*MembershipCosmosData, error) {
func setMembership(s *membershipStatements, ctx context.Context, membership membershipCosmosData) (*membershipCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(membership.Pk, membership.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
@ -284,15 +279,12 @@ func (s *membershipStatements) InsertMembership(
// " VALUES ($1, $2, $3)" +
// " ON CONFLICT DO NOTHING"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_nid, target_nid)
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
// " ON CONFLICT DO NOTHING"
exists, _ := getMembership(s, ctx, pk, cosmosDocId)
exists, _ := getMembership(s, ctx, s.getPartitionKey(), cosmosDocId)
if exists != nil {
exists.Membership.RoomNID = int64(roomNID)
exists.Membership.TargetNID = int64(targetUserNID)
@ -302,7 +294,7 @@ func (s *membershipStatements) InsertMembership(
return errSet
}
data := MembershipCosmos{
data := membershipCosmos{
EventNID: 0,
Forgotten: false,
MembershipNID: 1,
@ -312,8 +304,8 @@ func (s *membershipStatements) InsertMembership(
TargetNID: int64(targetUserNID),
}
var dbData = MembershipCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
var dbData = membershipCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
Membership: data,
}
@ -336,11 +328,10 @@ func (s *membershipStatements) SelectMembershipForUpdate(
// "SELECT membership_nid FROM roomserver_membership" +
// " WHERE room_nid = $1 AND target_nid = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
response, err := getMembership(s, ctx, pk, cosmosDocId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
response, err := getMembership(s, ctx, s.getPartitionKey(), cosmosDocId)
if response != nil {
membership = tables.MembershipState(response.Membership.MembershipNID)
}
@ -355,11 +346,9 @@ func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
// "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" +
// " WHERE room_nid = $1 AND target_nid = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
response, err := getMembership(s, ctx, pk, cosmosDocId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
response, err := getMembership(s, ctx, s.getPartitionKey(), cosmosDocId)
if response != nil {
eventNID = types.EventNID(response.Membership.EventNID)
forgotten = response.Membership.Forgotten
@ -374,9 +363,8 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
) (eventNIDs []types.EventNID, err error) {
var selectStmt string
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomNID,
}
if localOnly {
@ -390,12 +378,19 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
// " WHERE room_nid = $1 and forgotten = false"
selectStmt = s.selectMembershipsFromRoomStmt
}
response, err := queryMembership(s, ctx, selectStmt, params)
var rows []membershipCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), selectStmt, params, &rows)
if err != nil {
return nil, err
}
for _, item := range response {
for _, item := range rows {
eventNIDs = append(eventNIDs, types.EventNID(item.Membership.EventNID))
}
return
@ -407,9 +402,8 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
) (eventNIDs []types.EventNID, err error) {
var stmt string
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomNID,
"@x3": membership,
}
@ -423,12 +417,18 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
// " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false"
stmt = s.selectMembershipsFromRoomAndMembershipStmt
}
response, err := queryMembership(s, ctx, stmt, params)
var rows []membershipCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), stmt, params, &rows)
if err != nil {
return nil, err
}
for _, item := range response {
for _, item := range rows {
eventNIDs = append(eventNIDs, types.EventNID(item.Membership.EventNID))
}
return
@ -443,11 +443,9 @@ func (s *membershipStatements) UpdateMembership(
// "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" +
// " WHERE room_nid = $5 AND target_nid = $6"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
dbData, err := getMembership(s, ctx, pk, cosmosDocId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, err := getMembership(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return err
@ -468,18 +466,23 @@ func (s *membershipStatements) SelectRoomsWithMembership(
// "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": membershipState,
"@x3": userID,
}
response, err := queryMembership(s, ctx, s.selectRoomsWithMembershipStmt, params)
var rows []membershipCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectRoomsWithMembershipStmt, params, &rows)
if err != nil {
return nil, err
}
var roomNIDs []types.RoomNID
for _, item := range response {
for _, item := range rows {
roomNIDs = append(roomNIDs, types.RoomNID(item.Membership.RoomNID))
}
return roomNIDs, nil
@ -495,30 +498,24 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
// " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
// " GROUP BY target_nid"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomNIDs,
}
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []MembershipJoinedCountCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectJoinedUsersSetForRoomsSQL, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
var rows []MembershipJoinedCountCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
s.getPartitionKey(), selectJoinedUsersSetForRoomsSQL, params, &rows)
if err != nil {
return nil, err
}
result := make(map[types.EventStateKeyNID]int)
for _, item := range response {
for _, item := range rows {
userID := types.EventStateKeyNID(item.TargetNID)
count := item.RoomCount
result[userID] = count
@ -532,20 +529,25 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
var nid types.RoomNID
// err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
//
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": tables.MembershipStateJoin,
"@x3": roomNID,
}
response, err := queryMembership(s, ctx, s.selectLocalServerInRoomStmt, params)
if len(response) == 0 {
var rows []membershipCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectLocalServerInRoomStmt, params, &rows)
if len(rows) == 0 {
if err == cosmosdbutil.ErrNoRows {
return false, nil
}
return false, err
}
nid = types.RoomNID(response[0].Membership.RoomNID)
nid = types.RoomNID(rows[0].Membership.RoomNID)
found := nid > 0
return found, nil
@ -564,27 +566,20 @@ func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID t
"select * from c where c._cn = @x1 " +
"and (endswith(c.mx_roomserver_event_state_keys.event_state_key, \":\") or c.mx_roomserver_event_state_keys.event_state_key = @x2) "
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionEventStateKeys(),
"@x2": serverName,
}
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var eventStateKeys []EventStateKeysCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectEventStateKeyNIDSQL, params) //
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
var rowsEventsStateKeys []eventStateKeysCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&eventStateKeys,
optionsQry)
s.getPartitionKeyEventStateKeys(), selectEventStateKeyNIDSQL, params, &rowsEventsStateKeys)
eventStateKeyNids := []int64{}
for _, item := range eventStateKeys {
for _, item := range rowsEventsStateKeys {
eventStateKeyNids = append(eventStateKeyNids, item.EventStateKeys.EventStateKeyNID)
}
@ -595,19 +590,25 @@ func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID t
// err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
params = map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": tables.MembershipStateJoin,
"@x3": roomNID,
"@x4": eventStateKeyNids,
}
response, err := queryMembership(s, ctx, s.selectServerInRoomStmt, params)
if len(response) == 0 {
var rows []membershipCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectServerInRoomStmt, params, &rows)
if len(rows) == 0 {
if err == cosmosdbutil.ErrNoRows {
return false, nil
}
return false, err
}
nid = types.RoomNID(response[0].Membership.RoomNID)
nid = types.RoomNID(rows[0].Membership.RoomNID)
return roomNID == nid, nil
}
@ -616,33 +617,26 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
// " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
// ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": userID,
"@x3": searchString,
"@x4": limit,
}
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var responseDistinctRoom []MembershipCosmos
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(selectKnownUsersSQLDistinctRoom, params) //
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
var rowsDistinctRoom []membershipCosmos
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&responseDistinctRoom,
optionsQry)
s.getPartitionKey(), selectKnownUsersSQLDistinctRoom, params, &rowsDistinctRoom)
if err != nil {
return nil, err
}
rooms := []int64{}
for _, item := range responseDistinctRoom {
for _, item := range rowsDistinctRoom {
rooms = append(rooms, item.RoomNID)
}
@ -651,47 +645,40 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
// " WHERE room_nid IN (" +
params = map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": rooms,
}
var responseRooms []MembershipCosmos
query = cosmosdbapi.GetQuery(selectKnownUsersSQLRooms, params)
_, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
var rows []membershipCosmos
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&responseRooms,
optionsQry)
s.getPartitionKey(), selectKnownUsersSQLRooms, params, &rows)
if err != nil {
return nil, err
}
targetNIDs := []int64{}
for _, item := range responseRooms {
for _, item := range rows {
targetNIDs = append(targetNIDs, item.TargetNID)
}
// HACK: Joined table
var dbCollectionNameEventStateKeys = cosmosdbapi.GetCollectionName(s.db.databaseName, "event_state_keys")
params = map[string]interface{}{
"@x1": dbCollectionNameEventStateKeys,
"@x1": s.getCollectionEventStateKeys(),
"@x2": targetNIDs,
}
bulkSelectEventStateKeyStmt := "select * from c where c._cn = @x1 and ARRAY_CONTAINS(@x2, c.mx_roomserver_event_state_keys.event_state_key_nid)"
var responseEventStateKeys []EventStateKeysCosmos
query = cosmosdbapi.GetQuery(bulkSelectEventStateKeyStmt, params)
_, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
var rowsEventStateKeys []eventStateKeysCosmos
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&responseEventStateKeys,
optionsQry)
s.getPartitionKeyEventStateKeys(), bulkSelectEventStateKeyStmt, params, &rowsEventStateKeys)
if err != nil {
return nil, err
@ -700,7 +687,7 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
// SELECT DISTINCT event_state_key
result := []string{}
for _, item := range responseEventStateKeys {
for _, item := range rowsEventStateKeys {
userID := item.EventStateKey
result = append(result, userID)
}
@ -716,11 +703,9 @@ func (s *membershipStatements) UpdateForgetMembership(
// "UPDATE roomserver_membership SET forgotten = $1" +
// " WHERE room_nid = $2 AND target_nid = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := fmt.Sprintf("%d_%d", roomNID, targetUserNID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
dbData, err := getMembership(s, ctx, pk, cosmosDocId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, err := getMembership(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return err

View file

@ -43,15 +43,15 @@ import (
// );
// `
type PreviousEventCosmos struct {
type previousEventCosmos struct {
PreviousEventID string `json:"previous_event_id"`
PreviousReferenceSha256 []byte `json:"previous_reference_sha256"`
EventNIDs string `json:"event_nids"`
}
type PreviousEventCosmosData struct {
type previousEventCosmosData struct {
cosmosdbapi.CosmosDocument
PreviousEvent PreviousEventCosmos `json:"mx_roomserver_previous_event"`
PreviousEvent previousEventCosmos `json:"mx_roomserver_previous_event"`
}
// Insert an entry into the previous_events table.
@ -85,8 +85,16 @@ type previousEventStatements struct {
tableName string
}
func getPreviousEvent(s *previousEventStatements, ctx context.Context, pk string, docId string) (*PreviousEventCosmosData, error) {
response := PreviousEventCosmosData{}
func (s *previousEventStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *previousEventStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getPreviousEvent(s *previousEventStatements, ctx context.Context, pk string, docId string) (*previousEventCosmosData, error) {
response := previousEventCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -125,18 +133,15 @@ func (s *previousEventStatements) InsertPreviousEvent(
) error {
eventNIDAsString := fmt.Sprintf("%d", eventNID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (previous_event_id, previous_reference_sha256)
// TODO: Check value
// docId := fmt.Sprintf("%s_%s", previousEventID, previousEventReferenceSHA256)
docId := previousEventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
// SELECT 1 FROM roomserver_previous_events
// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
existing, err := getPreviousEvent(s, ctx, pk, cosmosDocId)
existing, err := getPreviousEvent(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
if err != cosmosdbutil.ErrNoRows {
@ -144,17 +149,17 @@ func (s *previousEventStatements) InsertPreviousEvent(
}
}
var dbData PreviousEventCosmosData
var dbData previousEventCosmosData
// Doesnt exist, create a new one
if existing == nil {
data := PreviousEventCosmos{
data := previousEventCosmos{
EventNIDs: "",
PreviousEventID: previousEventID,
PreviousReferenceSha256: previousEventReferenceSHA256,
}
dbData = PreviousEventCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = previousEventCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
PreviousEvent: data,
}
} else {
@ -192,18 +197,16 @@ func (s *previousEventStatements) InsertPreviousEvent(
func (s *previousEventStatements) SelectPreviousEventExists(
ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte,
) error {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (previous_event_id, previous_reference_sha256)
// TODO: Check value
// docId := fmt.Sprintf("%s_%s", previousEventID, previousEventReferenceSHA256)
docId := eventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, string(docId))
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), string(docId))
// SELECT 1 FROM roomserver_previous_events
// WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
dbData, err := getPreviousEvent(s, ctx, pk, cosmosDocId)
dbData, err := getPreviousEvent(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return err
}

View file

@ -34,14 +34,14 @@ import (
// );
// `
type PublishCosmos struct {
type publishCosmos struct {
RoomID string `json:"room_id"`
Published bool `json:"published"`
}
type PublishCosmosData struct {
type publishCosmosData struct {
cosmosdbapi.CosmosDocument
Publish PublishCosmos `json:"mx_roomserver_publish"`
Publish publishCosmos `json:"mx_roomserver_publish"`
}
// const upsertPublishedSQL = "" +
@ -64,29 +64,16 @@ type publishedStatements struct {
tableName string
}
func queryPublish(s *publishedStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PublishCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []PublishCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *publishedStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func getPublish(s *publishedStatements, ctx context.Context, pk string, docId string) (*PublishCosmosData, error) {
response := PublishCosmosData{}
func (s *publishedStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getPublish(s *publishedStatements, ctx context.Context, pk string, docId string) (*publishCosmosData, error) {
response := publishCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -125,24 +112,21 @@ func (s *publishedStatements) UpsertRoomPublished(
// "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// room_id TEXT NOT NULL PRIMARY KEY,
docId := roomID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
dbData, _ := getPublish(s, ctx, pk, cosmosDocId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getPublish(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
dbData.SetUpdateTime()
dbData.Publish.Published = published
} else {
data := PublishCosmos{
data := publishCosmos{
RoomID: roomID,
Published: false,
}
dbData = &PublishCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &publishCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
Publish: data,
}
}
@ -161,13 +145,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
) (published bool, err error) {
// "SELECT published FROM roomserver_published WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// room_id TEXT NOT NULL PRIMARY KEY,
docId := roomID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
response, err := getPublish(s, ctx, pk, cosmosDocId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
response, err := getPublish(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return false, err
}
@ -179,19 +160,24 @@ func (s *publishedStatements) SelectAllPublishedRooms(
) ([]string, error) {
// "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": published,
}
response, err := queryPublish(s, ctx, s.selectAllPublishedStmt, params)
var rows []publishCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectAllPublishedStmt, params, &rows)
if err != nil {
return nil, err
}
var roomIDs []string
for _, item := range response {
for _, item := range rows {
roomIDs = append(roomIDs, item.Publish.RoomID)
}
return roomIDs, nil

View file

@ -37,15 +37,15 @@ import (
// );
// `
type RedactionCosmos struct {
type redactionCosmos struct {
RedactionEventID string `json:"redaction_event_id"`
RedactsEventID string `json:"redacts_event_id"`
Validated bool `json:"validated"`
}
type RedactionCosmosData struct {
type redactionCosmosData struct {
cosmosdbapi.CosmosDocument
Redaction RedactionCosmos `json:"mx_roomserver_redaction"`
Redaction redactionCosmos `json:"mx_roomserver_redaction"`
}
// const insertRedactionSQL = "" +
@ -74,29 +74,16 @@ type redactionStatements struct {
tableName string
}
func queryRedaction(s *redactionStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RedactionCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []RedactionCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *redactionStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func getRedaction(s *redactionStatements, ctx context.Context, pk string, docId string) (*RedactionCosmosData, error) {
response := RedactionCosmosData{}
func (s *redactionStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getRedaction(s *redactionStatements, ctx context.Context, pk string, docId string) (*redactionCosmosData, error) {
response := redactionCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -112,7 +99,7 @@ func getRedaction(s *redactionStatements, ctx context.Context, pk string, docId
return &response, err
}
func setRedaction(s *redactionStatements, ctx context.Context, redaction RedactionCosmosData) (*RedactionCosmosData, error) {
func setRedaction(s *redactionStatements, ctx context.Context, redaction redactionCosmosData) (*redactionCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(redaction.Pk, redaction.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
@ -146,20 +133,18 @@ func (s *redactionStatements) InsertRedaction(
// "INSERT OR IGNORE INTO roomserver_redactions (redaction_event_id, redacts_event_id, validated)" +
// " VALUES ($1, $2, $3)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// redaction_event_id TEXT PRIMARY KEY,
docId := info.RedactionEventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
data := RedactionCosmos{
data := redactionCosmos{
RedactionEventID: info.RedactionEventID,
RedactsEventID: info.RedactsEventID,
Validated: info.Validated,
}
var dbData = RedactionCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
var dbData = redactionCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
Redaction: data,
}
@ -188,13 +173,11 @@ func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
// "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
// " WHERE redaction_event_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// redaction_event_id TEXT PRIMARY KEY,
docId := redactionEventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
response, err := getRedaction(s, ctx, pk, cosmosDocId)
response, err := getRedaction(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
info = nil
err = err
@ -221,27 +204,31 @@ func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted(
// "SELECT redaction_event_id, redacts_event_id, validated FROM roomserver_redactions" +
// " WHERE redacts_event_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventID,
}
response, err := queryRedaction(s, ctx, s.selectRedactionInfoByEventBeingRedactedStmt, params)
var rows []redactionCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectRedactionInfoByEventBeingRedactedStmt, params, &rows)
if err != nil {
return nil, err
}
if len(response) == 0 {
if len(rows) == 0 {
info = nil
err = nil
return
}
// TODO: Check this is ok to return the 1st one
*info = tables.RedactionInfo{
RedactionEventID: response[0].Redaction.RedactionEventID,
RedactsEventID: response[0].Redaction.RedactsEventID,
Validated: response[0].Redaction.Validated,
RedactionEventID: rows[0].Redaction.RedactionEventID,
RedactsEventID: rows[0].Redaction.RedactsEventID,
Validated: rows[0].Redaction.Validated,
}
return
}
@ -252,13 +239,11 @@ func (s *redactionStatements) MarkRedactionValidated(
// " UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// redaction_event_id TEXT PRIMARY KEY,
docId := redactionEventID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
response, err := getRedaction(s, ctx, pk, cosmosDocId)
response, err := getRedaction(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return err
}

View file

@ -33,15 +33,15 @@ import (
// CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id);
// `
type RoomAliasCosmos struct {
type roomAliasCosmos struct {
Alias string `json:"alias"`
RoomID string `json:"room_id"`
CreatorID string `json:"creator_id"`
}
type RoomAliasCosmosData struct {
type roomAliasCosmosData struct {
cosmosdbapi.CosmosDocument
RoomAlias RoomAliasCosmos `json:"mx_roomserver_room_alias"`
RoomAlias roomAliasCosmos `json:"mx_roomserver_room_alias"`
}
// const insertRoomAliasSQL = `
@ -75,29 +75,16 @@ type roomAliasesStatements struct {
tableName string
}
func queryRoomAlias(s *roomAliasesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RoomAliasCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []RoomAliasCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *roomAliasesStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func getRoomAlias(s *roomAliasesStatements, ctx context.Context, pk string, docId string) (*RoomAliasCosmosData, error) {
response := RoomAliasCosmosData{}
func (s *roomAliasesStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getRoomAlias(s *roomAliasesStatements, ctx context.Context, pk string, docId string) (*roomAliasCosmosData, error) {
response := roomAliasCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -137,21 +124,19 @@ func (s *roomAliasesStatements) InsertRoomAlias(
) error {
// INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3)
data := RoomAliasCosmos{
// alias TEXT NOT NULL PRIMARY KEY,
docId := alias
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
data := roomAliasCosmos{
Alias: alias,
CreatorID: creatorUserID,
RoomID: roomID,
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// alias TEXT NOT NULL PRIMARY KEY,
docId := alias
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var dbData = RoomAliasCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
var dbData = roomAliasCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
RoomAlias: data,
}
@ -172,13 +157,10 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias(
// SELECT room_id FROM roomserver_room_aliases WHERE alias = $1
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// alias TEXT NOT NULL PRIMARY KEY,
docId := alias
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
response, err := getRoomAlias(s, ctx, pk, cosmosDocId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
response, err := getRoomAlias(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return "", err
@ -198,19 +180,23 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
// SELECT alias FROM roomserver_room_aliases WHERE room_id = $1
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
}
response, err := queryRoomAlias(s, ctx, s.selectAliasesFromRoomIDStmt, params)
var rows []roomAliasCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectAliasesFromRoomIDStmt, params, &rows)
if err != nil {
return nil, err
}
for _, item := range response {
for _, item := range rows {
aliases = append(aliases, item.RoomAlias.Alias)
}
@ -223,13 +209,10 @@ func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
// SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// alias TEXT NOT NULL PRIMARY KEY,
docId := alias
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
response, err := getRoomAlias(s, ctx, pk, cosmosDocId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
response, err := getRoomAlias(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return "", err
@ -248,11 +231,9 @@ func (s *roomAliasesStatements) DeleteRoomAlias(
// DELETE FROM roomserver_room_aliases WHERE alias = $1
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
docId := alias
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var options = cosmosdbapi.GetDeleteDocumentOptions(s.getPartitionKey())
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,

View file

@ -40,12 +40,7 @@ import (
// );
// `
type RoomCosmosData struct {
cosmosdbapi.CosmosDocument
Room RoomCosmos `json:"mx_roomserver_room"`
}
type RoomCosmos struct {
type roomCosmos struct {
RoomNID int64 `json:"room_nid"`
RoomID string `json:"room_id"`
LatestEventNIDs []int64 `json:"latest_event_nids"`
@ -54,6 +49,11 @@ type RoomCosmos struct {
RoomVersion string `json:"room_version"`
}
type roomCosmosData struct {
cosmosdbapi.CosmosDocument
Room roomCosmos `json:"mx_roomserver_room"`
}
// Same as insertEventTypeNIDSQL
// const insertRoomNIDSQL = `
// INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2)
@ -113,6 +113,14 @@ type roomStatements struct {
tableName string
}
func (s *roomStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *roomStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func NewCosmosDBRoomsTable(db *Database) (tables.Rooms, error) {
s := &roomStatements{
db: db,
@ -139,29 +147,8 @@ func mapToRoomEventNIDArray(eventNIDs []int64) []types.EventNID {
return result
}
func queryRoom(s *roomStatements, ctx context.Context, qry string, params map[string]interface{}) ([]RoomCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []RoomCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getRoom(s *roomStatements, ctx context.Context, pk string, docId string) (*RoomCosmosData, error) {
response := RoomCosmosData{}
func getRoom(s *roomStatements, ctx context.Context, pk string, docId string) (*roomCosmosData, error) {
response := roomCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -177,7 +164,7 @@ func getRoom(s *roomStatements, ctx context.Context, pk string, docId string) (*
return &response, err
}
func setRoom(s *roomStatements, ctx context.Context, room RoomCosmosData) (*RoomCosmosData, error) {
func setRoom(s *roomStatements, ctx context.Context, room roomCosmosData) (*roomCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(room.Pk, room.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
@ -193,19 +180,23 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
// "SELECT room_id FROM roomserver_rooms"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
}
response, err := queryRoom(s, ctx, s.selectRoomIDsStmt, params)
var rows []roomCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectRoomIDsStmt, params, &rows)
if err != nil {
return nil, err
}
var roomIDs []string
for _, item := range response {
for _, item := range rows {
roomIDs = append(roomIDs, item.Room.RoomID)
}
return roomIDs, nil
@ -216,12 +207,10 @@ func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*ty
// "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// room_id TEXT NOT NULL UNIQUE,
docId := roomID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
room, err := getRoom(s, ctx, pk, cosmosDocId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
room, err := getRoom(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
if err == cosmosdbutil.ErrNoRows {
@ -242,16 +231,13 @@ func (s *roomStatements) InsertRoomNID(
roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (roomNID types.RoomNID, err error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// INSERT INTO roomserver_rooms (room_id, room_version) VALUES ($1, $2)
// ON CONFLICT DO NOTHING;
// room_id TEXT NOT NULL UNIQUE,
docId := roomID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, errGet := getRoom(s, ctx, pk, cosmosDocId)
dbData, errGet := getRoom(s, ctx, s.getPartitionKey(), cosmosDocId)
if errGet == cosmosdbutil.ErrNoRows {
// room_nid INTEGER PRIMARY KEY AUTOINCREMENT,
@ -260,14 +246,14 @@ func (s *roomStatements) InsertRoomNID(
return 0, seqErr
}
data := RoomCosmos{
data := roomCosmos{
RoomNID: int64(roomNIDSeq),
RoomID: roomID,
RoomVersion: string(roomVersion),
}
dbData = &RoomCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &roomCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
Room: data,
}
} else {
@ -298,12 +284,10 @@ func (s *roomStatements) SelectRoomNID(
// "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// room_id TEXT NOT NULL UNIQUE,
docId := roomID
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
room, err := getRoom(s, ctx, pk, cosmosDocId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
room, err := getRoom(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return 0, err
@ -321,25 +305,29 @@ func (s *roomStatements) SelectLatestEventNIDs(
// "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomNID,
}
response, err := queryRoom(s, ctx, s.selectLatestEventNIDsStmt, params)
var rows []roomCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectLatestEventNIDsStmt, params, &rows)
if err != nil {
return nil, 0, err
}
// TODO: Check the error handling
if len(response) == 0 {
if len(rows) == 0 {
return nil, 0, cosmosdbutil.ErrNoRows
}
//Assume 1 per RoomNID
room := response[0]
room := rows[0]
return mapToRoomEventNIDArray(room.Room.LatestEventNIDs), types.StateSnapshotNID(room.Room.StateSnapshotNID), nil
}
@ -349,25 +337,29 @@ func (s *roomStatements) SelectLatestEventsNIDsForUpdate(
// "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomNID,
}
response, err := queryRoom(s, ctx, s.selectLatestEventNIDsForUpdateStmt, params)
var rows []roomCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectLatestEventNIDsForUpdateStmt, params, &rows)
if err != nil {
return nil, 0, 0, err
}
// TODO: Check the error handling
if len(response) == 0 {
if len(rows) == 0 {
return nil, 0, 0, cosmosdbutil.ErrNoRows
}
//Assume 1 per RoomNID
room := response[0]
room := rows[0]
return mapToRoomEventNIDArray(room.Room.LatestEventNIDs), types.EventNID(room.Room.LastEventSentNID), types.StateSnapshotNID(room.Room.StateSnapshotNID), nil
}
@ -382,25 +374,29 @@ func (s *roomStatements) UpdateLatestEventNIDs(
// "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomNID,
}
response, err := queryRoom(s, ctx, s.selectLatestEventNIDsForUpdateStmt, params)
var rows []roomCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectLatestEventNIDsForUpdateStmt, params, &rows)
if err != nil {
return err
}
// TODO: Check the error handling
if len(response) == 0 {
if len(rows) == 0 {
return cosmosdbutil.ErrNoRows
}
//Assume 1 per RoomNID
room := response[0]
room := rows[0]
room.Room.LatestEventNIDs = mapFromEventNIDArray(eventNIDs)
room.Room.LastEventSentNID = int64(lastEventSentNID)
@ -419,20 +415,24 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
// "SELECT room_nid, room_version FROM roomserver_rooms WHERE room_nid IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomNIDs,
}
response, err := queryRoom(s, ctx, selectRoomVersionsForRoomNIDsSQL, params)
var rows []roomCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectRoomVersionForRoomNIDStmt, params, &rows)
if err != nil {
return nil, err
}
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
for _, item := range response {
for _, item := range rows {
result[types.RoomNID(item.Room.RoomNID)] = gomatrixserverlib.RoomVersion(item.Room.RoomVersion)
}
return result, nil
@ -445,20 +445,24 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
// "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomNIDs,
}
response, err := queryRoom(s, ctx, bulkSelectRoomIDsSQL, params)
var rows []roomCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), bulkSelectRoomIDsSQL, params, &rows)
if err != nil {
return nil, err
}
var roomIDs []string
for _, item := range response {
for _, item := range rows {
roomIDs = append(roomIDs, item.Room.RoomID)
}
return roomIDs, nil
@ -471,20 +475,24 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []strin
// "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomIDs,
}
response, err := queryRoom(s, ctx, bulkSelectRoomNIDsSQL, params)
var rows []roomCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), bulkSelectRoomNIDsSQL, params, &rows)
if err != nil {
return nil, err
}
var roomNIDs []types.RoomNID
for _, item := range response {
for _, item := range rows {
roomNIDs = append(roomNIDs, types.RoomNID(item.Room.RoomNID))
}
return roomNIDs, nil

View file

@ -42,19 +42,19 @@ import (
// );
// `
type StateBlockCosmos struct {
type stateBlockCosmos struct {
StateBlockNID int64 `json:"state_block_nid"`
StateBlockHash []byte `json:"state_block_hash"`
EventNIDs []int64 `json:"event_nids"`
}
type StateBlockCosmosMaxNID struct {
type stateBlockCosmosMaxNID struct {
Max int64 `json:"maxstateblocknid"`
}
type StateBlockCosmosData struct {
type stateBlockCosmosData struct {
cosmosdbapi.CosmosDocument
StateBlock StateBlockCosmos `json:"mx_roomserver_state_block"`
StateBlock stateBlockCosmos `json:"mx_roomserver_state_block"`
}
// Insert a new state block. If we conflict on the hash column then
@ -82,29 +82,16 @@ type stateBlockStatements struct {
tableName string
}
func queryStateBlock(s *stateBlockStatements, ctx context.Context, qry string, params map[string]interface{}) ([]StateBlockCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []StateBlockCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *stateBlockStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func getStateBlock(s *stateBlockStatements, ctx context.Context, pk string, docId string) (*StateBlockCosmosData, error) {
response := StateBlockCosmosData{}
func (s *stateBlockStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getStateBlock(s *stateBlockStatements, ctx context.Context, pk string, docId string) (*stateBlockCosmosData, error) {
response := stateBlockCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -120,7 +107,7 @@ func getStateBlock(s *stateBlockStatements, ctx context.Context, pk string, docI
return &response, err
}
func setStateBlock(s *stateBlockStatements, ctx context.Context, item StateBlockCosmosData) (*StateBlockCosmosData, error) {
func setStateBlock(s *stateBlockStatements, ctx context.Context, item stateBlockCosmosData) (*stateBlockCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(item.Pk, item.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
@ -168,16 +155,12 @@ func (s *stateBlockStatements) BulkInsertStateData(
// ctx, nids.Hash(), js,
// ).Scan(&id)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// state_block_hash BLOB UNIQUE,
docId := hex.EncodeToString(nids.Hash())
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
//See if it exists
existing, err := getStateBlock(s, ctx, pk, cosmosDocId)
existing, err := getStateBlock(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
if err != cosmosdbutil.ErrNoRows {
return 0, err
@ -199,14 +182,14 @@ func (s *stateBlockStatements) BulkInsertStateData(
seq, err := GetNextStateBlockNID(s, ctx)
id = types.StateBlockNID(seq)
data := StateBlockCosmos{
data := stateBlockCosmos{
StateBlockNID: seq,
StateBlockHash: nids.Hash(),
EventNIDs: ids,
}
var dbData = StateBlockCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
var dbData = stateBlockCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
StateBlock: data,
}
@ -235,14 +218,18 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
// if err != nil {
// return nil, err
// }
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": stateBlockNIDs,
}
// rows, err := selectStmt.QueryContext(ctx, intfs...)
rows, err := queryStateBlock(s, ctx, s.bulkSelectStateBlockEntriesStmt, params)
var rows []stateBlockCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.bulkSelectStateBlockEntriesStmt, params, &rows)
if err != nil {
return nil, err

View file

@ -47,16 +47,16 @@ import (
// state_block_nids TEXT NOT NULL DEFAULT '[]'
// );
type StateSnapshotCosmos struct {
type stateSnapshotCosmos struct {
StateSnapshotNID int64 `json:"state_snapshot_nid"`
StateSnapshotHash []byte `json:"state_snapshot_hash"`
RoomNID int64 `json:"room_nid"`
StateBlockNIDs []int64 `json:"state_block_nids"`
}
type StateSnapshotCosmosData struct {
type stateSnapshotCosmosData struct {
cosmosdbapi.CosmosDocument
StateSnapshot StateSnapshotCosmos `json:"mx_roomserver_state_snapshot"`
StateSnapshot stateSnapshotCosmos `json:"mx_roomserver_state_snapshot"`
}
// const insertStateSQL = `
@ -82,6 +82,14 @@ type stateSnapshotStatements struct {
tableName string
}
func (s *stateSnapshotStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *stateSnapshotStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func mapFromStateBlockNIDArray(stateBlockNIDs []types.StateBlockNID) []int64 {
result := []int64{}
for i := 0; i < len(stateBlockNIDs); i++ {
@ -133,22 +141,19 @@ func (s *stateSnapshotStatements) InsertState(
// return
// }
data := StateSnapshotCosmos{
// state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
docId := fmt.Sprintf("%d", stateSnapshotNIDSeq)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
data := stateSnapshotCosmos{
RoomNID: int64(roomNID),
StateSnapshotHash: stateBlockNIDs.Hash(),
StateBlockNIDs: mapFromStateBlockNIDArray(stateBlockNIDs),
StateSnapshotNID: int64(stateSnapshotNIDSeq),
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
docId := fmt.Sprintf("%d", stateSnapshotNIDSeq)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var dbData = StateSnapshotCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
var dbData = stateSnapshotCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
StateSnapshot: data,
}
@ -174,30 +179,25 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
// "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
// " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []StateSnapshotCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": stateNIDs,
}
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.bulkSelectStateBlockNIDsStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
var rows []stateSnapshotCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
s.getPartitionKey(), s.bulkSelectStateBlockNIDsStmt, params, &rows)
if err != nil {
return nil, err
}
results := make([]types.StateBlockNIDList, len(stateNIDs))
i := 0
for _, item := range response {
for _, item := range rows {
result := &results[i]
result.StateSnapshotNID = types.StateSnapshotNID(item.StateSnapshot.StateSnapshotNID)
result.StateBlockNIDs = mapToStateBlockNIDArray(item.StateSnapshot.StateBlockNIDs)

View file

@ -35,16 +35,16 @@ import (
// );
// `
type TransactionCosmos struct {
type transactionCosmos struct {
TransactionID string `json:"transaction_id"`
SessionID int64 `json:"session_id"`
UserID string `json:"user_id"`
EventID string `json:"event_id"`
}
type TransactionCosmosData struct {
type transactionCosmosData struct {
cosmosdbapi.CosmosDocument
Transaction TransactionCosmos `json:"mx_roomserver_transaction"`
Transaction transactionCosmos `json:"mx_roomserver_transaction"`
}
// const insertTransactionSQL = `
@ -64,8 +64,16 @@ type transactionStatements struct {
tableName string
}
func getTransaction(s *transactionStatements, ctx context.Context, pk string, docId string) (*TransactionCosmosData, error) {
response := TransactionCosmosData{}
func (s *transactionStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *transactionStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getTransaction(s *transactionStatements, ctx context.Context, pk string, docId string) (*transactionCosmosData, error) {
response := transactionCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -103,22 +111,20 @@ func (s *transactionStatements) InsertTransaction(
// INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)
// VALUES ($1, $2, $3, $4)
data := TransactionCosmos{
// PRIMARY KEY (transaction_id, session_id, user_id)
docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
data := transactionCosmos{
EventID: eventID,
SessionID: sessionID,
TransactionID: transactionID,
UserID: userID,
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// PRIMARY KEY (transaction_id, session_id, user_id)
docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var dbData = TransactionCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
var dbData = transactionCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
Transaction: data,
}
@ -143,13 +149,11 @@ func (s *transactionStatements) SelectTransactionEventID(
// SELECT event_id FROM roomserver_transactions
// WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// PRIMARY KEY (transaction_id, session_id, user_id)
docId := fmt.Sprintf("%s_%d_%s", transactionID, sessionID, userID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
response, err := getTransaction(s, ctx, pk, cosmosDocId)
response, err := getTransaction(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return "", err

View file

@ -51,7 +51,7 @@ import (
// CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id);
// `
type ServerKeyCosmos struct {
type serverKeyCosmos struct {
ServerName string `json:"server_name"`
ServerKeyID string `json:"server_key_id"`
ServerNameAndKeyID string `json:"server_name_and_key_id"`
@ -60,9 +60,9 @@ type ServerKeyCosmos struct {
ServerKey string `json:"server_key"`
}
type ServerKeyCosmosData struct {
type serverKeyCosmosData struct {
cosmosdbapi.CosmosDocument
ServerKey ServerKeyCosmos `json:"mx_keydb_server_key"`
ServerKey serverKeyCosmos `json:"mx_keydb_server_key"`
}
// "SELECT server_name, server_key_id, valid_until_ts, expired_ts, " +
@ -87,8 +87,16 @@ type serverKeyStatements struct {
tableName string
}
func getServerKey(s *serverKeyStatements, ctx context.Context, pk string, docId string) (*ServerKeyCosmosData, error) {
response := ServerKeyCosmosData{}
func (s *serverKeyStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *serverKeyStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getServerKey(s *serverKeyStatements, ctx context.Context, pk string, docId string) (*serverKeyCosmosData, error) {
response := serverKeyCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -104,27 +112,6 @@ func getServerKey(s *serverKeyStatements, ctx context.Context, pk string, docId
return &response, err
}
func queryServerKey(s *serverKeyStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ServerKeyCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []ServerKeyCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func (s *serverKeyStatements) prepare(db *Database, writer sqlutil.Writer) (err error) {
s.db = db
s.writer = writer
@ -146,9 +133,8 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
// iKeyIDs[i] = v
// }
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": nameAndKeyIDs,
}
@ -159,8 +145,12 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
// err := sqlutil.RunLimitedVariablesQuery(
// ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables,
// func(rows *sql.Rows) error {
rows, err := queryServerKey(s, ctx, bulkSelectServerKeysSQL, params)
var rows []serverKeyCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), bulkSelectServerKeysSQL, params, &rows)
if err != nil {
return nil, err
@ -213,20 +203,18 @@ func (s *serverKeyStatements) upsertServerKeys(
// stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (server_name, server_key_id)
docId := fmt.Sprintf("%s_%s", string(request.ServerName), string(request.KeyID))
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getServerKey(s, ctx, pk, cosmosDocId)
dbData, _ := getServerKey(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
dbData.SetUpdateTime()
dbData.ServerKey.ValidUntilTimestamp = int64(key.ValidUntilTS)
dbData.ServerKey.ExpiredTimestamp = int64(key.ExpiredTS)
dbData.ServerKey.ServerKey = key.Key.Encode()
} else {
data := ServerKeyCosmos{
data := serverKeyCosmos{
ServerName: string(request.ServerName),
ServerKeyID: string(request.KeyID),
ServerNameAndKeyID: nameAndKeyID(request),
@ -235,8 +223,8 @@ func (s *serverKeyStatements) upsertServerKeys(
ServerKey: key.Key.Encode(),
}
dbData = &ServerKeyCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &serverKeyCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
ServerKey: data,
}
}

View file

@ -39,7 +39,7 @@ import (
// );
// `
type AccountDataTypeCosmos struct {
type accountDataTypeCosmos struct {
ID int64 `json:"id"`
UserID string `json:"user_id"`
RoomID string `json:"room_id"`
@ -50,9 +50,9 @@ type AccountDataTypeNumberCosmosData struct {
Number int64 `json:"number"`
}
type AccountDataTypeCosmosData struct {
type accountDataTypeCosmosData struct {
cosmosdbapi.CosmosDocument
AccountDataType AccountDataTypeCosmos `json:"mx_syncapi_account_data_type"`
AccountDataType accountDataTypeCosmos `json:"mx_syncapi_account_data_type"`
}
// const insertAccountDataSQL = "" +
@ -83,8 +83,16 @@ type accountDataStatements struct {
tableName string
}
func getAccountDataType(s *accountDataStatements, ctx context.Context, pk string, docId string) (*AccountDataTypeCosmosData, error) {
response := AccountDataTypeCosmosData{}
func (s *accountDataStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *accountDataStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getAccountDataType(s *accountDataStatements, ctx context.Context, pk string, docId string) (*accountDataTypeCosmosData, error) {
response := accountDataTypeCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -100,48 +108,6 @@ func getAccountDataType(s *accountDataStatements, ctx context.Context, pk string
return &response, err
}
func queryAccountDataType(s *accountDataStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountDataTypeCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []AccountDataTypeCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryAccountDataTypeNumber(s *accountDataStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountDataTypeNumberCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []AccountDataTypeNumberCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, cosmosdbutil.ErrNoRows
}
return response, nil
}
func NewCosmosDBAccountDataTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.AccountData, error) {
s := &accountDataStatements{
db: db,
@ -168,26 +134,24 @@ func (s *accountDataStatements) InsertAccountData(
return
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (user_id, room_id, type)
docId := fmt.Sprintf("%s_%s_%s", userID, roomID, dataType)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getAccountDataType(s, ctx, pk, cosmosDocId)
dbData, _ := getAccountDataType(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
dbData.SetUpdateTime()
dbData.AccountDataType.ID = int64(pos)
} else {
data := AccountDataTypeCosmos{
data := accountDataTypeCosmos{
ID: int64(pos),
UserID: userID,
RoomID: roomID,
DataType: dataType,
}
dbData = &AccountDataTypeCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &accountDataTypeCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
AccountDataType: data,
}
}
@ -216,15 +180,18 @@ func (s *accountDataStatements) SelectAccountDataInRange(
// " ORDER BY id ASC"
// rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High())
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": userID,
"@x3": r.Low(),
"@x4": r.High(),
}
rows, err := queryAccountDataType(s, ctx, s.selectAccountDataInRangeStmt, params)
var rows []accountDataTypeCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectAccountDataInRangeStmt, params, &rows)
if err != nil {
return
@ -276,12 +243,16 @@ func (s *accountDataStatements) SelectMaxAccountDataID(
var nullableID sql.NullInt64
// err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
}
rows, err := queryAccountDataTypeNumber(s, ctx, s.selectMaxAccountDataIDStmt, params)
var rows []AccountDataTypeNumberCosmosData
err = cosmosdbapi.PerformQueryAllPartitions(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.selectMaxAccountDataIDStmt, params, &rows)
if err != cosmosdbutil.ErrNoRows && len(rows) == 1 {
nullableID.Int64 = rows[0].Number

View file

@ -37,15 +37,15 @@ import (
// );
// `
type BackwardExtremityCosmos struct {
type backwardExtremityCosmos struct {
RoomID string `json:"room_id"`
EventID string `json:"event_id"`
PrevEventID string `json:"prev_event_id"`
}
type BackwardExtremityCosmosData struct {
type backwardExtremityCosmosData struct {
cosmosdbapi.CosmosDocument
BackwardExtremity BackwardExtremityCosmos `json:"mx_syncapi_backward_extremity"`
BackwardExtremity backwardExtremityCosmos `json:"mx_syncapi_backward_extremity"`
}
// const insertBackwardExtremitySQL = "" +
@ -78,8 +78,16 @@ type backwardExtremitiesStatements struct {
tableName string
}
func getBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, pk string, docId string) (*BackwardExtremityCosmosData, error) {
response := BackwardExtremityCosmosData{}
func (s *backwardExtremitiesStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *backwardExtremitiesStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, pk string, docId string) (*backwardExtremityCosmosData, error) {
response := backwardExtremityCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -95,28 +103,7 @@ func getBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context,
return &response, err
}
func queryBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]BackwardExtremityCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []BackwardExtremityCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func deleteBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, dbData BackwardExtremityCosmosData) error {
func deleteBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, dbData backwardExtremityCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
@ -152,24 +139,22 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
// _, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// PRIMARY KEY(room_id, event_id, prev_event_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, eventID, prevEventID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getBackwardExtremity(s, ctx, pk, cosmosDocId)
dbData, _ := getBackwardExtremity(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
dbData.SetUpdateTime()
} else {
data := BackwardExtremityCosmos{
data := backwardExtremityCosmos{
EventID: eventID,
PrevEventID: prevEventID,
RoomID: roomID,
}
dbData = &BackwardExtremityCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &backwardExtremityCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
BackwardExtremity: data,
}
}
@ -191,13 +176,16 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
// "SELECT event_id, prev_event_id FROM syncapi_backward_extremities WHERE room_id = $1"
// rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
}
rows, err := queryBackwardExtremity(s, ctx, s.selectBackwardExtremitiesForRoomStmt, params)
var rows []backwardExtremityCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectBackwardExtremitiesForRoomStmt, params, &rows)
if err != nil {
return
@ -223,14 +211,18 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
// _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
"@x3": knownEventID,
}
var rows []backwardExtremityCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteBackwardExtremityStmt, params, &rows)
rows, err := queryBackwardExtremity(s, ctx, s.deleteBackwardExtremityStmt, params)
if err != nil {
return
}
@ -249,13 +241,18 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
// _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
}
rows, err := queryBackwardExtremity(s, ctx, s.deleteBackwardExtremitiesForRoomStmt, params)
var rows []backwardExtremityCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteBackwardExtremitiesForRoomStmt, params, &rows)
if err != nil {
return
}

View file

@ -52,7 +52,7 @@ import (
// CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id);
// `
type CurrentRoomStateCosmos struct {
type currentRoomStateCosmos struct {
RoomID string `json:"room_id"`
EventID string `json:"event_id"`
Type string `json:"type"`
@ -64,9 +64,9 @@ type CurrentRoomStateCosmos struct {
AddedAt int64 `json:"added_at"`
}
type CurrentRoomStateCosmosData struct {
type currentRoomStateCosmosData struct {
cosmosdbapi.CosmosDocument
CurrentRoomState CurrentRoomStateCosmos `json:"mx_syncapi_current_room_state"`
CurrentRoomState currentRoomStateCosmos `json:"mx_syncapi_current_room_state"`
}
// const upsertRoomStateSQL = "" +
@ -132,50 +132,16 @@ type currentRoomStateStatements struct {
jsonPropertyName string
}
func queryCurrentRoomState(s *currentRoomStateStatements, ctx context.Context, qry string, params map[string]interface{}) ([]CurrentRoomStateCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []CurrentRoomStateCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *currentRoomStateStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func queryCurrentRoomStateDistinct(s *currentRoomStateStatements, ctx context.Context, qry string, params map[string]interface{}) ([]CurrentRoomStateCosmos, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []CurrentRoomStateCosmos
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *currentRoomStateStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getEvent(s *currentRoomStateStatements, ctx context.Context, pk string, docId string) (*CurrentRoomStateCosmosData, error) {
response := CurrentRoomStateCosmosData{}
func getEvent(s *currentRoomStateStatements, ctx context.Context, pk string, docId string) (*currentRoomStateCosmosData, error) {
response := currentRoomStateCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -191,7 +157,7 @@ func getEvent(s *currentRoomStateStatements, ctx context.Context, pk string, doc
return &response, err
}
func deleteCurrentRoomState(s *currentRoomStateStatements, ctx context.Context, dbData CurrentRoomStateCosmosData) error {
func deleteCurrentRoomState(s *currentRoomStateStatements, ctx context.Context, dbData currentRoomStateCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
@ -228,12 +194,15 @@ func (s *currentRoomStateStatements) SelectJoinedUsers(
// "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
// rows, err := s.selectJoinedUsersStmt.QueryContext(ctx)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
}
rows, err := queryCurrentRoomState(s, ctx, s.selectJoinedUsersStmt, params)
var rows []currentRoomStateCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectJoinedUsersStmt, params, &rows)
if err != nil {
return nil, err
@ -264,14 +233,18 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
// stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithMembershipStmt)
// rows, err := stmt.QueryContext(ctx, userID, membership)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": userID,
"@x3": membership,
}
rows, err := queryCurrentRoomStateDistinct(s, ctx, s.selectRoomIDsWithMembershipStmt, params)
var rows []currentRoomStateCosmos
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectRoomIDsWithMembershipStmt, params, &rows)
if err != nil {
return nil, err
@ -296,9 +269,8 @@ func (s *currentRoomStateStatements) SelectCurrentState(
// "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1"
// // WHEN, ORDER BY and LIMIT will be added by prepareWithFilter
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
"@x3": stateFilter.Limit,
}
@ -309,7 +281,12 @@ func (s *currentRoomStateStatements) SelectCurrentState(
stateFilter.Types, stateFilter.NotTypes,
excludeEventIDs, stateFilter.Limit, FilterOrderNone,
)
rows, err := queryCurrentRoomState(s, ctx, stmt, params)
var rows []currentRoomStateCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), stmt, params, &rows)
if err != nil {
return nil, err
@ -325,13 +302,16 @@ func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
// "DELETE FROM syncapi_current_room_state WHERE event_id = $1"
// stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventID,
}
rows, err := queryCurrentRoomState(s, ctx, s.deleteRoomStateByEventIDStmt, params)
var rows []currentRoomStateCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteRoomStateByEventIDStmt, params, &rows)
for _, item := range rows {
err = deleteCurrentRoomState(s, ctx, item)
@ -348,13 +328,16 @@ func (s *currentRoomStateStatements) DeleteRoomStateForRoom(
// "DELETE FROM syncapi_current_room_state WHERE event_id = $1"
// stmt := sqlutil.TxStmt(txn, s.DeleteRoomStateForRoomStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
}
rows, err := queryCurrentRoomState(s, ctx, s.DeleteRoomStateForRoomStmt, params)
var rows []currentRoomStateCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.DeleteRoomStateForRoomStmt, params, &rows)
for _, item := range rows {
err = deleteCurrentRoomState(s, ctx, item)
@ -407,18 +390,16 @@ func (s *currentRoomStateStatements) UpsertRoomState(
// addedAt,
// )
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// " ON CONFLICT (room_id, type, state_key)" +
docId := fmt.Sprintf("%s_%s_%s", event.RoomID(), event.Type(), *event.StateKey())
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
membershipData := ""
if membership != nil {
membershipData = *membership
}
dbData, _ := getEvent(s, ctx, pk, cosmosDocId)
dbData, _ := getEvent(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
// " DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9"
dbData.SetUpdateTime()
@ -429,7 +410,7 @@ func (s *currentRoomStateStatements) UpsertRoomState(
dbData.CurrentRoomState.Membership = membershipData
dbData.CurrentRoomState.AddedAt = int64(addedAt)
} else {
data := CurrentRoomStateCosmos{
data := currentRoomStateCosmos{
RoomID: event.RoomID(),
EventID: event.EventID(),
Type: event.Type(),
@ -441,8 +422,8 @@ func (s *currentRoomStateStatements) UpsertRoomState(
AddedAt: int64(addedAt),
}
dbData = &CurrentRoomStateCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &currentRoomStateCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
CurrentRoomState: data,
}
}
@ -480,13 +461,17 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
// query := strings.Replace(selectEventsWithEventIDsSQL, "@x2", sql.QueryVariadic(n), 1)
// rows, err := txn.QueryContext(ctx, query, iEventIDs[start:start+n]...)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventIDs,
}
rows, err := queryCurrentRoomState(s, ctx, s.DeleteRoomStateForRoomStmt, params)
var rows []currentRoomStateCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.DeleteRoomStateForRoomStmt, params, &rows)
if err != nil {
return nil, err
@ -502,7 +487,7 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
}
// Copied from output_room_events_table
func rowsToStreamEventsFromCurrentRoomState(rows *[]CurrentRoomStateCosmosData) ([]types.StreamEvent, error) {
func rowsToStreamEventsFromCurrentRoomState(rows *[]currentRoomStateCosmosData) ([]types.StreamEvent, error) {
var result []types.StreamEvent
for _, item := range *rows {
var (
@ -546,7 +531,7 @@ func rowsToStreamEventsFromCurrentRoomState(rows *[]CurrentRoomStateCosmosData)
return result, nil
}
func rowsToEvents(rows *[]CurrentRoomStateCosmosData) ([]*gomatrixserverlib.HeaderedEvent, error) {
func rowsToEvents(rows *[]currentRoomStateCosmosData) ([]*gomatrixserverlib.HeaderedEvent, error) {
result := []*gomatrixserverlib.HeaderedEvent{}
for _, item := range *rows {
var eventID string
@ -570,12 +555,10 @@ func (s *currentRoomStateStatements) SelectStateEvent(
// stmt := s.selectStateEventStmt
var res []byte
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
// " ON CONFLICT (room_id, type, state_key)" +
docId := fmt.Sprintf("%s_%s_%s", roomID, evType, stateKey)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
var response, err = getEvent(s, ctx, pk, cosmosDocId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var response, err = getEvent(s, ctx, s.getPartitionKey(), cosmosDocId)
// err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res)
if err == cosmosdbutil.ErrNoRows {

View file

@ -41,15 +41,15 @@ import (
// CREATE INDEX IF NOT EXISTS syncapi_filter_localpart ON syncapi_filter(localpart);
// `
type FilterCosmos struct {
type filterCosmos struct {
ID int64 `json:"id"`
Filter []byte `json:"filter"`
Localpart string `json:"localpart"`
}
type FilterCosmosData struct {
type filterCosmosData struct {
cosmosdbapi.CosmosDocument
Filter FilterCosmos `json:"mx_syncapi_filter"`
Filter filterCosmos `json:"mx_syncapi_filter"`
}
// const selectFilterSQL = "" +
@ -72,34 +72,16 @@ type filterStatements struct {
tableName string
}
func queryFilter(s *filterStatements, ctx context.Context, qry string, params map[string]interface{}) ([]FilterCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []FilterCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
if len(response) == 0 {
return nil, cosmosdbutil.ErrNoRows
}
return response, nil
func (s *filterStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func getFilter(s *filterStatements, ctx context.Context, pk string, docId string) (*FilterCosmosData, error) {
response := FilterCosmosData{}
func (s *filterStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getFilter(s *filterStatements, ctx context.Context, pk string, docId string) (*filterCosmosData, error) {
response := filterCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -134,12 +116,10 @@ func (s *filterStatements) SelectFilter(
var filterData []byte
// err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (id, localpart)
docId := fmt.Sprintf("%s_%s", localpart, filterID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response, err = getFilter(s, ctx, pk, cosmosDocId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var response, err = getFilter(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return nil, err
@ -187,21 +167,25 @@ func (s *filterStatements) InsertFilter(
// TODO: See if we can avoid the search by Content []byte
// "SELECT id FROM syncapi_filter WHERE localpart = $1 AND filter = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": localpart,
"@x3": filterJSON,
}
response, err := queryFilter(s, ctx, s.selectFilterIDByContentStmt, params)
var rows []filterCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectFilterIDByContentStmt, params, &rows)
if err != nil && err != cosmosdbutil.ErrNoRows {
return "", err
}
if response != nil {
existingFilterID = fmt.Sprintf("%d", response[0].Filter.ID)
if len(rows) > 0 {
existingFilterID = fmt.Sprintf("%d", rows[0].Filter.ID)
}
// If it does, return the existing ID
if existingFilterID != "" {
@ -217,7 +201,7 @@ func (s *filterStatements) InsertFilter(
return "", seqErr
}
data := FilterCosmos{
data := filterCosmos{
ID: seqID,
Localpart: localpart,
Filter: filterJSON,
@ -225,11 +209,10 @@ func (s *filterStatements) InsertFilter(
// UNIQUE (id, localpart)
docId := fmt.Sprintf("%s_%d", localpart, seqID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var dbData = FilterCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
var dbData = filterCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
Filter: data,
}

View file

@ -43,7 +43,7 @@ import (
// CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx ON syncapi_invite_events (event_id);
// `
type InviteEventCosmos struct {
type inviteEventCosmos struct {
ID int64 `json:"id"`
EventID string `json:"event_id"`
RoomID string `json:"room_id"`
@ -52,13 +52,13 @@ type InviteEventCosmos struct {
Deleted bool `json:"deleted"`
}
type InviteEventCosmosMaxNumber struct {
type inviteEventCosmosMaxNumber struct {
Max int64 `json:"number"`
}
type InviteEventCosmosData struct {
type inviteEventCosmosData struct {
cosmosdbapi.CosmosDocument
InviteEvent InviteEventCosmos `json:"mx_syncapi_invite_event"`
InviteEvent inviteEventCosmos `json:"mx_syncapi_invite_event"`
}
// const insertInviteEventSQL = "" +
@ -95,51 +95,16 @@ type inviteEventsStatements struct {
tableName string
}
func queryInviteEvent(s *inviteEventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InviteEventCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []InviteEventCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *inviteEventsStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func queryInviteEventMaxNumber(s *inviteEventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InviteEventCosmosMaxNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []InviteEventCosmosMaxNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, nil
}
return response, nil
func (s *inviteEventsStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getInviteEvent(s *inviteEventsStatements, ctx context.Context, pk string, docId string) (*InviteEventCosmosData, error) {
response := InviteEventCosmosData{}
func getInviteEvent(s *inviteEventsStatements, ctx context.Context, pk string, docId string) (*inviteEventCosmosData, error) {
response := inviteEventCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -155,7 +120,7 @@ func getInviteEvent(s *inviteEventsStatements, ctx context.Context, pk string, d
return &response, err
}
func setInviteEvent(s *inviteEventsStatements, ctx context.Context, invite InviteEventCosmosData) (*InviteEventCosmosData, error) {
func setInviteEvent(s *inviteEventsStatements, ctx context.Context, invite inviteEventCosmosData) (*inviteEventCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(invite.Pk, invite.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
@ -207,7 +172,7 @@ func (s *inviteEventsStatements) InsertInviteEvent(
// *inviteEvent.StateKey(),
// headeredJSON,
// )
data := InviteEventCosmos{
data := inviteEventCosmos{
ID: int64(streamPos),
RoomID: inviteEvent.RoomID(),
EventID: inviteEvent.EventID(),
@ -215,14 +180,12 @@ func (s *inviteEventsStatements) InsertInviteEvent(
HeaderedEventJSON: headeredJSON,
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
// id INTEGER PRIMARY KEY,
docId := fmt.Sprintf("%d", streamPos)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var dbData = InviteEventCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
var dbData = inviteEventCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
InviteEvent: data,
}
@ -249,14 +212,18 @@ func (s *inviteEventsStatements) DeleteInviteEvent(
// stmt := sqlutil.TxStmt(txn, s.deleteInviteEventStmt)
// _, err = stmt.ExecContext(ctx, streamPos, inviteEventID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": inviteEventID,
}
response, err := queryInviteEvent(s, ctx, s.deleteInviteEventStmt, params)
var rows []inviteEventCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteInviteEventStmt, params, &rows)
for _, item := range response {
for _, item := range rows {
item.InviteEvent.Deleted = true
item.InviteEvent.ID = int64(streamPos)
setInviteEvent(s, ctx, item)
@ -276,14 +243,18 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange(
// stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
// rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": targetUserID,
"@x3": r.Low(),
"@x4": r.High(),
}
rows, err := queryInviteEvent(s, ctx, s.selectInviteEventsInRangeStmt, params)
var rows []inviteEventCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectInviteEventsInRangeStmt, params, &rows)
if err != nil {
return nil, nil, err
@ -333,14 +304,18 @@ func (s *inviteEventsStatements) SelectMaxInviteID(
// stmt := sqlutil.TxStmt(txn, s.selectMaxInviteIDStmt)
// err = stmt.QueryRowContext(ctx).Scan(&nullableID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
}
response, err := queryInviteEventMaxNumber(s, ctx, s.selectMaxInviteIDStmt, params)
var rows []inviteEventCosmosMaxNumber
err = cosmosdbapi.PerformQueryAllPartitions(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.selectMaxInviteIDStmt, params, &rows)
if response != nil {
nullableID.Int64 = response[0].Max
if len(rows) > 0 {
nullableID.Int64 = rows[0].Max
}
if nullableID.Valid {

View file

@ -51,7 +51,7 @@ import (
// );
// `
type MembershipCosmos struct {
type membershipCosmos struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
Membership string `json:"membership"`
@ -60,9 +60,9 @@ type MembershipCosmos struct {
TopologicalPos int64 `json:"topological_pos"`
}
type MembershipCosmosData struct {
type membershipCosmosData struct {
cosmosdbapi.CosmosDocument
Membership MembershipCosmos `json:"mx_syncapi_membership"`
Membership membershipCosmos `json:"mx_syncapi_membership"`
}
// const upsertMembershipSQL = "" +
@ -88,8 +88,16 @@ type membershipsStatements struct {
tableName string
}
func getMembership(s *membershipsStatements, ctx context.Context, pk string, docId string) (*MembershipCosmosData, error) {
response := MembershipCosmosData{}
func (s *membershipsStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *membershipsStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getMembership(s *membershipsStatements, ctx context.Context, pk string, docId string) (*membershipCosmosData, error) {
response := membershipCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -105,27 +113,6 @@ func getMembership(s *membershipsStatements, ctx context.Context, pk string, doc
return &response, err
}
func queryMembership(s *membershipsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]MembershipCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []MembershipCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func NewCosmosDBMembershipsTable(db *SyncServerDatasource) (tables.Memberships, error) {
s := &membershipsStatements{
db: db,
@ -158,13 +145,11 @@ func (s *membershipsStatements) UpsertMembership(
// topologicalPos,
// )
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
// UNIQUE (room_id, user_id, membership)
docId := fmt.Sprintf("%s_%s_%s", event.RoomID(), *event.StateKey(), membership)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getMembership(s, ctx, pk, cosmosDocId)
dbData, _ := getMembership(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
// " DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6"
dbData.SetUpdateTime()
@ -172,7 +157,7 @@ func (s *membershipsStatements) UpsertMembership(
dbData.Membership.StreamPos = int64(streamPos)
dbData.Membership.TopologicalPos = int64(topologicalPos)
} else {
data := MembershipCosmos{
data := membershipCosmos{
RoomID: event.RoomID(),
UserID: *event.StateKey(),
Membership: membership,
@ -181,8 +166,8 @@ func (s *membershipsStatements) UpsertMembership(
TopologicalPos: int64(topologicalPos),
}
dbData = &MembershipCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &membershipCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
Membership: data,
}
}
@ -209,15 +194,19 @@ func (s *membershipsStatements) SelectMembership(
// " LIMIT 1"
// err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
"@x3": userID,
"@x4": memberships,
}
// orig := strings.Replace(selectMembershipSQL, "@x4", cosmosdbutil.QueryVariadicOffset(len(memberships), 2), 1)
rows, err := queryMembership(s, ctx, selectMembershipSQL, params)
var rows []membershipCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), selectMembershipSQL, params, &rows)
if err != nil || len(rows) == 0 {
return "", 0, 0, err

View file

@ -22,8 +22,6 @@ import (
"fmt"
"sort"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/roomserver/api"
@ -52,7 +50,7 @@ import (
// );
// `
type OutputRoomEventCosmos struct {
type outputRoomEventCosmos struct {
ID int64 `json:"id"`
EventID string `json:"event_id"`
RoomID string `json:"room_id"`
@ -67,13 +65,13 @@ type OutputRoomEventCosmos struct {
ExcludeFromSync bool `json:"exclude_from_sync"`
}
type OutputRoomEventCosmosMaxNumber struct {
type outputRoomEventCosmosMaxNumber struct {
Max int64 `json:"number"`
}
type OutputRoomEventCosmosData struct {
type outputRoomEventCosmosData struct {
cosmosdbapi.CosmosDocument
OutputRoomEvent OutputRoomEventCosmos `json:"mx_syncapi_output_room_event"`
OutputRoomEvent outputRoomEventCosmos `json:"mx_syncapi_output_room_event"`
}
// const insertEventSQL = "" +
@ -152,49 +150,15 @@ type outputRoomEventsStatements struct {
jsonPropertyName string
}
func queryOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OutputRoomEventCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []OutputRoomEventCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *outputRoomEventsStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func queryOutputRoomEventNumber(s *outputRoomEventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OutputRoomEventCosmosMaxNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []OutputRoomEventCosmosMaxNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, cosmosdbutil.ErrNoRows
}
return response, nil
func (s *outputRoomEventsStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func setOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, outputRoomEvent OutputRoomEventCosmosData) (*OutputRoomEventCosmosData, error) {
func setOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, outputRoomEvent outputRoomEventCosmosData) (*outputRoomEventCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(outputRoomEvent.Pk, outputRoomEvent.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
@ -206,7 +170,7 @@ func setOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, outp
return &outputRoomEvent, ex
}
func deleteOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, dbData OutputRoomEventCosmosData) error {
func deleteOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, dbData outputRoomEventCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
@ -243,14 +207,19 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event
// "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": event.EventID(),
}
// _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
rows, err := queryOutputRoomEvent(s, ctx, s.deleteEventsForRoomStmt, params)
var rows []outputRoomEventCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteEventsForRoomStmt, params, &rows)
if err != nil {
return err
}
@ -261,7 +230,6 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event
}
return err
return err
}
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
@ -277,9 +245,8 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
// " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))"
// // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": r.Low(),
"@x3": r.High(),
"@x4": stateFilter.Limit,
@ -292,7 +259,13 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
)
// rows, err := stmt.QueryContext(ctx, params...)
rows, err := queryOutputRoomEvent(s, ctx, query, params)
var rows []outputRoomEventCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), query, params, &rows)
if err != nil {
return nil, nil, err
}
@ -374,13 +347,17 @@ func (s *outputRoomEventsStatements) SelectMaxEventID(
) (id int64, err error) {
var nullableID sql.NullInt64
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
}
// stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt)
var rows []outputRoomEventCosmosMaxNumber
err = cosmosdbapi.PerformQueryAllPartitions(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.selectMaxEventIDStmt, params, &rows)
rows, err := queryOutputRoomEventNumber(s, ctx, s.selectMaxEventIDStmt, params)
// err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if rows != nil {
@ -464,7 +441,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
// excludeFromSync,
// )
data := OutputRoomEventCosmos{
data := outputRoomEventCosmos{
ID: int64(streamPos),
RoomID: event.RoomID(),
EventID: event.EventID(),
@ -482,14 +459,12 @@ func (s *outputRoomEventsStatements) InsertEvent(
data.TransactionID = *txnID
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
// id INTEGER PRIMARY KEY,
docId := fmt.Sprintf("%d", streamPos)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var dbData = OutputRoomEventCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
var dbData = outputRoomEventCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
OutputRoomEvent: data,
}
@ -521,9 +496,8 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
query = selectRecentEventsSQL
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
"@x3": r.Low(),
"@x4": r.High(),
@ -538,7 +512,12 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
)
// rows, err := stmt.QueryContext(ctx, params...)
rows, err := queryOutputRoomEvent(s, ctx, query, params)
var rows []outputRoomEventCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), query, params, &rows)
if err != nil {
return nil, false, err
@ -577,9 +556,8 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
// " WHERE room_id = $1 AND id > $2 AND id <= $3"
// // WHEN, ORDER BY (and not LIMIT) are appended by prepareWithFilters
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
"@x3": r.Low(),
"@x4": r.High(),
@ -593,7 +571,13 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
)
// rows, err := stmt.QueryContext(ctx, params...)
rows, err := queryOutputRoomEvent(s, ctx, stmt, params)
var rows []outputRoomEventCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), stmt, params, &rows)
if err != nil {
return nil, err
}
@ -622,14 +606,19 @@ func (s *outputRoomEventsStatements) SelectEvents(
// stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
for _, eventID := range eventIDs {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventID,
}
// rows, err := stmt.QueryContext(ctx, eventID)
rows, err := queryOutputRoomEvent(s, ctx, s.selectEventsStmt, params)
var rows []outputRoomEventCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectEventsStmt, params, &rows)
if err != nil {
return nil, err
}
@ -645,14 +634,19 @@ func (s *outputRoomEventsStatements) DeleteEventsForRoom(
) (err error) {
// "DELETE FROM syncapi_output_room_events WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
}
// _, err = sqlutil.TxStmt(txn, s.deleteEventsForRoomStmt).ExecContext(ctx, roomID)
rows, err := queryOutputRoomEvent(s, ctx, s.deleteEventsForRoomStmt, params)
var rows []outputRoomEventCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteEventsForRoomStmt, params, &rows)
if err != nil {
return err
}
@ -664,7 +658,7 @@ func (s *outputRoomEventsStatements) DeleteEventsForRoom(
return err
}
func rowsToStreamEvents(rows *[]OutputRoomEventCosmosData) ([]types.StreamEvent, error) {
func rowsToStreamEvents(rows *[]outputRoomEventCosmosData) ([]types.StreamEvent, error) {
var result []types.StreamEvent
for _, item := range *rows {
var (

View file

@ -39,16 +39,16 @@ import (
// -- CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, stream_position, room_id);
// `
type OutputRoomEventTopologyCosmos struct {
type outputRoomEventTopologyCosmos struct {
EventID string `json:"event_id"`
TopologicalPosition int64 `json:"topological_position"`
StreamPosition int64 `json:"stream_position"`
RoomID string `json:"room_id"`
}
type OutputRoomEventTopologyCosmosData struct {
type outputRoomEventTopologyCosmosData struct {
cosmosdbapi.CosmosDocument
OutputRoomEventTopology OutputRoomEventTopologyCosmos `json:"mx_syncapi_output_room_event_topology"`
OutputRoomEventTopology outputRoomEventTopologyCosmos `json:"mx_syncapi_output_room_event_topology"`
}
// const insertEventInTopologySQL = "" +
@ -126,29 +126,16 @@ type outputRoomEventsTopologyStatements struct {
tableName string
}
func queryOutputRoomEventTopology(s *outputRoomEventsTopologyStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OutputRoomEventTopologyCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []OutputRoomEventTopologyCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *outputRoomEventsTopologyStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func getOutputRoomEventTopology(s *outputRoomEventsTopologyStatements, ctx context.Context, pk string, docId string) (*OutputRoomEventTopologyCosmosData, error) {
response := OutputRoomEventTopologyCosmosData{}
func (s *outputRoomEventsTopologyStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getOutputRoomEventTopology(s *outputRoomEventsTopologyStatements, ctx context.Context, pk string, docId string) (*outputRoomEventTopologyCosmosData, error) {
response := outputRoomEventTopologyCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -164,7 +151,7 @@ func getOutputRoomEventTopology(s *outputRoomEventsTopologyStatements, ctx conte
return &response, err
}
func deleteOutputRoomEventTopology(s *outputRoomEventsTopologyStatements, ctx context.Context, dbData OutputRoomEventTopologyCosmosData) error {
func deleteOutputRoomEventTopology(s *outputRoomEventsTopologyStatements, ctx context.Context, dbData outputRoomEventTopologyCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
@ -203,26 +190,24 @@ func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
// " VALUES ($1, $2, $3, $4)" +
// " ON CONFLICT DO NOTHING"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE(topological_position, room_id, stream_position)
docId := fmt.Sprintf("%d_%s_%d", event.Depth(), event.RoomID(), pos)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var err error
dbData, _ := getOutputRoomEventTopology(s, ctx, pk, cosmosDocId)
dbData, _ := getOutputRoomEventTopology(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
// " ON CONFLICT DO NOTHING"
} else {
data := OutputRoomEventTopologyCosmos{
data := outputRoomEventTopologyCosmos{
EventID: event.EventID(),
TopologicalPosition: event.Depth(),
RoomID: event.RoomID(),
StreamPosition: int64(pos),
}
dbData = &OutputRoomEventTopologyCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &outputRoomEventTopologyCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
OutputRoomEventTopology: data,
}
// _, err := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt).ExecContext(
@ -265,9 +250,8 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
}
// Query the event IDs.
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
"@x3": minDepth,
"@x4": maxDepth,
@ -275,8 +259,12 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
"@x6": maxStreamPos,
"@x7": limit,
}
rows, err := queryOutputRoomEventTopology(s, ctx, stmt, params)
var rows []outputRoomEventTopologyCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), stmt, params, &rows)
// rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit)
if err == sql.ErrNoRows {
@ -308,13 +296,17 @@ func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology(
// "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" +
// " WHERE event_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": eventID,
}
var rows []outputRoomEventTopologyCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectPositionInTopologyStmt, params, &rows)
rows, err := queryOutputRoomEventTopology(s, ctx, s.selectPositionInTopologyStmt, params)
// stmt := sqlutil.TxStmt(txn, s.selectPositionInTopologyStmt)
if err != nil {
@ -342,13 +334,17 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
// ") ORDER BY stream_position DESC LIMIT 1"
// stmt := sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
}
var rows []outputRoomEventTopologyCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectMaxPositionInTopologyStmt, params, &rows)
rows, err := queryOutputRoomEventTopology(s, ctx, s.selectMaxPositionInTopologyStmt, params)
// err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
if err != nil {
@ -369,13 +365,17 @@ func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
) (err error) {
// "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
}
var rows []outputRoomEventTopologyCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteTopologyForRoomStmt, params, &rows)
rows, err := queryOutputRoomEventTopology(s, ctx, s.deleteTopologyForRoomStmt, params)
// _, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
if err != nil {

View file

@ -42,7 +42,7 @@ import (
// CREATE INDEX IF NOT EXISTS syncapi_peeks_user_id_device_id_idx ON syncapi_peeks(user_id, device_id);
// `
type PeekCosmos struct {
type peekCosmos struct {
ID int64 `json:"id"`
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
@ -52,13 +52,13 @@ type PeekCosmos struct {
// creation_ts int64 `json:"creation_ts"`
}
type PeekCosmosMaxNumber struct {
type peekCosmosMaxNumber struct {
Max int64 `json:"number"`
}
type PeekCosmosData struct {
type peekCosmosData struct {
cosmosdbapi.CosmosDocument
Peek PeekCosmos `json:"mx_syncapi_peek"`
Peek peekCosmos `json:"mx_syncapi_peek"`
}
// const insertPeekSQL = "" +
@ -115,50 +115,16 @@ type peekStatements struct {
tableName string
}
func queryPeek(s *peekStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PeekCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []PeekCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *peekStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func queryPeekMaxNumber(s *peekStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PeekCosmosMaxNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []PeekCosmosMaxNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, nil
}
return response, nil
func (s *peekStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getPeek(s *peekStatements, ctx context.Context, pk string, docId string) (*PeekCosmosData, error) {
response := PeekCosmosData{}
func getPeek(s *peekStatements, ctx context.Context, pk string, docId string) (*peekCosmosData, error) {
response := peekCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -174,7 +140,7 @@ func getPeek(s *peekStatements, ctx context.Context, pk string, docId string) (*
return &response, err
}
func setPeek(s *peekStatements, ctx context.Context, peek PeekCosmosData) (*PeekCosmosData, error) {
func setPeek(s *peekStatements, ctx context.Context, peek peekCosmosData) (*peekCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(peek.Pk, peek.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
@ -213,28 +179,26 @@ func (s *peekStatements) InsertPeek(
// " (id, room_id, user_id, device_id, creation_ts, deleted)" +
// " VALUES ($1, $2, $3, $4, $5, false)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE(room_id, user_id, device_id)
docId := fmt.Sprintf("%d_%s_%d", roomID, userID, deviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
docId := fmt.Sprintf("%s_%s_%s", roomID, userID, deviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getPeek(s, ctx, pk, cosmosDocId)
dbData, _ := getPeek(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
// " (id, room_id, user_id, device_id, creation_ts, deleted)" +
// " VALUES ($1, $2, $3, $4, $5, false)"
dbData.SetUpdateTime()
dbData.Peek.Deleted = false
} else {
data := PeekCosmos{
data := peekCosmos{
ID: int64(streamPos),
RoomID: roomID,
UserID: userID,
DeviceID: deviceID,
}
dbData = &PeekCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &peekCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
Peek: data,
}
}
@ -257,15 +221,20 @@ func (s *peekStatements) DeletePeek(
// "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3 AND device_id = $4"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
"@x3": userID,
"@x4": deviceID,
}
rows, err := queryPeek(s, ctx, s.deletePeekStmt, params)
var rows []peekCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deletePeekStmt, params, &rows)
// _, err = sqlutil.TxStmt(txn, s.deletePeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID)
numAffected := len(rows)
@ -295,14 +264,19 @@ func (s *peekStatements) DeletePeeks(
) (types.StreamPosition, error) {
// "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": roomID,
"@x3": userID,
}
rows, err := queryPeek(s, ctx, s.deletePeekStmt, params)
var rows []peekCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deletePeekStmt, params, &rows)
// result, err := sqlutil.TxStmt(txn, s.deletePeeksStmt).ExecContext(ctx, streamPos, roomID, userID)
if err != nil {
return 0, err
@ -334,16 +308,20 @@ func (s *peekStatements) SelectPeeksInRange(
) (peeks []types.Peek, err error) {
// "SELECT id, room_id, deleted FROM syncapi_peeks WHERE user_id = $1 AND device_id = $2 AND ((id <= $3 AND NOT deleted=true) OR (id > $3 AND id <= $4))"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": userID,
"@x3": deviceID,
"@x4": r.Low(),
"@x5": r.High(),
}
var rows []peekCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectPeeksInRangeStmt, params, &rows)
rows, err := queryPeek(s, ctx, s.selectPeeksInRangeStmt, params)
// rows, err := sqlutil.TxStmt(txn, s.selectPeeksInRangeStmt).QueryContext(ctx, userID, deviceID, r.Low(), r.High())
if err != nil {
return
@ -371,12 +349,17 @@ func (s *peekStatements) SelectPeekingDevices(
// "SELECT room_id, user_id, device_id FROM syncapi_peeks WHERE deleted=false"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
}
rows, err := queryPeek(s, ctx, s.selectPeekingDevicesStmt, params)
var rows []peekCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectPeekingDevicesStmt, params, &rows)
// rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx)
if err != nil {
return nil, err
@ -405,12 +388,16 @@ func (s *peekStatements) SelectMaxPeekID(
// stmt := sqlutil.TxStmt(txn, s.selectMaxPeekIDStmt)
var nullableID sql.NullInt64
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
}
var rows []peekCosmosMaxNumber
err = cosmosdbapi.PerformQueryAllPartitions(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.selectMaxPeekIDStmt, params, &rows)
rows, err := queryPeekMaxNumber(s, ctx, s.selectMaxPeekIDStmt, params)
// err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if rows != nil {

View file

@ -41,7 +41,7 @@ import (
// CREATE INDEX IF NOT EXISTS syncapi_receipts_room_id_idx ON syncapi_receipts(room_id);
// `
type ReceiptCosmos struct {
type receiptCosmos struct {
ID int64 `json:"id"`
RoomID string `json:"room_id"`
ReceiptType string `json:"receipt_type"`
@ -50,13 +50,13 @@ type ReceiptCosmos struct {
ReceiptTS int64 `json:"receipt_ts"`
}
type ReceiptCosmosMaxNumber struct {
type receiptCosmosMaxNumber struct {
Max int64 `json:"number"`
}
type ReceiptCosmosData struct {
type receiptCosmosData struct {
cosmosdbapi.CosmosDocument
Receipt ReceiptCosmos `json:"mx_syncapi_receipt"`
Receipt receiptCosmos `json:"mx_syncapi_receipt"`
}
// const upsertReceipt = "" +
@ -87,46 +87,12 @@ type receiptStatements struct {
tableName string
}
func queryReceipt(s *receiptStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ReceiptCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []ReceiptCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *receiptStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func queryReceiptNumber(s *receiptStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ReceiptCosmosMaxNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []ReceiptCosmosMaxNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, nil
}
return response, nil
func (s *receiptStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func NewCosmosDBReceiptsTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.Receipts, error) {
@ -152,7 +118,11 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
// " ON CONFLICT (room_id, receipt_type, user_id)" +
// " DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9"
data := ReceiptCosmos{
// CONSTRAINT syncapi_receipts_unique UNIQUE (room_id, receipt_type, user_id)
docId := fmt.Sprintf("%s_%s_%s", roomId, receiptType, userId)
cosmosDocId := cosmosdbapi.GetDocumentId(r.db.cosmosConfig.ContainerName, r.getCollectionName(), docId)
data := receiptCosmos{
ID: int64(pos),
RoomID: roomId,
ReceiptType: receiptType,
@ -161,14 +131,8 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
ReceiptTS: int64(timestamp),
}
var dbCollectionName = cosmosdbapi.GetCollectionName(r.db.databaseName, r.tableName)
var pk = cosmosdbapi.GetPartitionKey(r.db.cosmosConfig.ContainerName, dbCollectionName)
// CONSTRAINT syncapi_receipts_unique UNIQUE (room_id, receipt_type, user_id)
docId := fmt.Sprintf("%s_%s_%s", roomId, receiptType, userId)
cosmosDocId := cosmosdbapi.GetDocumentId(r.db.cosmosConfig.ContainerName, dbCollectionName, docId)
var dbData = ReceiptCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, r.db.cosmosConfig.TenantName, pk, cosmosDocId),
var dbData = receiptCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(r.getCollectionName(), r.db.cosmosConfig.TenantName, r.getPartitionKey(), cosmosDocId),
Receipt: data,
}
@ -197,14 +161,18 @@ func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs
// for k, v := range roomIDs {
// params[k+1] = v
var dbCollectionName = cosmosdbapi.GetCollectionName(r.db.databaseName, r.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": r.getCollectionName(),
"@x2": streamPos,
"@x3": roomIDs,
}
var rows []receiptCosmosData
err := cosmosdbapi.PerformQuery(ctx,
r.db.connection,
r.db.cosmosConfig.DatabaseName,
r.db.cosmosConfig.ContainerName,
r.getPartitionKey(), selectRoomReceipts, params, &rows)
rows, err := queryReceipt(r, ctx, selectRoomReceipts, params)
// rows, err := r.db.QueryContext(ctx, selectSQL, params...)
if err != nil {
return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
@ -239,12 +207,16 @@ func (s *receiptStatements) SelectMaxReceiptID(
// "SELECT MAX(id) FROM syncapi_receipts"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
}
var rows []receiptCosmosMaxNumber
err = cosmosdbapi.PerformQueryAllPartitions(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.selectMaxReceiptID, params, &rows)
rows, err := queryReceiptNumber(s, ctx, s.selectMaxReceiptID, params)
// stmt := sqlutil.TxStmt(txn, s.selectMaxReceiptID)
if rows != nil {

View file

@ -94,46 +94,12 @@ type sendToDeviceStatements struct {
tableName string
}
func querySendToDevice(s *sendToDeviceStatements, ctx context.Context, qry string, params map[string]interface{}) ([]SendToDeviceCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []SendToDeviceCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *sendToDeviceStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func querySendToDeviceNumber(s *sendToDeviceStatements, ctx context.Context, qry string, params map[string]interface{}) ([]SendToDeviceCosmosMaxNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []SendToDeviceCosmosMaxNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, nil
}
return response, nil
func (s *sendToDeviceStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func deleteSendToDevice(s *sendToDeviceStatements, ctx context.Context, dbData SendToDeviceCosmosData) error {
@ -180,6 +146,10 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
// INSERT INTO syncapi_send_to_device (user_id, device_id, content)
// VALUES ($1, $2, $3)
// NO CONSTRAINT
docId := fmt.Sprintf("%d", pos)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
data := SendToDeviceCosmos{
ID: int64(pos),
UserID: userID,
@ -187,14 +157,8 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
Content: content,
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
// NO CONSTRAINT
docId := fmt.Sprintf("%d", pos)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
var dbData = SendToDeviceCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
SendToDevice: data,
}
@ -217,16 +181,21 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
// WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
// ORDER BY id DESC
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": userID,
"@x3": deviceID,
"@x4": from,
"@x5": to,
}
rows, err := querySendToDevice(s, ctx, s.selectSendToDeviceMessagesStmt, params)
var rows []SendToDeviceCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectSendToDeviceMessagesStmt, params, &rows)
if err != nil {
return
}
@ -268,16 +237,21 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
// DELETE FROM syncapi_send_to_device
// WHERE user_id = $1 AND device_id = $2 AND id < $3
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": userID,
"@x3": deviceID,
"@x4": pos,
}
// _, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos)
rows, err := querySendToDevice(s, ctx, s.deleteSendToDeviceMessagesStmt, params)
var rows []SendToDeviceCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.deleteSendToDeviceMessagesStmt, params, &rows)
if err != nil {
return err
}
@ -297,12 +271,16 @@ func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID(
var nullableID sql.NullInt64
// "SELECT MAX(id) FROM syncapi_send_to_device"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
}
var rows []SendToDeviceCosmosMaxNumber
err = cosmosdbapi.PerformQueryAllPartitions(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.selectMaxSendToDeviceIDStmt, params, &rows)
rows, err := querySendToDeviceNumber(s, ctx, s.selectMaxSendToDeviceIDStmt, params)
// stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt)
// err = stmt.QueryRowContext(ctx).Scan(&nullableID)

View file

@ -38,18 +38,18 @@ import (
// );
// `
type AccountDataCosmosData struct {
cosmosdbapi.CosmosDocument
AccountData AccountDataCosmos `json:"mx_userapi_accountdata"`
}
type AccountDataCosmos struct {
type accountDataCosmos struct {
LocalPart string `json:"local_part"`
RoomId string `json:"room_id"`
Type string `json:"type"`
Content []byte `json:"content"`
}
type accountDataCosmosData struct {
cosmosdbapi.CosmosDocument
AccountData accountDataCosmos `json:"mx_userapi_accountdata"`
}
type accountDataStatements struct {
db *Database
// insertAccountDataStmt *sql.Stmt
@ -58,6 +58,14 @@ type accountDataStatements struct {
tableName string
}
func (s *accountDataStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *accountDataStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func (s *accountDataStatements) prepare(db *Database) (err error) {
s.db = db
s.selectAccountDataStmt = "select * from c where c._cn = @x1 and c.mx_userapi_accountdata.local_part = @x2"
@ -66,8 +74,8 @@ func (s *accountDataStatements) prepare(db *Database) (err error) {
return
}
func getAccountData(s *accountDataStatements, ctx context.Context, pk string, docId string) (*AccountDataCosmosData, error) {
response := AccountDataCosmosData{}
func getAccountData(s *accountDataStatements, ctx context.Context, pk string, docId string) (*accountDataCosmosData, error) {
response := accountDataCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -83,34 +91,12 @@ func getAccountData(s *accountDataStatements, ctx context.Context, pk string, do
return &response, err
}
func queryAccountData(s *accountDataStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountDataCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []AccountDataCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func (s *accountDataStatements) insertAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
// INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
// ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
id := ""
if roomID == "" {
id = fmt.Sprintf("%s_%s", localpart, dataType)
@ -119,24 +105,23 @@ func (s *accountDataStatements) insertAccountData(
}
docId := id
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
dbData, _ := getAccountData(s, ctx, pk, cosmosDocId)
dbData, _ := getAccountData(s, ctx, s.getPartitionKey(), cosmosDocId)
if dbData != nil {
// ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
dbData.SetUpdateTime()
dbData.AccountData.Content = content
} else {
var result = AccountDataCosmos{
var result = accountDataCosmos{
LocalPart: localpart,
RoomId: roomID,
Type: dataType,
Content: content,
}
dbData = &AccountDataCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData = &accountDataCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
AccountData: result,
}
}
@ -157,13 +142,16 @@ func (s *accountDataStatements) selectAccountData(
error,
) {
// "SELECT room_id, type, content FROM account_data WHERE localpart = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": localpart,
}
response, err := queryAccountData(s, ctx, s.selectAccountDataStmt, params)
var rows []accountDataCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectAccountDataStmt, params, &rows)
if err != nil {
return nil, nil, err
@ -172,8 +160,8 @@ func (s *accountDataStatements) selectAccountData(
global := map[string]json.RawMessage{}
rooms := map[string]map[string]json.RawMessage{}
for i := 0; i < len(response); i++ {
var row = response[i]
for i := 0; i < len(rows); i++ {
var row = rows[i]
var roomID = row.AccountData.RoomId
if roomID != "" {
if _, ok := rooms[row.AccountData.RoomId]; !ok {
@ -194,25 +182,28 @@ func (s *accountDataStatements) selectAccountDataByType(
var bytes []byte
// "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accountDatas.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": localpart,
"@x3": roomID,
"@x4": dataType,
}
response, err := queryAccountData(s, ctx, s.selectAccountDataByTypeStmt, params)
var rows []accountDataCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectAccountDataByTypeStmt, params, &rows)
if err != nil {
return nil, err
}
if len(response) == 0 {
if len(rows) == 0 {
return data, nil
}
bytes = response[0].AccountData.Content
bytes = rows[0].AccountData.Content
data = json.RawMessage(bytes)
return

View file

@ -46,7 +46,7 @@ import (
// );
// `
type AccountCosmos struct {
type accountCosmos struct {
UserID string `json:"user_id"`
Localpart string `json:"local_part"`
ServerName gomatrixserverlib.ServerName `json:"server_name"`
@ -56,12 +56,12 @@ type AccountCosmos struct {
Created int64 `json:"created_ts"`
}
type AccountCosmosData struct {
type accountCosmosData struct {
cosmosdbapi.CosmosDocument
Account AccountCosmos `json:"mx_userapi_account"`
Account accountCosmos `json:"mx_userapi_account"`
}
type AccountCosmosUserCount struct {
type accountCosmosUserCount struct {
UserCount int64 `json:"usercount"`
}
@ -74,6 +74,14 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName
}
func (s *accountsStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *accountsStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.selectPasswordHashStmt = "select * from c where c._cn = @x1 and c.mx_userapi_account.local_part = @x2 and c.mx_userapi_account.is_deactivated = false"
@ -84,29 +92,8 @@ func (s *accountsStatements) prepare(db *Database, server gomatrixserverlib.Serv
return
}
func queryAccount(s *accountsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []AccountCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getAccount(s *accountsStatements, ctx context.Context, pk string, docId string) (*AccountCosmosData, error) {
response := AccountCosmosData{}
func getAccount(s *accountsStatements, ctx context.Context, pk string, docId string) (*accountCosmosData, error) {
response := accountCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -122,8 +109,8 @@ func getAccount(s *accountsStatements, ctx context.Context, pk string, docId str
return &response, err
}
func setAccount(s *accountsStatements, ctx context.Context, account AccountCosmosData) (*AccountCosmosData, error) {
response := AccountCosmosData{}
func setAccount(s *accountsStatements, ctx context.Context, account accountCosmosData) (*accountCosmosData, error) {
response := accountCosmosData{}
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(account.Pk, account.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
@ -135,7 +122,7 @@ func setAccount(s *accountsStatements, ctx context.Context, account AccountCosmo
return &response, ex
}
func mapFromAccount(db AccountCosmos) api.Account {
func mapFromAccount(db accountCosmos) api.Account {
return api.Account{
AppServiceID: db.AppServiceID,
Localpart: db.Localpart,
@ -144,8 +131,8 @@ func mapFromAccount(db AccountCosmos) api.Account {
}
}
func mapToAccount(api api.Account) AccountCosmos {
return AccountCosmos{
func mapToAccount(api api.Account) accountCosmos {
return accountCosmos{
AppServiceID: api.AppServiceID,
Localpart: api.Localpart,
ServerName: api.ServerName,
@ -175,14 +162,11 @@ func (s *accountsStatements) insertAccount(
data.PasswordHash = hash
data.IsDeactivated = false
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
docId := result.Localpart
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var dbData = AccountCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
var dbData = accountCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
Account: data,
}
@ -206,12 +190,10 @@ func (s *accountsStatements) updatePassword(
) (err error) {
// "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
docId := localpart
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var response, exGet = getAccount(s, ctx, pk, cosmosDocId)
var response, exGet = getAccount(s, ctx, s.getPartitionKey(), cosmosDocId)
if exGet != nil {
return exGet
}
@ -230,13 +212,10 @@ func (s *accountsStatements) deactivateAccount(
) (err error) {
// "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
docId := localpart
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var response, exGet = getAccount(s, ctx, pk, cosmosDocId)
var response, exGet = getAccount(s, ctx, s.getPartitionKey(), cosmosDocId)
if exGet != nil {
return exGet
}
@ -255,27 +234,30 @@ func (s *accountsStatements) selectPasswordHash(
) (hash string, err error) {
// "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": localpart,
}
response, err := queryAccount(s, ctx, s.selectPasswordHashStmt, params)
var rows []accountCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectPasswordHashStmt, params, &rows)
if err != nil {
return "", err
}
if len(response) == 0 {
if len(rows) == 0 {
return "", errors.New(fmt.Sprintf("Localpart %s not found", localpart))
}
if len(response) != 1 {
if len(rows) != 1 {
return "", errors.New(fmt.Sprintf("Localpart %s has multiple entries", localpart))
}
return response[0].Account.PasswordHash, nil
return rows[0].Account.PasswordHash, nil
}
func (s *accountsStatements) selectAccountByLocalpart(
@ -284,23 +266,26 @@ func (s *accountsStatements) selectAccountByLocalpart(
var acc api.Account
// "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": localpart,
}
response, err := queryAccount(s, ctx, s.selectAccountByLocalpartStmt, params)
var rows []accountCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectAccountByLocalpartStmt, params, &rows)
if err != nil {
return nil, err
}
if len(response) == 0 {
if len(rows) == 0 {
return nil, nil
}
acc = mapFromAccount(response[0].Account)
acc = mapFromAccount(rows[0].Account)
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
acc.ServerName = s.serverName
@ -312,25 +297,20 @@ func (s *accountsStatements) selectNewNumericLocalpart(
) (id int64, err error) {
// "SELECT COUNT(localpart) FROM account_accounts"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []AccountCosmosUserCount
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
}
var options = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectNewNumericLocalpartStmt, params)
var _, ex = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
var rows []accountCosmosUserCount
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
options)
s.getPartitionKey(), s.selectNewNumericLocalpartStmt, params, &rows)
if ex != nil {
return -1, ex
if err != nil {
return -1, err
}
return int64(response[0].UserCount), nil
return int64(rows[0].UserCount), nil
}

View file

@ -40,12 +40,7 @@ import (
// CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
// `
type KeyBackupCosmosData struct {
cosmosdbapi.CosmosDocument
KeyBackup KeyBackupCosmos `json:"mx_userapi_account_e2e_room_keys"`
}
type KeyBackupCosmos struct {
type keyBackupCosmos struct {
UserId string `json:"user_id"`
RoomId string `json:"room_id"`
SessionId string `json:"session_id"`
@ -56,7 +51,12 @@ type KeyBackupCosmos struct {
SessionData []byte `json:"session_data"`
}
type KeyBackupCosmosNumber struct {
type keyBackupCosmosData struct {
cosmosdbapi.CosmosDocument
KeyBackup keyBackupCosmos `json:"mx_userapi_account_e2e_room_keys"`
}
type keyBackupCosmosNumber struct {
Number int64 `json:"number"`
}
@ -110,50 +110,16 @@ type keyBackupStatements struct {
serverName gomatrixserverlib.ServerName
}
func queryKeyBackup(s *keyBackupStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyBackupCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []KeyBackupCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *keyBackupStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func queryKeyBackupNumber(s *keyBackupStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyBackupCosmosNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []KeyBackupCosmosNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
func (s *keyBackupStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getKeyBackup(s *keyBackupStatements, ctx context.Context, pk string, docId string) (*KeyBackupCosmosData, error) {
response := KeyBackupCosmosData{}
func getKeyBackup(s *keyBackupStatements, ctx context.Context, pk string, docId string) (*keyBackupCosmosData, error) {
response := keyBackupCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -169,7 +135,7 @@ func getKeyBackup(s *keyBackupStatements, ctx context.Context, pk string, docId
return &response, err
}
func setKeyBackup(s *keyBackupStatements, ctx context.Context, keyBackup KeyBackupCosmosData) (*KeyBackupCosmosData, error) {
func setKeyBackup(s *keyBackupStatements, ctx context.Context, keyBackup keyBackupCosmosData) (*keyBackupCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(keyBackup.Pk, keyBackup.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
@ -200,13 +166,17 @@ func (s keyBackupStatements) countKeys(
// "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2"
// err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": userID,
"@x3": version,
}
rows, err := queryKeyBackupNumber(&s, ctx, s.countKeysStmt, params)
var rows []keyBackupCosmosNumber
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.countKeysStmt, params, &rows)
if err != nil {
return -1, err
@ -228,13 +198,11 @@ func (s *keyBackupStatements) insertBackupKey(
// _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
// ctx, userID, key.RoomID, key.SessionID, version, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData),
// )
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
docId := fmt.Sprintf("%s_%s_%s_%s", userID, key.RoomID, key.SessionID, version)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
data := KeyBackupCosmos{
data := keyBackupCosmos{
UserId: userID,
RoomId: key.RoomID,
SessionId: key.SessionID,
@ -245,8 +213,8 @@ func (s *keyBackupStatements) insertBackupKey(
SessionData: key.SessionData,
}
dbData := &KeyBackupCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData := &keyBackupCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
KeyBackup: data,
}
@ -270,13 +238,11 @@ func (s *keyBackupStatements) updateBackupKey(
// ctx, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), userID, key.RoomID, key.SessionID, version,
// )
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
docId := fmt.Sprintf("%s_%s_%s_%s", userID, key.RoomID, key.SessionID, version)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
res, err := getKeyBackup(s, ctx, pk, cosmosDocId)
res, err := getKeyBackup(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return
@ -302,13 +268,17 @@ func (s *keyBackupStatements) selectKeys(
) (map[string]map[string]api.KeyBackupSession, error) {
// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
// "WHERE user_id = $1 AND version = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": userID,
"@x3": version,
}
rows, err := queryKeyBackup(s, ctx, s.selectKeysStmt, params)
var rows []keyBackupCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectKeysStmt, params, &rows)
if err != nil {
return nil, err
@ -327,14 +297,18 @@ func (s *keyBackupStatements) selectKeysByRoomID(
) (map[string]map[string]api.KeyBackupSession, error) {
// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
// "WHERE user_id = $1 AND version = $2 AND room_id = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": userID,
"@x3": version,
"@x4": roomID,
}
rows, err := queryKeyBackup(s, ctx, s.selectKeysByRoomIDStmt, params)
var rows []keyBackupCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectKeysByRoomIDStmt, params, &rows)
if err != nil {
return nil, err
@ -355,15 +329,19 @@ func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
) (map[string]map[string]api.KeyBackupSession, error) {
// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
// "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": userID,
"@x3": version,
"@x4": roomID,
"@x5": sessionID,
}
rows, err := queryKeyBackup(s, ctx, s.selectKeysByRoomIDAndSessionIDStmt, params)
var rows []keyBackupCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectKeysByRoomIDAndSessionIDStmt, params, &rows)
if err != nil {
return nil, err
@ -379,7 +357,7 @@ func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
return unpackKeys(ctx, &rows)
}
func unpackKeys(ctx context.Context, rows *[]KeyBackupCosmosData) (map[string]map[string]api.KeyBackupSession, error) {
func unpackKeys(ctx context.Context, rows *[]keyBackupCosmosData) (map[string]map[string]api.KeyBackupSession, error) {
result := make(map[string]map[string]api.KeyBackupSession)
for _, item := range *rows {
var key api.InternalKeyBackupSession

View file

@ -41,12 +41,7 @@ import (
// CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
// `
type KeyBackupVersionCosmosData struct {
cosmosdbapi.CosmosDocument
KeyBackupVersion KeyBackupVersionCosmos `json:"mx_userapi_account_e2e_room_keys_versions"`
}
type KeyBackupVersionCosmos struct {
type keyBackupVersionCosmos struct {
UserId string `json:"user_id"`
Version int64 `json:"vesion"`
Algorithm string `json:"algorithm"`
@ -55,7 +50,12 @@ type KeyBackupVersionCosmos struct {
Deleted int `json:"deleted"`
}
type KeyBackupVersionCosmosNumber struct {
type keyBackupVersionCosmosData struct {
cosmosdbapi.CosmosDocument
KeyBackupVersion keyBackupVersionCosmos `json:"mx_userapi_account_e2e_room_keys_versions"`
}
type keyBackupVersionCosmosNumber struct {
Number int64 `json:"number"`
}
@ -91,33 +91,16 @@ type keyBackupVersionStatements struct {
serverName gomatrixserverlib.ServerName
}
func queryKeyBackupVersionNumber(s *keyBackupVersionStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyBackupVersionCosmosNumber, error) {
var response []KeyBackupVersionCosmosNumber
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
var query = cosmosdbapi.GetQuery(qry, params)
var _, _ = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
//WHen there is no data these GroupBy queries return errors
// if err != nil {
// return nil, err
// }
if len(response) == 0 {
return nil, cosmosdbutil.ErrNoRows
}
return response, nil
func (s *keyBackupVersionStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func getKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, pk string, docId string) (*KeyBackupVersionCosmosData, error) {
response := KeyBackupVersionCosmosData{}
func (s *keyBackupVersionStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func getKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, pk string, docId string) (*keyBackupVersionCosmosData, error) {
response := keyBackupVersionCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -133,7 +116,7 @@ func getKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, pk
return &response, err
}
func setKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, keyBackup KeyBackupVersionCosmosData) (*KeyBackupVersionCosmosData, error) {
func setKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, keyBackup keyBackupVersionCosmosData) (*keyBackupVersionCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(keyBackup.Pk, keyBackup.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
@ -171,14 +154,11 @@ func (s *keyBackupVersionStatements) insertKeyBackup(
return "", seqErr
}
// err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData), etag).Scan(&versionInt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
docId := fmt.Sprintf("%s_%d", userID, versionInt)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
data := KeyBackupVersionCosmos{
data := keyBackupVersionCosmos{
UserId: userID,
Version: versionInt,
Algorithm: algorithm,
@ -187,8 +167,8 @@ func (s *keyBackupVersionStatements) insertKeyBackup(
Deleted: 0,
}
dbData := &KeyBackupVersionCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
dbData := &keyBackupVersionCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
KeyBackupVersion: data,
}
@ -211,13 +191,11 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
if err != nil {
return fmt.Errorf("invalid version")
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
docId := fmt.Sprintf("%s_%d", userID, versionInt)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
res, err := getKeyBackupVersion(s, ctx, pk, cosmosDocId)
res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return err
@ -243,13 +221,11 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag(
if err != nil {
return fmt.Errorf("invalid version")
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
docId := fmt.Sprintf("%s_%d", userID, versionInt)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
res, err := getKeyBackupVersion(s, ctx, pk, cosmosDocId)
res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return err
@ -275,13 +251,11 @@ func (s *keyBackupVersionStatements) deleteKeyBackup(
if err != nil {
return false, fmt.Errorf("invalid version")
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
docId := fmt.Sprintf("%s_%d", userID, versionInt)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
res, err := getKeyBackupVersion(s, ctx, pk, cosmosDocId)
res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return false, err
@ -309,17 +283,21 @@ func (s *keyBackupVersionStatements) selectKeyBackup(
var versionInt int64
if version == "" {
// var v *int64 // allows nulls
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": s.db.cosmosConfig.TenantName,
"@x2": dbCollectionName,
"@x2": s.getCollectionName(),
"@x3": userID,
}
// err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
response, err1 := queryKeyBackupVersionNumber(s, ctx, s.selectLatestVersionStmt, params)
var rows []keyBackupVersionCosmosNumber
err = cosmosdbapi.PerformQueryAllPartitions(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.selectLatestVersionStmt, params, &rows)
if err1 != nil {
if err != nil {
if err == cosmosdbutil.ErrNoRows {
err = nil
}
@ -327,12 +305,12 @@ func (s *keyBackupVersionStatements) selectKeyBackup(
// if err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&v); err != nil {
// return
// }
if response == nil || len(response) == 0 {
if rows == nil || len(rows) == 0 {
err = cosmosdbutil.ErrNoRows
versionInt = 0
return
}
versionInt = response[0].Number
versionInt = rows[0].Number
} else {
if versionInt, err = strconv.ParseInt(version, 10, 64); err != nil {
return
@ -342,13 +320,11 @@ func (s *keyBackupVersionStatements) selectKeyBackup(
if err != nil {
return
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
docId := fmt.Sprintf("%s_%d", userID, versionInt)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
res, err := getKeyBackupVersion(s, ctx, pk, cosmosDocId)
res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(), cosmosDocId)
if err != nil {
return

View file

@ -22,7 +22,7 @@ import (
// `
// OpenIDToken represents an OpenID token
type OpenIDTokenCosmos struct {
type openIDTokenCosmos struct {
Token string `json:"token"`
UserID string `json:"user_id"`
ExpiresAtMS int64 `json:"expires_at"`
@ -30,7 +30,7 @@ type OpenIDTokenCosmos struct {
type OpenIdTokenCosmosData struct {
cosmosdbapi.CosmosDocument
OpenIdToken OpenIDTokenCosmos `json:"mx_userapi_openidtoken"`
OpenIdToken openIDTokenCosmos `json:"mx_userapi_openidtoken"`
}
type tokenStatements struct {
@ -41,7 +41,15 @@ type tokenStatements struct {
serverName gomatrixserverlib.ServerName
}
func mapFromToken(db OpenIDTokenCosmos) api.OpenIDToken {
func (s *tokenStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *tokenStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func mapFromToken(db openIDTokenCosmos) api.OpenIDToken {
return api.OpenIDToken{
ExpiresAtMS: db.ExpiresAtMS,
Token: db.Token,
@ -49,35 +57,14 @@ func mapFromToken(db OpenIDTokenCosmos) api.OpenIDToken {
}
}
func mapToToken(api api.OpenIDToken) OpenIDTokenCosmos {
return OpenIDTokenCosmos{
func mapToToken(api api.OpenIDToken) openIDTokenCosmos {
return openIDTokenCosmos{
ExpiresAtMS: api.ExpiresAtMS,
Token: api.Token,
UserID: api.UserID,
}
}
func queryOpenIdToken(s *tokenStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OpenIdTokenCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []OpenIdTokenCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func (s *tokenStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.selectTokenStmt = "select * from c where c._cn = @x1 and c.mx_userapi_openidtoken.token = @x2"
@ -95,20 +82,17 @@ func (s *tokenStatements) insertToken(
) (err error) {
// "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
docId := token
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var result = &api.OpenIDToken{
UserID: localpart,
Token: token,
ExpiresAtMS: expiresAtMS,
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
docId := result.Token
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var dbData = OpenIdTokenCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
OpenIdToken: mapToToken(*result),
}
@ -136,22 +120,26 @@ func (s *tokenStatements) selectOpenIDTokenAtrributes(
var openIDTokenAttrs api.OpenIDTokenAttributes
// "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": token,
}
response, err := queryOpenIdToken(s, ctx, s.selectTokenStmt, params)
var rows []OpenIdTokenCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectTokenStmt, params, &rows)
if err != nil {
return nil, err
}
if len(response) == 0 {
if len(rows) == 0 {
return nil, nil
}
var openIdToken = response[0].OpenIdToken
var openIdToken = rows[0].OpenIdToken
openIDTokenAttrs = api.OpenIDTokenAttributes{
UserID: openIdToken.UserID,
ExpiresAtMS: openIdToken.ExpiresAtMS,

View file

@ -38,15 +38,15 @@ import (
// `
// Profile represents the profile for a Matrix account.
type ProfileCosmos struct {
type profileCosmos struct {
Localpart string `json:"local_part"`
DisplayName string `json:"display_name"`
AvatarURL string `json:"avatar_url"`
}
type ProfileCosmosData struct {
type profileCosmosData struct {
cosmosdbapi.CosmosDocument
Profile ProfileCosmos `json:"mx_userapi_profile"`
Profile profileCosmos `json:"mx_userapi_profile"`
}
type profilesStatements struct {
@ -59,7 +59,15 @@ type profilesStatements struct {
tableName string
}
func mapFromProfile(db ProfileCosmos) authtypes.Profile {
func (s *profilesStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *profilesStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func mapFromProfile(db profileCosmos) authtypes.Profile {
return authtypes.Profile{
AvatarURL: db.AvatarURL,
DisplayName: db.DisplayName,
@ -67,8 +75,8 @@ func mapFromProfile(db ProfileCosmos) authtypes.Profile {
}
}
func mapToProfile(api authtypes.Profile) ProfileCosmos {
return ProfileCosmos{
func mapToProfile(api authtypes.Profile) profileCosmos {
return profileCosmos{
AvatarURL: api.AvatarURL,
DisplayName: api.DisplayName,
Localpart: api.Localpart,
@ -83,29 +91,8 @@ func (s *profilesStatements) prepare(db *Database) (err error) {
return
}
func queryProfile(s *profilesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ProfileCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []ProfileCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getProfile(s *profilesStatements, ctx context.Context, pk string, docId string) (*ProfileCosmosData, error) {
response := ProfileCosmosData{}
func getProfile(s *profilesStatements, ctx context.Context, pk string, docId string) (*profileCosmosData, error) {
response := profileCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -121,7 +108,7 @@ func getProfile(s *profilesStatements, ctx context.Context, pk string, docId str
return &response, err
}
func setProfile(s *profilesStatements, ctx context.Context, profile ProfileCosmosData) (*ProfileCosmosData, error) {
func setProfile(s *profilesStatements, ctx context.Context, profile profileCosmosData) (*profileCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(profile.Pk, profile.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
@ -142,14 +129,11 @@ func (s *profilesStatements) insertProfile(
Localpart: localpart,
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
docId := localpart
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var dbData = ProfileCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
var dbData = profileCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
Profile: mapToProfile(*result),
}
@ -169,27 +153,30 @@ func (s *profilesStatements) selectProfileByLocalpart(
) (*authtypes.Profile, error) {
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": localpart,
}
response, err := queryProfile(s, ctx, s.selectProfileByLocalpartStmt, params)
var rows []profileCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectProfileByLocalpartStmt, params, &rows)
if err != nil {
return nil, err
}
if len(response) == 0 {
return nil, errors.New(fmt.Sprintf("Localpart %s not found", len(response)))
if len(rows) == 0 {
return nil, errors.New(fmt.Sprintf("Localpart %s not found", len(rows)))
}
if len(response) != 1 {
return nil, errors.New(fmt.Sprintf("Localpart %s has multiple entries", len(response)))
if len(rows) != 1 {
return nil, errors.New(fmt.Sprintf("Localpart %s has multiple entries", len(rows)))
}
var result = mapFromProfile(response[0].Profile)
var result = mapFromProfile(rows[0].Profile)
return &result, nil
}
@ -198,12 +185,10 @@ func (s *profilesStatements) setAvatarURL(
) (err error) {
// "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
docId := localpart
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var response, exGet = getProfile(s, ctx, pk, cosmosDocId)
var response, exGet = getProfile(s, ctx, s.getPartitionKey(), cosmosDocId)
if exGet != nil {
return exGet
}
@ -222,11 +207,9 @@ func (s *profilesStatements) setDisplayName(
) (err error) {
// "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
docId := localpart
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response, exGet = getProfile(s, ctx, pk, cosmosDocId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var response, exGet = getProfile(s, ctx, s.getPartitionKey(), cosmosDocId)
if exGet != nil {
return exGet
}
@ -246,21 +229,24 @@ func (s *profilesStatements) selectProfilesBySearch(
var profiles []authtypes.Profile
// "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.profiles.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": searchString,
"@x3": limit,
}
response, err := queryProfile(s, ctx, s.selectProfilesBySearchStmt, params)
var rows []profileCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectProfilesBySearchStmt, params, &rows)
if err != nil {
return nil, err
}
for i := 0; i < len(response); i++ {
var responseData = response[i]
for i := 0; i < len(rows); i++ {
var responseData = rows[i]
profiles = append(profiles, mapFromProfile(responseData.Profile))
}

View file

@ -36,15 +36,15 @@ import (
// PRIMARY KEY(threepid, medium)
// );
type ThreePIDCosmos struct {
type threePIDCosmos struct {
Localpart string `json:"local_part"`
ThreePID string `json:"three_pid"`
Medium string `json:"medium"`
}
type ThreePIDCosmosData struct {
type threePIDCosmosData struct {
cosmosdbapi.CosmosDocument
ThreePID ThreePIDCosmos `json:"mx_userapi_threepid"`
ThreePID threePIDCosmos `json:"mx_userapi_threepid"`
}
type threepidStatements struct {
@ -56,6 +56,14 @@ type threepidStatements struct {
tableName string
}
func (s *threepidStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *threepidStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func (s *threepidStatements) prepare(db *Database) (err error) {
s.db = db
s.selectLocalpartForThreePIDStmt = "select * from c where c._cn = @x1 and c.mx_userapi_threepid.three_pid = @x2 and c.mx_userapi_threepid.medium = @x3"
@ -64,50 +72,33 @@ func (s *threepidStatements) prepare(db *Database) (err error) {
return
}
func queryThreePID(s *threepidStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ThreePIDCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []ThreePIDCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func (s *threepidStatements) selectLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
// "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": threepid,
"@x3": medium,
}
response, err := queryThreePID(s, ctx, s.selectLocalpartForThreePIDStmt, params)
var rows []threePIDCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectLocalpartForThreePIDStmt, params, &rows)
if err != nil {
return "", err
}
if len(response) == 0 {
if len(rows) == 0 {
return "", nil
}
return response[0].ThreePID.Localpart, nil
return rows[0].ThreePID.Localpart, nil
}
func (s *threepidStatements) selectThreePIDsForLocalpart(
@ -115,23 +106,27 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
) (threepids []authtypes.ThreePID, err error) {
// "SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.threepids.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": localpart,
}
response, err := queryThreePID(s, ctx, s.selectThreePIDsForLocalpartStmt, params)
var rows []threePIDCosmosData
err = cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectThreePIDsForLocalpartStmt, params, &rows)
if err != nil {
return threepids, err
}
if len(response) == 0 {
if len(rows) == 0 {
return threepids, nil
}
threepids = []authtypes.ThreePID{}
for _, item := range response {
for _, item := range rows {
threepids = append(threepids, authtypes.ThreePID{
Address: item.ThreePID.ThreePID,
Medium: item.ThreePID.Medium,
@ -145,19 +140,16 @@ func (s *threepidStatements) insertThreePID(
) (err error) {
// "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
var result = ThreePIDCosmos{
var result = threePIDCosmos{
Localpart: localpart,
Medium: medium,
ThreePID: threepid,
}
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
docId := fmt.Sprintf("%s_%s", threepid, medium)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var dbData = ThreePIDCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var dbData = threePIDCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
ThreePID: result,
}
@ -179,11 +171,9 @@ func (s *threepidStatements) deleteThreePID(
ctx context.Context, threepid string, medium string) (err error) {
// "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.accounts.tableName)
docId := fmt.Sprintf("%s_%s", threepid, medium)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var options = cosmosdbapi.GetDeleteDocumentOptions(s.getPartitionKey())
_, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,

View file

@ -49,7 +49,7 @@ import (
// );
// `
type DeviceCosmos struct {
type deviceCosmos struct {
ID string `json:"device_id"`
UserID string `json:"user_id"`
// The access_token granted to this device.
@ -69,12 +69,12 @@ type DeviceCosmos struct {
AppserviceID string `json:"app_service_id"`
}
type DeviceCosmosData struct {
type deviceCosmosData struct {
cosmosdbapi.CosmosDocument
Device DeviceCosmos `json:"mx_userapi_device"`
Device deviceCosmos `json:"mx_userapi_device"`
}
type DeviceCosmosSessionCount struct {
type deviceCosmosSessionCount struct {
SessionCount int64 `json:"sessioncount"`
}
@ -90,7 +90,15 @@ type devicesStatements struct {
tableName string
}
func mapFromDevice(db DeviceCosmos) api.Device {
func (s *devicesStatements) getCollectionName() string {
return cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
}
func (s *devicesStatements) getPartitionKey() string {
return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName())
}
func mapFromDevice(db deviceCosmos) api.Device {
return api.Device{
AccessToken: db.AccessToken,
AppserviceID: db.AppserviceID,
@ -103,9 +111,9 @@ func mapFromDevice(db DeviceCosmos) api.Device {
}
}
func mapTodevice(api api.Device, s *devicesStatements) DeviceCosmos {
func mapTodevice(api api.Device, s *devicesStatements) deviceCosmos {
localPart, _ := userutil.ParseUsernameParam(api.UserID, &s.serverName)
return DeviceCosmos{
return deviceCosmos{
AccessToken: api.AccessToken,
AppserviceID: api.AppserviceID,
ID: api.ID,
@ -118,29 +126,8 @@ func mapTodevice(api api.Device, s *devicesStatements) DeviceCosmos {
}
}
func queryDevice(s *devicesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []DeviceCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*DeviceCosmosData, error) {
response := DeviceCosmosData{}
func getDevice(s *devicesStatements, ctx context.Context, pk string, docId string) (*deviceCosmosData, error) {
response := deviceCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
@ -156,7 +143,7 @@ func getDevice(s *devicesStatements, ctx context.Context, pk string, docId strin
return &response, err
}
func setDevice(s *devicesStatements, ctx context.Context, device DeviceCosmosData) (*DeviceCosmosData, error) {
func setDevice(s *devicesStatements, ctx context.Context, device deviceCosmosData) (*deviceCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(device.Pk, device.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
@ -191,31 +178,31 @@ func (s *devicesStatements) insertDevice(
var sessionID int64
// "SELECT COUNT(access_token) FROM device_devices"
// HACK: Do we need a Cosmos Table for the sequence?
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response []DeviceCosmosSessionCount
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
}
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(s.selectDevicesCountStmt, params)
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
var rows []deviceCosmosSessionCount
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
s.getPartitionKey(), s.selectDevicesCountStmt, params, &rows)
if err != nil {
return nil, err
}
sessionID = response[0].SessionCount
sessionID = rows[0].SessionCount
sessionID++
// "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
// " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
data := DeviceCosmos{
// access_token TEXT PRIMARY KEY,
// UNIQUE (localpart, device_id)
// HACK: check for duplicate PK as we are using the UNIQUE key for the DocId
docId := fmt.Sprintf("%s_%s", localpart, id)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
data := deviceCosmos{
ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName),
AccessToken: accessToken,
@ -226,14 +213,8 @@ func (s *devicesStatements) insertDevice(
UserAgent: userAgent,
}
// access_token TEXT PRIMARY KEY,
// UNIQUE (localpart, device_id)
// HACK: check for duplicate PK as we are using the UNIQUE key for the DocId
docId := fmt.Sprintf("%s_%s", localpart, id)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
var dbData = DeviceCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(dbCollectionName, s.db.cosmosConfig.TenantName, pk, cosmosDocId),
var dbData = deviceCosmosData{
CosmosDocument: cosmosdbapi.GenerateDocument(s.getCollectionName(), s.db.cosmosConfig.TenantName, s.getPartitionKey(), cosmosDocId),
Device: data,
}
@ -257,11 +238,9 @@ func (s *devicesStatements) deleteDevice(
ctx context.Context, id, localpart string,
) error {
// "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
docId := fmt.Sprintf("%s_%s", localpart, id)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var options = cosmosdbapi.GetDeleteDocumentOptions(pk)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var options = cosmosdbapi.GetDeleteDocumentOptions(s.getPartitionKey())
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
@ -279,20 +258,23 @@ func (s *devicesStatements) deleteDevices(
ctx context.Context, localpart string, devices []string,
) error {
// "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var response []DeviceCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": localpart,
"@x3": devices,
}
response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartStmt, params)
var rows []deviceCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectDevicesByLocalpartStmt, params, &rows)
if err != nil {
return err
}
for _, item := range response {
for _, item := range rows {
s.deleteDevice(ctx, item.Device.ID, item.Device.Localpart)
}
return err
@ -302,22 +284,25 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, localpart, exceptDeviceID string,
) error {
// "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
exceptDevices := []string{
exceptDeviceID,
}
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": localpart,
"@x3": exceptDevices,
}
response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartStmt, params)
var rows []deviceCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectDevicesByLocalpartStmt, params, &rows)
if err != nil {
return err
}
for _, item := range response {
for _, item := range rows {
s.deleteDevice(ctx, item.Device.ID, item.Device.Localpart)
}
return err
@ -327,11 +312,9 @@ func (s *devicesStatements) updateDeviceName(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
// "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var response, exGet = getDevice(s, ctx, s.getPartitionKey(), cosmosDocId)
if exGet != nil {
return exGet
}
@ -349,24 +332,27 @@ func (s *devicesStatements) selectDeviceByToken(
ctx context.Context, accessToken string,
) (*api.Device, error) {
// "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var response []DeviceCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": accessToken,
}
response, err := queryDevice(s, ctx, s.selectDeviceByTokenStmt, params)
var rows []deviceCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectDeviceByTokenStmt, params, &rows)
if err != nil {
return nil, err
}
if len(response) == 0 {
if len(rows) == 0 {
return nil, cosmosdbutil.ErrNoRows
}
if err == nil {
result := mapFromDevice(response[0].Device)
result := mapFromDevice(rows[0].Device)
return &result, nil
}
return nil, err
@ -378,11 +364,9 @@ func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
// "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var response, exGet = getDevice(s, ctx, s.getPartitionKey(), cosmosDocId)
if exGet != nil {
return nil, exGet
}
@ -395,20 +379,23 @@ func (s *devicesStatements) selectDevicesByLocalpart(
) ([]api.Device, error) {
devices := []api.Device{}
// "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": localpart,
"@x3": exceptDeviceID,
}
response, err := queryDevice(s, ctx, s.selectDevicesByLocalpartExceptIDStmt, params)
var rows []deviceCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectDevicesByLocalpartExceptIDStmt, params, &rows)
if err != nil {
return nil, err
}
for _, item := range response {
for _, item := range rows {
dev := mapFromDevice(item.Device)
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev)
@ -420,19 +407,21 @@ func (s *devicesStatements) selectDevicesByLocalpart(
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
// "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
var devices []api.Device
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
var response []DeviceCosmosData
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x1": s.getCollectionName(),
"@x2": deviceIDs,
}
response, err := queryDevice(s, ctx, s.selectDevicesByIDStmt, params)
var rows []deviceCosmosData
err := cosmosdbapi.PerformQuery(ctx,
s.db.connection,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
s.getPartitionKey(), s.selectDevicesByIDStmt, params, &rows)
if err != nil {
return nil, err
}
for _, item := range response {
for _, item := range rows {
dev := mapFromDevice(item.Device)
devices = append(devices, dev)
}
@ -443,11 +432,9 @@ func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, localpart,
lastSeenTs := time.Now().UnixNano() / 1000000
// "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.devices.tableName)
docId := fmt.Sprintf("%s_%s", localpart, deviceID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
var response, exGet = getDevice(s, ctx, pk, cosmosDocId)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId)
var response, exGet = getDevice(s, ctx, s.getPartitionKey(), cosmosDocId)
if exGet != nil {
return exGet
}