diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 4a03aca74..be75f8ad0 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -29,6 +29,7 @@ import ( type DatabaseTransaction interface { sqlutil.Transaction + Reset() (err error) SharedUsers MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index fb3b295e9..937ced3a2 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -77,6 +77,7 @@ func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseTransactio } return &DatabaseTransaction{ Database: d, + ctx: ctx, txn: txn, }, nil */ @@ -89,6 +90,7 @@ func (d *Database) NewDatabaseTransaction(ctx context.Context) (*DatabaseTransac } return &DatabaseTransaction{ Database: d, + ctx: ctx, txn: txn, }, nil } diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index a19135a69..6cc83ebc8 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -13,6 +13,7 @@ import ( type DatabaseTransaction struct { *Database + ctx context.Context txn *sql.Tx } @@ -30,6 +31,19 @@ func (d *DatabaseTransaction) Rollback() error { return d.txn.Rollback() } +func (d *DatabaseTransaction) Reset() (err error) { + if d.txn == nil { + return nil + } + if err = d.txn.Rollback(); err != nil { + return err + } + if d.txn, err = d.DB.BeginTx(d.ctx, nil); err != nil { + return err + } + return +} + func (d *DatabaseTransaction) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) { id, err := d.OutputEvents.SelectMaxEventID(ctx, d.txn) if err != nil { diff --git a/syncapi/streams/stream_accountdata.go b/syncapi/streams/stream_accountdata.go index 34135d65a..f3e7fbdaa 100644 --- a/syncapi/streams/stream_accountdata.go +++ b/syncapi/streams/stream_accountdata.go @@ -54,7 +54,7 @@ func (p *AccountDataStreamProvider) IncrementalSync( ) if err != nil { req.Log.WithError(err).Error("p.DB.GetAccountDataInRange failed") - _ = snapshot.Rollback() + _ = snapshot.Reset() return from } diff --git a/syncapi/streams/stream_devicelist.go b/syncapi/streams/stream_devicelist.go index 10ede573a..307099b8f 100644 --- a/syncapi/streams/stream_devicelist.go +++ b/syncapi/streams/stream_devicelist.go @@ -34,13 +34,13 @@ func (p *DeviceListStreamProvider) IncrementalSync( to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) if err != nil { req.Log.WithError(err).Error("internal.DeviceListCatchup failed") - _ = snapshot.Rollback() + _ = snapshot.Reset() return from } err = internal.DeviceOTKCounts(req.Context, p.keyAPI, req.Device.UserID, req.Device.ID, req.Response) if err != nil { req.Log.WithError(err).Error("internal.DeviceListCatchup failed") - _ = snapshot.Rollback() + _ = snapshot.Reset() return from } diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index b52eaaab1..4c889b8f5 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -56,7 +56,7 @@ func (p *InviteStreamProvider) IncrementalSync( ) if err != nil { req.Log.WithError(err).Error("p.DB.InviteEventsInRange failed") - _ = snapshot.Rollback() + _ = snapshot.Reset() return from } diff --git a/syncapi/streams/stream_notificationdata.go b/syncapi/streams/stream_notificationdata.go index e1ee02b21..5154dd332 100644 --- a/syncapi/streams/stream_notificationdata.go +++ b/syncapi/streams/stream_notificationdata.go @@ -46,7 +46,7 @@ func (p *NotificationDataStreamProvider) IncrementalSync( countsByRoom, err := snapshot.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms) if err != nil { req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed") - _ = snapshot.Rollback() + _ = snapshot.Reset() return from } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 92e1bccf0..01ddf9ac9 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -75,7 +75,7 @@ func (p *PDUStreamProvider) CompleteSync( joinedRoomIDs, err := snapshot.RoomIDsWithMembership(ctx, req.Device.UserID, gomatrixserverlib.Join) if err != nil { req.Log.WithError(err).Error("p.DB.RoomIDsWithMembership failed") - _ = snapshot.Rollback() + _ = snapshot.Reset() return from } @@ -102,7 +102,9 @@ func (p *PDUStreamProvider) CompleteSync( ) if jerr != nil { req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed") - _ = snapshot.Rollback() + if err = snapshot.Reset(); err != nil { + return from + } continue // return from } req.Response.Rooms.Join[roomID] = *jr @@ -113,7 +115,7 @@ func (p *PDUStreamProvider) CompleteSync( peeks, err := snapshot.PeeksInRange(ctx, req.Device.UserID, req.Device.ID, r) if err != nil { req.Log.WithError(err).Error("p.DB.PeeksInRange failed") - _ = snapshot.Rollback() + _ = snapshot.Reset() return from } for _, peek := range peeks { @@ -124,7 +126,9 @@ func (p *PDUStreamProvider) CompleteSync( ) if err != nil { req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed") - _ = snapshot.Rollback() + if err = snapshot.Reset(); err != nil { + return from + } continue // return from } req.Response.Rooms.Peek[peek.RoomID] = *jr @@ -156,13 +160,13 @@ func (p *PDUStreamProvider) IncrementalSync( if req.WantFullState { if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed") - _ = snapshot.Rollback() + _ = snapshot.Reset() return } } else { if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { req.Log.WithError(err).Error("p.DB.GetStateDeltas failed") - _ = snapshot.Rollback() + _ = snapshot.Reset() return } } @@ -177,7 +181,7 @@ func (p *PDUStreamProvider) IncrementalSync( if err = p.addIgnoredUsersToFilter(ctx, snapshot, req, &eventFilter); err != nil { req.Log.WithError(err).Error("unable to update event filter with ignored users") - _ = snapshot.Rollback() + _ = snapshot.Reset() } newPos = from @@ -197,10 +201,12 @@ func (p *PDUStreamProvider) IncrementalSync( var pos types.StreamPosition if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil { req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed") - _ = snapshot.Rollback() if err == context.DeadlineExceeded || err == context.Canceled { return newPos } + if err = snapshot.Reset(); err != nil { + return from + } continue // return to } // Reset the position, as it is only for the special case of newly joined rooms @@ -301,7 +307,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( events, err := applyHistoryVisibilityFilter(ctx, snapshot, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents) if err != nil { logrus.WithError(err).Error("unable to apply history visibility filter") - _ = snapshot.Rollback() + _ = snapshot.Reset() } if len(delta.StateEvents) > 0 { diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index d24c85620..8a3f01c24 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -67,7 +67,7 @@ func (p *PresenceStreamProvider) IncrementalSync( presences, err := snapshot.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000}) if err != nil { req.Log.WithError(err).Error("p.DB.PresenceAfter failed") - _ = snapshot.Rollback() + _ = snapshot.Reset() return from } diff --git a/syncapi/streams/stream_receipt.go b/syncapi/streams/stream_receipt.go index 40e5bd01e..79fd65bfe 100644 --- a/syncapi/streams/stream_receipt.go +++ b/syncapi/streams/stream_receipt.go @@ -52,7 +52,7 @@ func (p *ReceiptStreamProvider) IncrementalSync( lastPos, receipts, err := snapshot.RoomReceiptsAfter(ctx, joinedRooms, from) if err != nil { req.Log.WithError(err).Error("p.DB.RoomReceiptsAfter failed") - _ = snapshot.Rollback() + _ = snapshot.Reset() return from } diff --git a/syncapi/streams/stream_sendtodevice.go b/syncapi/streams/stream_sendtodevice.go index 3262832a3..c79efad06 100644 --- a/syncapi/streams/stream_sendtodevice.go +++ b/syncapi/streams/stream_sendtodevice.go @@ -44,7 +44,7 @@ func (p *SendToDeviceStreamProvider) IncrementalSync( lastPos, events, err := snapshot.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, from, to) if err != nil { req.Log.WithError(err).Error("p.DB.SendToDeviceUpdatesForSync failed") - _ = snapshot.Rollback() + _ = snapshot.Reset() return from }