diff --git a/appservice/storage/cosmosdb/appservice_events_table.go b/appservice/storage/cosmosdb/appservice_events_table.go index f7cc63e79..fd40d0160 100644 --- a/appservice/storage/cosmosdb/appservice_events_table.go +++ b/appservice/storage/cosmosdb/appservice_events_table.go @@ -144,18 +144,6 @@ func getEvent(s *eventsStatements, ctx context.Context, pk string, docId string) return &response, err } -func setEvent(s *eventsStatements, ctx context.Context, event eventCosmosData) (*eventCosmosData, error) { - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(event.Pk, event.ETag) - var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - event.Id, - &event, - optionsReplace) - return &event, ex -} - func deleteEvent(s *eventsStatements, ctx context.Context, event eventCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(event.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( @@ -392,7 +380,8 @@ func (s *eventsStatements) updateTxnIDForEvents( for _, item := range rows { item.Event.TXNID = int64(txnID) // _, err := s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID) - _, err = setEvent(s, ctx, item) + item.SetUpdateTime() + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) } return err diff --git a/federationsender/storage/cosmosdb/inbound_peeks_table.go b/federationsender/storage/cosmosdb/inbound_peeks_table.go index 9dd5256dd..da9c8d095 100644 --- a/federationsender/storage/cosmosdb/inbound_peeks_table.go +++ b/federationsender/storage/cosmosdb/inbound_peeks_table.go @@ -114,18 +114,6 @@ func getInboundPeek(s *inboundPeeksStatements, ctx context.Context, pk string, d return &response, err } -func setInboundPeek(s *inboundPeeksStatements, ctx context.Context, inboundPeek inboundPeekCosmosData) (*inboundPeekCosmosData, error) { - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(inboundPeek.Pk, inboundPeek.ETag) - var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - inboundPeek.Id, - &inboundPeek, - optionsReplace) - return &inboundPeek, ex -} - func deleteInboundPeek(s *inboundPeeksStatements, ctx context.Context, dbData inboundPeekCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( @@ -209,20 +197,21 @@ func (s *inboundPeeksStatements) RenewInboundPeek( cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) // _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) - res, err := getInboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId) + item, err := getInboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId) if err != nil { return } - if res == nil { + if item == nil { return } - res.InboundPeek.RenewedTimestamp = nowMilli - res.InboundPeek.RenewalInterval = renewalInterval + item.SetUpdateTime() + item.InboundPeek.RenewedTimestamp = nowMilli + item.InboundPeek.RenewalInterval = renewalInterval - _, err = setInboundPeek(s, ctx, *res) + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) return } diff --git a/federationsender/storage/cosmosdb/outbound_peeks_table.go b/federationsender/storage/cosmosdb/outbound_peeks_table.go index cc707886f..cf9ec1830 100644 --- a/federationsender/storage/cosmosdb/outbound_peeks_table.go +++ b/federationsender/storage/cosmosdb/outbound_peeks_table.go @@ -111,18 +111,6 @@ func getOutboundPeek(s *outboundPeeksStatements, ctx context.Context, pk string, return &response, err } -func setOutboundPeek(s *outboundPeeksStatements, ctx context.Context, outboundPeek outboundPeekCosmosData) (*outboundPeekCosmosData, error) { - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(outboundPeek.Pk, outboundPeek.ETag) - var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - outboundPeek.Id, - &outboundPeek, - optionsReplace) - return &outboundPeek, ex -} - func deleteOutboundPeek(s *outboundPeeksStatements, ctx context.Context, dbData outboundPeekCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( @@ -206,20 +194,21 @@ func (s *outboundPeeksStatements) RenewOutboundPeek( cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) // _, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) - res, err := getOutboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId) + item, err := getOutboundPeek(s, ctx, s.getPartitionKey(roomID), cosmosDocId) if err != nil { return } - if res == nil { + if item == nil { return } - res.OutboundPeek.RenewedTimestamp = nowMilli - res.OutboundPeek.RenewalInterval = renewalInterval + item.SetUpdateTime() + item.OutboundPeek.RenewedTimestamp = nowMilli + item.OutboundPeek.RenewalInterval = renewalInterval - _, err = setOutboundPeek(s, ctx, *res) + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) return } diff --git a/internal/cosmosdbapi/client.go b/internal/cosmosdbapi/client.go index 29837f00f..eeae75573 100644 --- a/internal/cosmosdbapi/client.go +++ b/internal/cosmosdbapi/client.go @@ -61,8 +61,8 @@ func PerformQuery(ctx context.Context, if err != nil { return err } - optionsQry := GetQueryDocumentsOptions(partitonKey) - var query = GetQuery(qryString, params) + optionsQry := getQueryDocumentsOptions(partitonKey) + var query = getQuery(qryString, params) _, err = GetClient(conn).QueryDocuments( ctx, databaseName, @@ -84,8 +84,8 @@ func PerformQueryAllPartitions(ctx context.Context, if err != nil { return err } - var optionsQry = GetQueryAllPartitionsDocumentsOptions() - var query = GetQuery(qryString, params) + var optionsQry = getQueryAllPartitionsDocumentsOptions() + var query = getQuery(qryString, params) _, err = GetClient(conn).QueryDocuments( ctx, databaseName, @@ -141,6 +141,26 @@ func GetDocumentOrNil(connection CosmosConnection, config CosmosConfig, ctx cont return nil } +func UpdateDocument(ctx context.Context, + conn CosmosConnection, + databaseName string, + containerName string, + partitionKey string, + eTag string, + docId string, + document interface{}, +) (*interface{}, error) { + optionsReplace := getReplaceDocumentOptions(partitionKey, eTag) + _, _, err := GetClient(conn).ReplaceDocument( + ctx, + databaseName, + containerName, + docId, + &document, + optionsReplace) + return &document, err +} + func validateQuery(qryString string) error { if len(qryString) == 0 { return errors.New("qryString was nil") diff --git a/internal/cosmosdbapi/documentoperations.go b/internal/cosmosdbapi/documentoperations.go index 6ec944fe3..2285b3511 100644 --- a/internal/cosmosdbapi/documentoperations.go +++ b/internal/cosmosdbapi/documentoperations.go @@ -18,7 +18,7 @@ func getUpsertDocumentOptions(pk string) cosmosapi.CreateDocumentOptions { } } -func GetQueryDocumentsOptions(pk string) cosmosapi.QueryDocumentsOptions { +func getQueryDocumentsOptions(pk string) cosmosapi.QueryDocumentsOptions { return cosmosapi.QueryDocumentsOptions{ PartitionKeyValue: pk, IsQuery: true, @@ -26,7 +26,7 @@ func GetQueryDocumentsOptions(pk string) cosmosapi.QueryDocumentsOptions { } } -func GetQueryAllPartitionsDocumentsOptions() cosmosapi.QueryDocumentsOptions { +func getQueryAllPartitionsDocumentsOptions() cosmosapi.QueryDocumentsOptions { return cosmosapi.QueryDocumentsOptions{ IsQuery: true, EnableCrossPartition: true, @@ -40,7 +40,7 @@ func GetGetDocumentOptions(pk string) cosmosapi.GetDocumentOptions { } } -func GetReplaceDocumentOptions(pk string, etag string) cosmosapi.ReplaceDocumentOptions { +func getReplaceDocumentOptions(pk string, etag string) cosmosapi.ReplaceDocumentOptions { return cosmosapi.ReplaceDocumentOptions{ PartitionKeyValue: pk, IfMatch: etag, diff --git a/internal/cosmosdbapi/query.go b/internal/cosmosdbapi/query.go index 29e46be23..c3deebd02 100644 --- a/internal/cosmosdbapi/query.go +++ b/internal/cosmosdbapi/query.go @@ -4,17 +4,17 @@ import ( cosmosapi "github.com/vippsas/go-cosmosdb/cosmosapi" ) -func GetQuery(qry string, params map[string]interface{}) cosmosapi.Query { +func getQuery(qry string, params map[string]interface{}) cosmosapi.Query { qryParams := []cosmosapi.QueryParam{} for key, value := range params { - qryParam := cosmosapi.QueryParam { - Name: key, + qryParam := cosmosapi.QueryParam{ + Name: key, Value: value, } qryParams = append(qryParams, qryParam) - } - return cosmosapi.Query { - Query: qry, + } + return cosmosapi.Query{ + Query: qry, Params: qryParams, } -} \ No newline at end of file +} diff --git a/internal/cosmosdbutil/document_seq.go b/internal/cosmosdbutil/document_seq.go index 09963dbe5..caaf78b33 100644 --- a/internal/cosmosdbutil/document_seq.go +++ b/internal/cosmosdbutil/document_seq.go @@ -53,15 +53,8 @@ func GetNextSequence( } } else { dbData.Value++ - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(dbData.Pk, dbData.ETag) - var _, _, err = cosmosdbapi.GetClient(connection).ReplaceDocument( - ctx, - config.DatabaseName, - config.ContainerName, - cosmosDocId, - dbData, - optionsReplace, - ) + dbData.SetUpdateTime() + _, err := cosmosdbapi.UpdateDocument(ctx, connection, config.DatabaseName, config.ContainerName, dbData.Pk, dbData.ETag, dbData.Id, dbData) if err != nil { return -1, err } diff --git a/roomserver/storage/cosmosdb/events_table.go b/roomserver/storage/cosmosdb/events_table.go index 470483656..ec63120ac 100644 --- a/roomserver/storage/cosmosdb/events_table.go +++ b/roomserver/storage/cosmosdb/events_table.go @@ -233,18 +233,6 @@ func getEvent(s *eventStatements, ctx context.Context, pk string, docId string) return &response, err } -func setEvent(s *eventStatements, ctx context.Context, event eventCosmosData) (*eventCosmosData, error) { - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(event.Pk, event.ETag) - var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - event.Id, - &event, - optionsReplace) - return &event, ex -} - func isEventAuthEventNIDsSame( ids []int64, authEventNIDs []types.EventNID, @@ -634,9 +622,10 @@ func (s *eventStatements) UpdateEventState( } item := rows[0] + item.SetUpdateTime() item.Event.StateSnapshotNID = int64(stateNID) - var _, exReplace = setEvent(s, ctx, item) + _, exReplace := cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) if exReplace != nil { return exReplace } @@ -691,9 +680,10 @@ func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql. } item := rows[0] + item.SetUpdateTime() item.Event.SentToOutput = true - var _, exReplace = setEvent(s, ctx, item) + _, exReplace := cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) if exReplace != nil { return exReplace } diff --git a/roomserver/storage/cosmosdb/invite_table.go b/roomserver/storage/cosmosdb/invite_table.go index c5083fabd..9ba97077e 100644 --- a/roomserver/storage/cosmosdb/invite_table.go +++ b/roomserver/storage/cosmosdb/invite_table.go @@ -120,18 +120,6 @@ func getInvite(s *inviteStatements, ctx context.Context, pk string, docId string return &response, err } -func setInvite(s *inviteStatements, ctx context.Context, invite inviteCosmosData) (*inviteCosmosData, error) { - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(invite.Pk, invite.ETag) - var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - invite.Id, - &invite, - optionsReplace) - return &invite, ex -} - func NewCosmosDBInvitesTable(db *Database) (tables.Invites, error) { s := &inviteStatements{ db: db, @@ -224,8 +212,9 @@ func (s *inviteStatements) UpdateInviteRetired( // UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired // now retire the invites + item.SetUpdateTime() item.Invite.Retired = true - _, err = setInvite(s, ctx, item) + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) } return diff --git a/roomserver/storage/cosmosdb/membership_table.go b/roomserver/storage/cosmosdb/membership_table.go index 452c1fbcc..a2c415f61 100644 --- a/roomserver/storage/cosmosdb/membership_table.go +++ b/roomserver/storage/cosmosdb/membership_table.go @@ -232,18 +232,6 @@ func getMembership(s *membershipStatements, ctx context.Context, pk string, docI return &response, err } -func setMembership(s *membershipStatements, ctx context.Context, membership membershipCosmosData) (*membershipCosmosData, error) { - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(membership.Pk, membership.ETag) - var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - membership.Id, - &membership, - optionsReplace) - return &membership, ex -} - func NewCosmosDBMembershipTable(db *Database) (tables.Membership, error) { s := &membershipStatements{ db: db, @@ -290,7 +278,8 @@ func (s *membershipStatements) InsertMembership( exists.Membership.TargetNID = int64(targetUserNID) exists.Membership.TargetLocal = localTarget exists.SetUpdateTime() - _, errSet := setMembership(s, ctx, *exists) + exists.SetUpdateTime() + _, errSet := cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, exists.Pk, exists.ETag, exists.Id, exists) return errSet } @@ -455,8 +444,9 @@ func (s *membershipStatements) UpdateMembership( dbData.Membership.MembershipNID = int64(membership) dbData.Membership.EventNID = int64(eventNID) dbData.Membership.Forgotten = forgotten + dbData.SetUpdateTime() - _, err = setMembership(s, ctx, *dbData) + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, dbData.Pk, dbData.ETag, dbData.Id, dbData) return err } @@ -711,8 +701,9 @@ func (s *membershipStatements) UpdateForgetMembership( return err } + dbData.SetUpdateTime() dbData.Membership.Forgotten = forget - _, err = setMembership(s, ctx, *dbData) + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, dbData.Pk, dbData.ETag, dbData.Id, dbData) return err } diff --git a/roomserver/storage/cosmosdb/redactions_table.go b/roomserver/storage/cosmosdb/redactions_table.go index 0f71819d4..58ad9bd3e 100644 --- a/roomserver/storage/cosmosdb/redactions_table.go +++ b/roomserver/storage/cosmosdb/redactions_table.go @@ -99,18 +99,6 @@ func getRedaction(s *redactionStatements, ctx context.Context, pk string, docId return &response, err } -func setRedaction(s *redactionStatements, ctx context.Context, redaction redactionCosmosData) (*redactionCosmosData, error) { - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(redaction.Pk, redaction.ETag) - var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - redaction.Id, - &redaction, - optionsReplace) - return &redaction, ex -} - func NewCosmosDBRedactionsTable(db *Database) (tables.Redactions, error) { s := &redactionStatements{ db: db, @@ -242,13 +230,14 @@ func (s *redactionStatements) MarkRedactionValidated( docId := redactionEventID cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - response, err := getRedaction(s, ctx, s.getPartitionKey(), cosmosDocId) + item, err := getRedaction(s, ctx, s.getPartitionKey(), cosmosDocId) if err != nil { return err } - response.Redaction.Validated = validated + item.SetUpdateTime() + item.Redaction.Validated = validated - _, err = setRedaction(s, ctx, *response) + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) return err } diff --git a/roomserver/storage/cosmosdb/rooms_table.go b/roomserver/storage/cosmosdb/rooms_table.go index d342535f8..53f517752 100644 --- a/roomserver/storage/cosmosdb/rooms_table.go +++ b/roomserver/storage/cosmosdb/rooms_table.go @@ -164,18 +164,6 @@ func getRoom(s *roomStatements, ctx context.Context, pk string, docId string) (* return &response, err } -func setRoom(s *roomStatements, ctx context.Context, room roomCosmosData) (*roomCosmosData, error) { - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(room.Pk, room.ETag) - var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - room.Id, - &room, - optionsReplace) - return &room, ex -} - func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { // "SELECT room_id FROM roomserver_rooms" @@ -398,11 +386,12 @@ func (s *roomStatements) UpdateLatestEventNIDs( //Assume 1 per RoomNID room := rows[0] + room.SetUpdateTime() room.Room.LatestEventNIDs = mapFromEventNIDArray(eventNIDs) room.Room.LastEventSentNID = int64(lastEventSentNID) room.Room.StateSnapshotNID = int64(stateSnapshotNID) - _, err = setRoom(s, ctx, room) + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, room.Pk, room.ETag, room.Id, room) return err } diff --git a/roomserver/storage/cosmosdb/state_block_table.go b/roomserver/storage/cosmosdb/state_block_table.go index 2b25d0c73..5b898db93 100644 --- a/roomserver/storage/cosmosdb/state_block_table.go +++ b/roomserver/storage/cosmosdb/state_block_table.go @@ -107,18 +107,6 @@ func getStateBlock(s *stateBlockStatements, ctx context.Context, pk string, docI return &response, err } -func setStateBlock(s *stateBlockStatements, ctx context.Context, item stateBlockCosmosData) (*stateBlockCosmosData, error) { - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(item.Pk, item.ETag) - var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - item.Id, - &item, - optionsReplace) - return &item, ex -} - func NewCosmosDBStateBlockTable(db *Database) (tables.StateBlock, error) { s := &stateBlockStatements{ db: db, @@ -168,9 +156,10 @@ func (s *stateBlockStatements) BulkInsertStateData( } if existing != nil { //if exists, just update and dont create a new seq + existing.SetUpdateTime() existing.StateBlock.EventNIDs = ids existing.SetUpdateTime() - _, err = setStateBlock(s, ctx, *existing) + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, existing.Pk, existing.ETag, existing.Id, existing) if err != nil { return 0, err } diff --git a/syncapi/storage/cosmosdb/invites_table.go b/syncapi/storage/cosmosdb/invites_table.go index 611301881..d05ed268e 100644 --- a/syncapi/storage/cosmosdb/invites_table.go +++ b/syncapi/storage/cosmosdb/invites_table.go @@ -121,18 +121,6 @@ func getInviteEvent(s *inviteEventsStatements, ctx context.Context, pk string, d return &response, err } -func setInviteEvent(s *inviteEventsStatements, ctx context.Context, invite inviteEventCosmosData) (*inviteEventCosmosData, error) { - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(invite.Pk, invite.ETag) - var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - invite.Id, - &invite, - optionsReplace) - return &invite, ex -} - func NewCosmosDBInvitesTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.Invites, error) { s := &inviteEventsStatements{ db: db, @@ -225,9 +213,10 @@ func (s *inviteEventsStatements) DeleteInviteEvent( s.getPartitionKey(), s.deleteInviteEventStmt, params, &rows) for _, item := range rows { + item.SetUpdateTime() item.InviteEvent.Deleted = true item.InviteEvent.ID = int64(streamPos) - setInviteEvent(s, ctx, item) + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) } return streamPos, err } diff --git a/syncapi/storage/cosmosdb/output_room_events_table.go b/syncapi/storage/cosmosdb/output_room_events_table.go index 988fd9125..4b2cbced0 100644 --- a/syncapi/storage/cosmosdb/output_room_events_table.go +++ b/syncapi/storage/cosmosdb/output_room_events_table.go @@ -160,18 +160,6 @@ func (s *outputRoomEventsStatements) getPartitionKey() string { return cosmosdbapi.GetPartitionKeyByCollection(s.db.cosmosConfig.TenantName, s.getCollectionName()) } -func setOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, outputRoomEvent outputRoomEventCosmosData) (*outputRoomEventCosmosData, error) { - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(outputRoomEvent.Pk, outputRoomEvent.ETag) - var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - outputRoomEvent.Id, - &outputRoomEvent, - optionsReplace) - return &outputRoomEvent, ex -} - func deleteOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, dbData outputRoomEventCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( @@ -227,8 +215,9 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event } for _, item := range rows { + item.SetUpdateTime() item.OutputRoomEvent.HeaderedEventJSON = headeredJSON - _, err = setOutputRoomEvent(s, ctx, item) + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) } return err diff --git a/syncapi/storage/cosmosdb/peeks_table.go b/syncapi/storage/cosmosdb/peeks_table.go index d13927f98..e904f0426 100644 --- a/syncapi/storage/cosmosdb/peeks_table.go +++ b/syncapi/storage/cosmosdb/peeks_table.go @@ -140,18 +140,6 @@ func getPeek(s *peekStatements, ctx context.Context, pk string, docId string) (* return &response, err } -func setPeek(s *peekStatements, ctx context.Context, peek peekCosmosData) (*peekCosmosData, error) { - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(peek.Pk, peek.ETag) - var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - peek.Id, - &peek, - optionsReplace) - return &peek, ex -} - func NewCosmosDBPeeksTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.Peeks, error) { s := &peekStatements{ db: db, @@ -249,9 +237,10 @@ func (s *peekStatements) DeletePeek( } for _, item := range rows { + item.SetUpdateTime() item.Peek.Deleted = true item.Peek.ID = int64(streamPos) - _, err = setPeek(s, ctx, item) + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) if err != nil { return } @@ -293,9 +282,10 @@ func (s *peekStatements) DeletePeeks( } for _, item := range rows { + item.SetUpdateTime() item.Peek.Deleted = true item.Peek.ID = int64(streamPos) - _, err = setPeek(s, ctx, item) + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) if err != nil { return 0, err } diff --git a/userapi/storage/accounts/cosmosdb/accounts_table.go b/userapi/storage/accounts/cosmosdb/accounts_table.go index b5727ee12..84ac51e0d 100644 --- a/userapi/storage/accounts/cosmosdb/accounts_table.go +++ b/userapi/storage/accounts/cosmosdb/accounts_table.go @@ -109,19 +109,6 @@ func getAccount(s *accountsStatements, ctx context.Context, pk string, docId str return &response, err } -func setAccount(s *accountsStatements, ctx context.Context, account accountCosmosData) (*accountCosmosData, error) { - response := accountCosmosData{} - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(account.Pk, account.ETag) - var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - account.Id, - &account, - optionsReplace) - return &response, ex -} - func mapFromAccount(db accountCosmos) api.Account { return api.Account{ AppServiceID: db.AppServiceID, @@ -193,16 +180,17 @@ func (s *accountsStatements) updatePassword( docId := localpart cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - var response, exGet = getAccount(s, ctx, s.getPartitionKey(), cosmosDocId) + var item, exGet = getAccount(s, ctx, s.getPartitionKey(), cosmosDocId) if exGet != nil { return exGet } - response.Account.PasswordHash = passwordHash + item.SetUpdateTime() + item.Account.PasswordHash = passwordHash - var _, exReplace = setAccount(s, ctx, *response) - if exReplace != nil { - return exReplace + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) + if err != nil { + return err } return } @@ -215,16 +203,17 @@ func (s *accountsStatements) deactivateAccount( docId := localpart cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - var response, exGet = getAccount(s, ctx, s.getPartitionKey(), cosmosDocId) + var item, exGet = getAccount(s, ctx, s.getPartitionKey(), cosmosDocId) if exGet != nil { return exGet } - response.Account.IsDeactivated = true + item.SetUpdateTime() + item.Account.IsDeactivated = true - var _, exReplace = setAccount(s, ctx, *response) - if exReplace != nil { - return exReplace + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) + if err != nil { + return err } return } diff --git a/userapi/storage/accounts/cosmosdb/key_backup_table.go b/userapi/storage/accounts/cosmosdb/key_backup_table.go index 60f74441f..adc9b0456 100644 --- a/userapi/storage/accounts/cosmosdb/key_backup_table.go +++ b/userapi/storage/accounts/cosmosdb/key_backup_table.go @@ -136,18 +136,6 @@ func getKeyBackup(s *keyBackupStatements, ctx context.Context, pk string, docId return &response, err } -func setKeyBackup(s *keyBackupStatements, ctx context.Context, keyBackup keyBackupCosmosData) (*keyBackupCosmosData, error) { - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(keyBackup.Pk, keyBackup.ETag) - var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - keyBackup.Id, - &keyBackup, - optionsReplace) - return &keyBackup, ex -} - func (s *keyBackupStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) { s.db = db // s.insertBackupKeyStmt = insertBackupKeySQL @@ -243,23 +231,24 @@ func (s *keyBackupStatements) updateBackupKey( docId := fmt.Sprintf("%s_%s_%s_%s", userID, key.RoomID, key.SessionID, version) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - res, err := getKeyBackup(s, ctx, s.getPartitionKey(userID), cosmosDocId) + item, err := getKeyBackup(s, ctx, s.getPartitionKey(userID), cosmosDocId) if err != nil { return } - if res == nil { + if item == nil { return } // ctx, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), userID, key.RoomID, key.SessionID, version, - res.KeyBackup.FirstMessageIndex = key.FirstMessageIndex - res.KeyBackup.ForwardedCount = key.ForwardedCount - res.KeyBackup.IsVerified = key.IsVerified - res.KeyBackup.SessionData = key.SessionData + item.SetUpdateTime() + item.KeyBackup.FirstMessageIndex = key.FirstMessageIndex + item.KeyBackup.ForwardedCount = key.ForwardedCount + item.KeyBackup.IsVerified = key.IsVerified + item.KeyBackup.SessionData = key.SessionData - _, err = setKeyBackup(s, ctx, *res) + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) return } diff --git a/userapi/storage/accounts/cosmosdb/key_backup_version_table.go b/userapi/storage/accounts/cosmosdb/key_backup_version_table.go index 5079ae41d..20d66cd33 100644 --- a/userapi/storage/accounts/cosmosdb/key_backup_version_table.go +++ b/userapi/storage/accounts/cosmosdb/key_backup_version_table.go @@ -117,18 +117,6 @@ func getKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, pk return &response, err } -func setKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, keyBackup keyBackupVersionCosmosData) (*keyBackupVersionCosmosData, error) { - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(keyBackup.Pk, keyBackup.ETag) - var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - keyBackup.Id, - &keyBackup, - optionsReplace) - return &keyBackup, ex -} - func (s *keyBackupVersionStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) { s.db = db // s.insertKeyBackupStmt = insertKeyBackupSQL @@ -196,20 +184,21 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData( docId := fmt.Sprintf("%s_%d", userID, versionInt) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId) + item, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId) if err != nil { return err } - if res == nil { + if item == nil { return err } // _, err = txn.Stmt(s.updateKeyBackupAuthDataStmt).ExecContext(ctx, string(authData), userID, versionInt) - res.KeyBackupVersion.AuthData = authData + item.SetUpdateTime() + item.KeyBackupVersion.AuthData = authData - _, err = setKeyBackupVersion(s, ctx, *res) + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) return err } @@ -226,20 +215,21 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag( docId := fmt.Sprintf("%s_%d", userID, versionInt) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId) + item, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId) if err != nil { return err } - if res == nil { + if item == nil { return err } // _, err = txn.Stmt(s.updateKeyBackupETagStmt).ExecContext(ctx, etag, userID, versionInt) - res.KeyBackupVersion.Etag = etag + item.SetUpdateTime() + item.KeyBackupVersion.Etag = etag - _, err = setKeyBackupVersion(s, ctx, *res) + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) return err } @@ -256,20 +246,21 @@ func (s *keyBackupVersionStatements) deleteKeyBackup( docId := fmt.Sprintf("%s_%d", userID, versionInt) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - res, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId) + item, err := getKeyBackupVersion(s, ctx, s.getPartitionKey(userID), cosmosDocId) if err != nil { return false, err } - if res == nil { + if item == nil { return false, err } // result, err := txn.Stmt(s.deleteKeyBackupStmt).ExecContext(ctx, userID, versionInt) - res.KeyBackupVersion.Deleted = 1 + item.SetUpdateTime() + item.KeyBackupVersion.Deleted = 1 - _, err = setKeyBackupVersion(s, ctx, *res) + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) if err != nil { return false, err diff --git a/userapi/storage/accounts/cosmosdb/profile_table.go b/userapi/storage/accounts/cosmosdb/profile_table.go index 418e48acb..04089e010 100644 --- a/userapi/storage/accounts/cosmosdb/profile_table.go +++ b/userapi/storage/accounts/cosmosdb/profile_table.go @@ -108,18 +108,6 @@ func getProfile(s *profilesStatements, ctx context.Context, pk string, docId str return &response, err } -func setProfile(s *profilesStatements, ctx context.Context, profile profileCosmosData) (*profileCosmosData, error) { - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(profile.Pk, profile.ETag) - var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - profile.Id, - &profile, - optionsReplace) - return &profile, ex -} - func (s *profilesStatements) insertProfile( ctx context.Context, localpart string, ) error { @@ -188,16 +176,17 @@ func (s *profilesStatements) setAvatarURL( docId := localpart cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - var response, exGet = getProfile(s, ctx, s.getPartitionKey(), cosmosDocId) + var item, exGet = getProfile(s, ctx, s.getPartitionKey(), cosmosDocId) if exGet != nil { return exGet } - response.Profile.AvatarURL = avatarURL + item.SetUpdateTime() + item.Profile.AvatarURL = avatarURL - var _, exReplace = setProfile(s, ctx, *response) - if exReplace != nil { - return exReplace + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) + if err != nil { + return err } return } @@ -209,16 +198,17 @@ func (s *profilesStatements) setDisplayName( // "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" docId := localpart cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - var response, exGet = getProfile(s, ctx, s.getPartitionKey(), cosmosDocId) + var item, exGet = getProfile(s, ctx, s.getPartitionKey(), cosmosDocId) if exGet != nil { return exGet } - response.Profile.DisplayName = displayName + item.SetUpdateTime() + item.Profile.DisplayName = displayName - var _, exReplace = setProfile(s, ctx, *response) - if exReplace != nil { - return exReplace + _, err = cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) + if err != nil { + return err } return } diff --git a/userapi/storage/devices/cosmosdb/devices_table.go b/userapi/storage/devices/cosmosdb/devices_table.go index 2d7dab646..aed388b75 100644 --- a/userapi/storage/devices/cosmosdb/devices_table.go +++ b/userapi/storage/devices/cosmosdb/devices_table.go @@ -144,18 +144,6 @@ func getDevice(s *devicesStatements, ctx context.Context, pk string, docId strin return &response, err } -func setDevice(s *devicesStatements, ctx context.Context, device deviceCosmosData) (*deviceCosmosData, error) { - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(device.Pk, device.ETag) - var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - device.Id, - &device, - optionsReplace) - return &device, ex -} - func (s *devicesStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) { s.db = db s.selectDevicesCountStmt = "select count(c._ts) as sessioncount from c where c._cn = @x1" @@ -315,18 +303,19 @@ func (s *devicesStatements) updateDeviceName( // "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" docId := fmt.Sprintf("%s_%s", localpart, deviceID) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - var response, exGet = getDevice(s, ctx, s.getPartitionKey(), cosmosDocId) + var item, exGet = getDevice(s, ctx, s.getPartitionKey(), cosmosDocId) if exGet != nil { return exGet } - response.Device.DisplayName = *displayName + item.SetUpdateTime() + item.Device.DisplayName = *displayName - var _, exReplace = setDevice(s, ctx, *response) - if exReplace != nil { - return exReplace + _, err := cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) + if err != nil { + return err } - return exReplace + return err } func (s *devicesStatements) selectDeviceByToken( @@ -435,17 +424,18 @@ func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, localpart, // "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" docId := fmt.Sprintf("%s_%s", localpart, deviceID) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, s.getCollectionName(), docId) - var response, exGet = getDevice(s, ctx, s.getPartitionKey(), cosmosDocId) + var item, exGet = getDevice(s, ctx, s.getPartitionKey(), cosmosDocId) if exGet != nil { return exGet } - response.Device.LastSeenTS = lastSeenTs - response.Device.LastSeenIP = ipAddr + item.SetUpdateTime() + item.Device.LastSeenTS = lastSeenTs + item.Device.LastSeenIP = ipAddr - var _, exReplace = setDevice(s, ctx, *response) - if exReplace != nil { - return exReplace + _, err := cosmosdbapi.UpdateDocument(ctx, s.db.connection, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, item.Pk, item.ETag, item.Id, item) + if err != nil { + return err } - return exReplace + return err }