From 1506d595ded3f8211ac2bc0805a48af74ef92da3 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 30 Sep 2022 10:12:03 +0100 Subject: [PATCH] Try to de-race stream positions --- syncapi/routing/filter.go | 8 +------- syncapi/storage/interface.go | 8 ++++---- syncapi/storage/shared/snapshot.go | 6 ------ syncapi/storage/shared/syncserver.go | 6 ++++++ syncapi/streams/stream_accountdata.go | 10 ++++++++-- syncapi/streams/stream_devicelist.go | 8 +++++++- syncapi/streams/stream_invite.go | 12 +++++++++--- syncapi/streams/stream_notificationdata.go | 15 ++++++++++++--- syncapi/streams/stream_pdu.go | 10 ++++++++-- syncapi/streams/stream_presence.go | 13 +++++++++++-- syncapi/streams/stream_receipt.go | 13 +++++++++++-- syncapi/streams/stream_sendtodevice.go | 13 +++++++++++-- syncapi/streams/stream_typing.go | 8 +++++++- syncapi/streams/streamprovider.go | 5 ++++- syncapi/streams/streams.go | 20 ++++++++++---------- syncapi/streams/template_stream.go | 2 +- syncapi/sync/request.go | 9 +-------- syncapi/syncapi.go | 11 ++++++++++- 18 files changed, 121 insertions(+), 56 deletions(-) diff --git a/syncapi/routing/filter.go b/syncapi/routing/filter.go index bb506ec39..f5acdbde3 100644 --- a/syncapi/routing/filter.go +++ b/syncapi/routing/filter.go @@ -45,14 +45,8 @@ func GetFilter( return jsonerror.InternalServerError() } - snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) - if err != nil { - return jsonerror.InternalServerError() - } - defer snapshot.Rollback() // nolint:errcheck - filter := gomatrixserverlib.DefaultFilter() - if err := snapshot.GetFilter(req.Context(), &filter, localpart, filterID); err != nil { + if err := syncDB.GetFilter(req.Context(), &filter, localpart, filterID); err != nil { //TODO better error handling. This error message is *probably* right, // but if there are obscure db errors, this will also be returned, // even though it is not correct. diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 14b8dbc0b..6ed9d7b16 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -90,10 +90,6 @@ type DatabaseSnapshot interface { // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the // relevant events within the given ranges for the supplied user ID and device ID. SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error) - // GetFilter looks up the filter associated with a given local user and filter ID - // and populates the target filter. Otherwise returns an error if no such filter exists - // or if there was an error talking to the database. - GetFilter(ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string) error // GetRoomReceipts gets all receipts for a given roomID GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) @@ -118,6 +114,10 @@ type Database interface { NewDatabaseSnapshot(ctx context.Context) (*shared.DatabaseSnapshot, error) NewDatabaseWritable(ctx context.Context) (*shared.DatabaseSnapshot, error) + // GetFilter looks up the filter associated with a given local user and filter ID + // and populates the target filter. Otherwise returns an error if no such filter exists + // or if there was an error talking to the database. + GetFilter(ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string) error // Events lookups a list of event by their event ID. // Returns a list of events matching the requested IDs found in the database. // If an event is not found in the database then it will be omitted from the list. diff --git a/syncapi/storage/shared/snapshot.go b/syncapi/storage/shared/snapshot.go index deac73482..834d9d869 100644 --- a/syncapi/storage/shared/snapshot.go +++ b/syncapi/storage/shared/snapshot.go @@ -253,12 +253,6 @@ func (d *DatabaseSnapshot) StreamToTopologicalPosition( } } -func (d *DatabaseSnapshot) GetFilter( - ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, -) error { - return d.Filter.SelectFilter(ctx, d.txn, target, localpart, filterID) -} - // GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the // oldest event in the room's topology. func (d *DatabaseSnapshot) GetBackwardTopologyPos( diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 8ad7b32ab..edd28fac0 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -85,6 +85,12 @@ func (d *Database) NewDatabaseWritable(ctx context.Context) (*DatabaseSnapshot, }, nil } +func (d *Database) GetFilter( + ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, +) error { + return d.Filter.SelectFilter(ctx, nil, target, localpart, filterID) +} + func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false) if err != nil { diff --git a/syncapi/streams/stream_accountdata.go b/syncapi/streams/stream_accountdata.go index e9db33061..5aed57904 100644 --- a/syncapi/streams/stream_accountdata.go +++ b/syncapi/streams/stream_accountdata.go @@ -23,11 +23,17 @@ func (p *AccountDataStreamProvider) Setup( p.latestMutex.Lock() defer p.latestMutex.Unlock() + p.latest = p.latestPosition(ctx, snapshot) +} + +func (p *AccountDataStreamProvider) latestPosition( + ctx context.Context, snapshot storage.DatabaseSnapshot, +) types.StreamPosition { id, err := snapshot.MaxStreamPositionForAccountData(context.Background()) if err != nil { panic(err) } - p.latest = id + return id } func (p *AccountDataStreamProvider) CompleteSync( @@ -35,7 +41,7 @@ func (p *AccountDataStreamProvider) CompleteSync( snapshot storage.DatabaseSnapshot, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.latestPosition(ctx, snapshot)) } func (p *AccountDataStreamProvider) IncrementalSync( diff --git a/syncapi/streams/stream_devicelist.go b/syncapi/streams/stream_devicelist.go index a806ef111..d22f4963f 100644 --- a/syncapi/streams/stream_devicelist.go +++ b/syncapi/streams/stream_devicelist.go @@ -16,12 +16,18 @@ type DeviceListStreamProvider struct { keyAPI keyapi.SyncKeyAPI } +func (p *DeviceListStreamProvider) latestPosition( + ctx context.Context, snapshot storage.DatabaseSnapshot, +) types.StreamPosition { + return 0 // TODO: is this the right thing to do? +} + func (p *DeviceListStreamProvider) CompleteSync( ctx context.Context, snapshot storage.DatabaseSnapshot, req *types.SyncRequest, ) types.StreamPosition { - return p.LatestPosition(ctx) + return p.latestPosition(ctx, snapshot) } func (p *DeviceListStreamProvider) IncrementalSync( diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 029302262..d6a8a8940 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -25,11 +25,17 @@ func (p *InviteStreamProvider) Setup( p.latestMutex.Lock() defer p.latestMutex.Unlock() - id, err := snapshot.MaxStreamPositionForInvites(context.Background()) + p.latest = p.latestPosition(ctx, snapshot) +} + +func (p *InviteStreamProvider) latestPosition( + ctx context.Context, snapshot storage.DatabaseSnapshot, +) types.StreamPosition { + id, err := snapshot.MaxStreamPositionForAccountData(context.Background()) if err != nil { panic(err) } - p.latest = id + return id } func (p *InviteStreamProvider) CompleteSync( @@ -37,7 +43,7 @@ func (p *InviteStreamProvider) CompleteSync( snapshot storage.DatabaseSnapshot, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.latestPosition(ctx, snapshot)) } func (p *InviteStreamProvider) IncrementalSync( diff --git a/syncapi/streams/stream_notificationdata.go b/syncapi/streams/stream_notificationdata.go index 6944640e9..3c315837b 100644 --- a/syncapi/streams/stream_notificationdata.go +++ b/syncapi/streams/stream_notificationdata.go @@ -16,11 +16,20 @@ func (p *NotificationDataStreamProvider) Setup( ) { p.DefaultStreamProvider.Setup(ctx, snapshot) + p.latestMutex.Lock() + defer p.latestMutex.Unlock() + + p.latest = p.latestPosition(ctx, snapshot) +} + +func (p *NotificationDataStreamProvider) latestPosition( + ctx context.Context, snapshot storage.DatabaseSnapshot, +) types.StreamPosition { id, err := snapshot.MaxStreamPositionForNotificationData(context.Background()) if err != nil { panic(err) } - p.latest = id + return id } func (p *NotificationDataStreamProvider) CompleteSync( @@ -28,7 +37,7 @@ func (p *NotificationDataStreamProvider) CompleteSync( snapshot storage.DatabaseSnapshot, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.latestPosition(ctx, snapshot)) } func (p *NotificationDataStreamProvider) IncrementalSync( @@ -59,5 +68,5 @@ func (p *NotificationDataStreamProvider) IncrementalSync( req.Response.Rooms.Join[roomID] = jr } - return p.LatestPosition(ctx) + return p.latestPosition(ctx, snapshot) } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 4779558bd..366878b68 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -47,11 +47,17 @@ func (p *PDUStreamProvider) Setup( p.latestMutex.Lock() defer p.latestMutex.Unlock() + p.latest = p.latestPosition(ctx, snapshot) +} + +func (p *PDUStreamProvider) latestPosition( + ctx context.Context, snapshot storage.DatabaseSnapshot, +) types.StreamPosition { id, err := snapshot.MaxStreamPositionForPDUs(context.Background()) if err != nil { panic(err) } - p.latest = id + return id } func (p *PDUStreamProvider) CompleteSync( @@ -60,7 +66,7 @@ func (p *PDUStreamProvider) CompleteSync( req *types.SyncRequest, ) types.StreamPosition { from := types.StreamPosition(0) - to := p.LatestPosition(ctx) + to := p.latestPosition(ctx, snapshot) // Get the current sync position which we will base the sync response on. // For complete syncs, we want to start at the most recent events and work diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index b2f9a0b47..ce3c2f0cd 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -39,11 +39,20 @@ func (p *PresenceStreamProvider) Setup( ) { p.DefaultStreamProvider.Setup(ctx, snapshot) + p.latestMutex.Lock() + defer p.latestMutex.Unlock() + + p.latest = p.latestPosition(ctx, snapshot) +} + +func (p *PresenceStreamProvider) latestPosition( + ctx context.Context, snapshot storage.DatabaseSnapshot, +) types.StreamPosition { id, err := snapshot.MaxStreamPositionForPresence(context.Background()) if err != nil { panic(err) } - p.latest = id + return id } func (p *PresenceStreamProvider) CompleteSync( @@ -51,7 +60,7 @@ func (p *PresenceStreamProvider) CompleteSync( snapshot storage.DatabaseSnapshot, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.latestPosition(ctx, snapshot)) } func (p *PresenceStreamProvider) IncrementalSync( diff --git a/syncapi/streams/stream_receipt.go b/syncapi/streams/stream_receipt.go index e7140b7e1..0e8eaffa6 100644 --- a/syncapi/streams/stream_receipt.go +++ b/syncapi/streams/stream_receipt.go @@ -18,11 +18,20 @@ func (p *ReceiptStreamProvider) Setup( ) { p.DefaultStreamProvider.Setup(ctx, snapshot) + p.latestMutex.Lock() + defer p.latestMutex.Unlock() + + p.latest = p.latestPosition(ctx, snapshot) +} + +func (p *ReceiptStreamProvider) latestPosition( + ctx context.Context, snapshot storage.DatabaseSnapshot, +) types.StreamPosition { id, err := snapshot.MaxStreamPositionForReceipts(context.Background()) if err != nil { panic(err) } - p.latest = id + return id } func (p *ReceiptStreamProvider) CompleteSync( @@ -30,7 +39,7 @@ func (p *ReceiptStreamProvider) CompleteSync( snapshot storage.DatabaseSnapshot, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.latestPosition(ctx, snapshot)) } func (p *ReceiptStreamProvider) IncrementalSync( diff --git a/syncapi/streams/stream_sendtodevice.go b/syncapi/streams/stream_sendtodevice.go index 0bd2e2c6d..ef6c0b500 100644 --- a/syncapi/streams/stream_sendtodevice.go +++ b/syncapi/streams/stream_sendtodevice.go @@ -16,11 +16,20 @@ func (p *SendToDeviceStreamProvider) Setup( ) { p.DefaultStreamProvider.Setup(ctx, snapshot) + p.latestMutex.Lock() + defer p.latestMutex.Unlock() + + p.latest = p.latestPosition(ctx, snapshot) +} + +func (p *SendToDeviceStreamProvider) latestPosition( + ctx context.Context, snapshot storage.DatabaseSnapshot, +) types.StreamPosition { id, err := snapshot.MaxStreamPositionForSendToDeviceMessages(context.Background()) if err != nil { panic(err) } - p.latest = id + return id } func (p *SendToDeviceStreamProvider) CompleteSync( @@ -28,7 +37,7 @@ func (p *SendToDeviceStreamProvider) CompleteSync( snapshot storage.DatabaseSnapshot, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.latestPosition(ctx, snapshot)) } func (p *SendToDeviceStreamProvider) IncrementalSync( diff --git a/syncapi/streams/stream_typing.go b/syncapi/streams/stream_typing.go index e895b80d7..d65f9c805 100644 --- a/syncapi/streams/stream_typing.go +++ b/syncapi/streams/stream_typing.go @@ -15,12 +15,18 @@ type TypingStreamProvider struct { EDUCache *caching.EDUCache } +func (p *TypingStreamProvider) latestPosition( + ctx context.Context, snapshot storage.DatabaseSnapshot, +) types.StreamPosition { + return types.StreamPosition(p.EDUCache.GetLatestSyncPosition()) +} + func (p *TypingStreamProvider) CompleteSync( ctx context.Context, snapshot storage.DatabaseSnapshot, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.latestPosition(ctx, snapshot)) } func (p *TypingStreamProvider) IncrementalSync( diff --git a/syncapi/streams/streamprovider.go b/syncapi/streams/streamprovider.go index cd21d9fb9..a269b20a1 100644 --- a/syncapi/streams/streamprovider.go +++ b/syncapi/streams/streamprovider.go @@ -24,5 +24,8 @@ type StreamProvider interface { IncrementalSync(ctx context.Context, snapshot storage.DatabaseSnapshot, req *types.SyncRequest, from, to types.StreamPosition) types.StreamPosition // LatestPosition returns the latest stream position for this stream. - LatestPosition(ctx context.Context) types.StreamPosition + LatestPosition(ctx context.Context, snapshot storage.DatabaseSnapshot) types.StreamPosition + + // latestPosition gets the latest stream position from the database for this stream. + latestPosition(ctx context.Context, snapshot storage.DatabaseSnapshot) types.StreamPosition } diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index eccbb3a4f..f4113dd27 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -87,16 +87,16 @@ func NewSyncStreamProviders( return streams } -func (s *Streams) Latest(ctx context.Context) types.StreamingToken { +func (s *Streams) Latest(ctx context.Context, snapshot storage.DatabaseSnapshot) types.StreamingToken { return types.StreamingToken{ - PDUPosition: s.PDUStreamProvider.LatestPosition(ctx), - TypingPosition: s.TypingStreamProvider.LatestPosition(ctx), - ReceiptPosition: s.ReceiptStreamProvider.LatestPosition(ctx), - InvitePosition: s.InviteStreamProvider.LatestPosition(ctx), - SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx), - AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx), - NotificationDataPosition: s.NotificationDataStreamProvider.LatestPosition(ctx), - DeviceListPosition: s.DeviceListStreamProvider.LatestPosition(ctx), - PresencePosition: s.PresenceStreamProvider.LatestPosition(ctx), + PDUPosition: s.PDUStreamProvider.LatestPosition(ctx, snapshot), + TypingPosition: s.TypingStreamProvider.LatestPosition(ctx, snapshot), + ReceiptPosition: s.ReceiptStreamProvider.LatestPosition(ctx, snapshot), + InvitePosition: s.InviteStreamProvider.LatestPosition(ctx, snapshot), + SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx, snapshot), + AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx, snapshot), + NotificationDataPosition: s.NotificationDataStreamProvider.LatestPosition(ctx, snapshot), + DeviceListPosition: s.DeviceListStreamProvider.LatestPosition(ctx, snapshot), + PresencePosition: s.PresenceStreamProvider.LatestPosition(ctx, snapshot), } } diff --git a/syncapi/streams/template_stream.go b/syncapi/streams/template_stream.go index b778be53f..78352b033 100644 --- a/syncapi/streams/template_stream.go +++ b/syncapi/streams/template_stream.go @@ -31,7 +31,7 @@ func (p *DefaultStreamProvider) Advance( } func (p *DefaultStreamProvider) LatestPosition( - ctx context.Context, + ctx context.Context, snapshot storage.DatabaseSnapshot, ) types.StreamPosition { p.latestMutex.RLock() defer p.latestMutex.RUnlock() diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index 8b6d78619..268ed70c6 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -48,13 +48,6 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat } } - snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) - if err != nil { - logrus.WithError(err).Error("Failed to acquire database snapshot for sync request") - return nil, err - } - defer snapshot.Rollback() // nolint:errcheck - // Create a default filter and apply a stored filter on top of it (if specified) filter := gomatrixserverlib.DefaultFilter() filterQuery := req.URL.Query().Get("filter") @@ -71,7 +64,7 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } - if err := snapshot.GetFilter(req.Context(), &filter, localpart, filterQuery); err != nil && err != sql.ErrNoRows { + if err := syncDB.GetFilter(req.Context(), &filter, localpart, filterQuery); err != nil && err != sql.ErrNoRows { util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed") return nil, fmt.Errorf("syncDB.GetFilter: %w", err) } diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index be19310f2..8fb5ab9dd 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -56,7 +56,16 @@ func AddPublicRoutes( eduCache := caching.NewTypingCache() notifier := notifier.NewNotifier() streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, base.Caches, notifier) - notifier.SetCurrentPosition(streams.Latest(context.Background())) + + snapshot, err := syncDB.NewDatabaseSnapshot(base.ProcessContext.Context()) + if err != nil { + logrus.WithError(err).Fatalf("Failed to acquire database snapshot for sync startup") + } + notifier.SetCurrentPosition(streams.Latest(context.Background(), snapshot)) + if err = snapshot.Rollback(); err != nil { + logrus.WithError(err).Fatalf("Failed to roll back snapshot for sync startup") + } + if err = notifier.Load(context.Background(), syncDB); err != nil { logrus.WithError(err).Panicf("failed to load notifier ") }