From c32d471dfb15925fc2f63ac5b88f01c53933c486 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 28 Sep 2022 14:28:39 +0100 Subject: [PATCH] Don't use transactional isolation on SQLite --- syncapi/storage/shared/syncserver.go | 94 +++++++++++++++------------ syncapi/storage/sqlite3/syncserver.go | 7 ++ 2 files changed, 61 insertions(+), 40 deletions(-) diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index e514551df..73042d3d1 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -57,7 +57,21 @@ type Database struct { type DatabaseSnapshot struct { *Database - *sql.Tx + txn *sql.Tx +} + +func (d *DatabaseSnapshot) Commit() error { + if d.txn == nil { + return nil + } + return d.txn.Commit() +} + +func (d *DatabaseSnapshot) Rollback() error { + if d.txn == nil { + return nil + } + return d.txn.Rollback() } func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseSnapshot, error) { @@ -75,7 +89,7 @@ func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseSnapshot, } return &DatabaseSnapshot{ Database: d, - Tx: txn, + txn: txn, }, nil } @@ -128,39 +142,39 @@ func (d *DatabaseSnapshot) MaxStreamPositionForNotificationData(ctx context.Cont } func (d *DatabaseSnapshot) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { - return d.CurrentRoomState.SelectCurrentState(ctx, d.Tx, roomID, stateFilterPart, excludeEventIDs) + return d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilterPart, excludeEventIDs) } func (d *DatabaseSnapshot) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) { - return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, d.Tx, userID, membership) + return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, d.txn, userID, membership) } func (d *DatabaseSnapshot) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) { - return d.Memberships.SelectMembershipCount(ctx, d.Tx, roomID, membership, pos) + return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos) } func (d *DatabaseSnapshot) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) { - return d.Memberships.SelectHeroes(ctx, d.Tx, roomID, userID, memberships) + return d.Memberships.SelectHeroes(ctx, d.txn, roomID, userID, memberships) } func (d *DatabaseSnapshot) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { - return d.OutputEvents.SelectRecentEvents(ctx, d.Tx, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) + return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) } func (d *DatabaseSnapshot) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) { - return d.Topology.SelectPositionInTopology(ctx, d.Tx, eventID) + return d.Topology.SelectPositionInTopology(ctx, d.txn, eventID) } func (d *DatabaseSnapshot) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) { - return d.Invites.SelectInviteEventsInRange(ctx, d.Tx, targetUserID, r) + return d.Invites.SelectInviteEventsInRange(ctx, d.txn, targetUserID, r) } func (d *DatabaseSnapshot) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) { - return d.Peeks.SelectPeeksInRange(ctx, d.Tx, userID, deviceID, r) + return d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, deviceID, r) } func (d *DatabaseSnapshot) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { - return d.Receipts.SelectRoomReceiptsAfter(ctx, d.Tx, roomIDs, streamPos) + return d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos) } // Events lookups a list of event by their event ID. @@ -169,7 +183,7 @@ func (d *DatabaseSnapshot) RoomReceiptsAfter(ctx context.Context, roomIDs []stri // Returns an error if there was a problem talking with the database. // Does not include any transaction IDs in the returned events. func (d *DatabaseSnapshot) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { - streamEvents, err := d.OutputEvents.SelectEvents(ctx, d.Tx, eventIDs, nil, false) + streamEvents, err := d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, false) if err != nil { return nil, err } @@ -191,31 +205,31 @@ func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixse } func (d *DatabaseSnapshot) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { - return d.CurrentRoomState.SelectJoinedUsers(ctx, d.Tx) + return d.CurrentRoomState.SelectJoinedUsers(ctx, d.txn) } func (d *DatabaseSnapshot) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) { - return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, d.Tx, roomIDs) + return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, d.txn, roomIDs) } func (d *DatabaseSnapshot) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) { - return d.Peeks.SelectPeekingDevices(ctx, d.Tx) + return d.Peeks.SelectPeekingDevices(ctx, d.txn) } func (d *DatabaseSnapshot) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) { - return d.CurrentRoomState.SelectSharedUsers(ctx, d.Tx, userID, otherUserIDs) + return d.CurrentRoomState.SelectSharedUsers(ctx, d.txn, userID, otherUserIDs) } func (d *DatabaseSnapshot) GetStateEvent( ctx context.Context, roomID, evType, stateKey string, ) (*gomatrixserverlib.HeaderedEvent, error) { - return d.CurrentRoomState.SelectStateEvent(ctx, d.Tx, roomID, evType, stateKey) + return d.CurrentRoomState.SelectStateEvent(ctx, d.txn, roomID, evType, stateKey) } func (d *DatabaseSnapshot) GetStateEventsForRoom( ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, ) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) { - stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, d.Tx, roomID, stateFilter, nil) + stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil) return } @@ -300,7 +314,7 @@ func (d *DatabaseSnapshot) GetAccountDataInRange( ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter, ) (map[string][]string, types.StreamPosition, error) { - return d.AccountData.SelectAccountDataInRange(ctx, d.Tx, userID, r, accountDataFilterPart) + return d.AccountData.SelectAccountDataInRange(ctx, d.txn, userID, r, accountDataFilterPart) } // UpsertAccountData keeps track of new or updated account data, by saving the type @@ -493,27 +507,27 @@ func (d *DatabaseSnapshot) GetEventsInTopologicalRange( // Select the event IDs from the defined range. var eIDs []string eIDs, err = d.Topology.SelectEventIDsInRange( - ctx, d.Tx, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering, + ctx, d.txn, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering, ) if err != nil { return } // Retrieve the events' contents using their IDs. - events, err = d.OutputEvents.SelectEvents(ctx, d.Tx, eIDs, filter, true) + events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eIDs, filter, true) return } func (d *DatabaseSnapshot) BackwardExtremitiesForRoom( ctx context.Context, roomID string, ) (backwardExtremities map[string][]string, err error) { - return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.Tx, roomID) + return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.txn, roomID) } func (d *DatabaseSnapshot) MaxTopologicalPosition( ctx context.Context, roomID string, ) (types.TopologyToken, error) { - depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.Tx, roomID) + depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID) if err != nil { return types.TopologyToken{}, err } @@ -523,7 +537,7 @@ func (d *DatabaseSnapshot) MaxTopologicalPosition( func (d *DatabaseSnapshot) EventPositionInTopology( ctx context.Context, eventID string, ) (types.TopologyToken, error) { - depth, stream, err := d.Topology.SelectPositionInTopology(ctx, d.Tx, eventID) + depth, stream, err := d.Topology.SelectPositionInTopology(ctx, d.txn, eventID) if err != nil { return types.TopologyToken{}, err } @@ -533,12 +547,12 @@ func (d *DatabaseSnapshot) EventPositionInTopology( func (d *DatabaseSnapshot) StreamToTopologicalPosition( ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool, ) (types.TopologyToken, error) { - topoPos, err := d.Topology.SelectStreamToTopologicalPosition(ctx, d.Tx, roomID, streamPos, backwardOrdering) + topoPos, err := d.Topology.SelectStreamToTopologicalPosition(ctx, d.txn, roomID, streamPos, backwardOrdering) switch { case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward return types.TopologyToken{PDUPosition: streamPos}, nil case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward - topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, d.Tx, roomID) + topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID) if err != nil { return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err) } @@ -553,7 +567,7 @@ func (d *DatabaseSnapshot) StreamToTopologicalPosition( func (d *DatabaseSnapshot) GetFilter( ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, ) error { - return d.Filter.SelectFilter(ctx, d.Tx, target, localpart, filterID) + return d.Filter.SelectFilter(ctx, d.txn, target, localpart, filterID) } func (d *Database) PutFilter( @@ -600,7 +614,7 @@ func (d *DatabaseSnapshot) GetBackwardTopologyPos( if len(events) == 0 { return zeroToken, nil } - pos, spos, err := d.Topology.SelectPositionInTopology(ctx, d.Tx, events[0].EventID()) + pos, spos, err := d.Topology.SelectPositionInTopology(ctx, d.txn, events[0].EventID()) if err != nil { return zeroToken, err } @@ -721,7 +735,7 @@ func (d *DatabaseSnapshot) GetStateDeltas( // Look up all memberships for the user. We only care about rooms that a // user has ever interacted with — joined to, kicked/banned from, left. - memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.Tx, userID) + memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.txn, userID) if err != nil { if err == sql.ErrNoRows { return nil, nil, nil @@ -739,14 +753,14 @@ func (d *DatabaseSnapshot) GetStateDeltas( } // get all the state events ever (i.e. for all available rooms) between these two positions - stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.Tx, r, stateFilter, allRoomIDs) + stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) if err != nil { if err == sql.ErrNoRows { return nil, nil, nil } return nil, nil, err } - state, err := d.fetchStateEvents(ctx, d.Tx, stateNeeded, eventMap) + state, err := d.fetchStateEvents(ctx, d.txn, stateNeeded, eventMap) if err != nil { if err == sql.ErrNoRows { return nil, nil, nil @@ -756,7 +770,7 @@ func (d *DatabaseSnapshot) GetStateDeltas( // find out which rooms this user is peeking, if any. // We do this before joins so any peeks get overwritten - peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.Tx, userID, device.ID, r) + peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r) if err != nil && err != sql.ErrNoRows { return nil, nil, err } @@ -839,7 +853,7 @@ func (d *DatabaseSnapshot) GetStateDeltasForFullStateSync( ) ([]types.StateDelta, []string, error) { // Look up all memberships for the user. We only care about rooms that a // user has ever interacted with — joined to, kicked/banned from, left. - memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.Tx, userID) + memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.txn, userID) if err != nil { if err == sql.ErrNoRows { return nil, nil, nil @@ -859,7 +873,7 @@ func (d *DatabaseSnapshot) GetStateDeltasForFullStateSync( // Use a reasonable initial capacity deltas := make(map[string]types.StateDelta) - peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.Tx, userID, device.ID, r) + peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r) if err != nil && err != sql.ErrNoRows { return nil, nil, err } @@ -883,14 +897,14 @@ func (d *DatabaseSnapshot) GetStateDeltasForFullStateSync( } // Get all the state events ever between these two positions - stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.Tx, r, stateFilter, allRoomIDs) + stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) if err != nil { if err == sql.ErrNoRows { return nil, nil, nil } return nil, nil, err } - state, err := d.fetchStateEvents(ctx, d.Tx, stateNeeded, eventMap) + state, err := d.fetchStateEvents(ctx, d.txn, stateNeeded, eventMap) if err != nil { if err == sql.ErrNoRows { return nil, nil, nil @@ -946,7 +960,7 @@ func (d *DatabaseSnapshot) currentStateStreamEventsForRoom( ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, ) ([]types.StreamEvent, error) { - allState, err := d.CurrentRoomState.SelectCurrentState(ctx, d.Tx, roomID, stateFilter, nil) + allState, err := d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil) if err != nil { return nil, err } @@ -984,7 +998,7 @@ func (d *DatabaseSnapshot) SendToDeviceUpdatesForSync( from, to types.StreamPosition, ) (types.StreamPosition, []types.SendToDeviceEvent, error) { // First of all, get our send-to-device updates for this user. - lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, d.Tx, userID, deviceID, from, to) + lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, d.txn, userID, deviceID, from, to) if err != nil { return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) } @@ -1032,7 +1046,7 @@ func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId } func (d *DatabaseSnapshot) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) { - _, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, d.Tx, roomIDs, streamPos) + _, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos) return receipts, err } @@ -1052,7 +1066,7 @@ func (d *DatabaseSnapshot) GetUserUnreadNotificationCountsForRooms(ctx context.C } roomIDs = append(roomIDs, roomID) } - return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.Tx, userID, roomIDs) + return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs) } func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index a84e2bd16..6edb5ace3 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -49,6 +49,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) return &d, nil } +func (d *SyncServerDatasource) NewDatabaseSnapshot(ctx context.Context) (*shared.DatabaseSnapshot, error) { + return &shared.DatabaseSnapshot{ + Database: &d.Database, + // not setting a transaction because SQLite doesn't support it + }, nil +} + func (d *SyncServerDatasource) prepare(ctx context.Context) (err error) { if err = d.streamID.Prepare(d.db); err != nil { return err