From 4440ab0a8aa3b0e814f10ed0610e9f04838a42ef Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 28 Sep 2022 15:17:07 +0100 Subject: [PATCH] Fix some bugs --- syncapi/storage/shared/syncserver.go | 16 ++++++++-------- .../storage/sqlite3/current_room_state_table.go | 8 +++++++- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index c48d1d651..0c6b7e025 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -94,7 +94,7 @@ func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseSnapshot, } func (d *DatabaseSnapshot) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) { - id, err := d.OutputEvents.SelectMaxEventID(ctx, nil) + id, err := d.OutputEvents.SelectMaxEventID(ctx, d.txn) if err != nil { return 0, fmt.Errorf("d.OutputEvents.SelectMaxEventID: %w", err) } @@ -102,7 +102,7 @@ func (d *DatabaseSnapshot) MaxStreamPositionForPDUs(ctx context.Context) (types. } func (d *DatabaseSnapshot) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) { - id, err := d.Receipts.SelectMaxReceiptID(ctx, nil) + id, err := d.Receipts.SelectMaxReceiptID(ctx, d.txn) if err != nil { return 0, fmt.Errorf("d.Receipts.SelectMaxReceiptID: %w", err) } @@ -110,7 +110,7 @@ func (d *DatabaseSnapshot) MaxStreamPositionForReceipts(ctx context.Context) (ty } func (d *DatabaseSnapshot) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) { - id, err := d.Invites.SelectMaxInviteID(ctx, nil) + id, err := d.Invites.SelectMaxInviteID(ctx, d.txn) if err != nil { return 0, fmt.Errorf("d.Invites.SelectMaxInviteID: %w", err) } @@ -118,7 +118,7 @@ func (d *DatabaseSnapshot) MaxStreamPositionForInvites(ctx context.Context) (typ } func (d *DatabaseSnapshot) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) { - id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, nil) + id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, d.txn) if err != nil { return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err) } @@ -126,7 +126,7 @@ func (d *DatabaseSnapshot) MaxStreamPositionForSendToDeviceMessages(ctx context. } func (d *DatabaseSnapshot) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) { - id, err := d.AccountData.SelectMaxAccountDataID(ctx, nil) + id, err := d.AccountData.SelectMaxAccountDataID(ctx, d.txn) if err != nil { return 0, fmt.Errorf("d.Invites.SelectMaxAccountDataID: %w", err) } @@ -134,7 +134,7 @@ func (d *DatabaseSnapshot) MaxStreamPositionForAccountData(ctx context.Context) } func (d *DatabaseSnapshot) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) { - id, err := d.NotificationData.SelectMaxID(ctx, nil) + id, err := d.NotificationData.SelectMaxID(ctx, d.txn) if err != nil { return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err) } @@ -1112,8 +1112,8 @@ func (d *DatabaseSnapshot) PresenceAfter(ctx context.Context, after types.Stream return d.Presence.GetPresenceAfter(ctx, d.txn, after, filter) } -func (d *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { - return d.Presence.GetMaxPresenceID(ctx, nil) +func (d *DatabaseSnapshot) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { + return d.Presence.GetMaxPresenceID(ctx, d.txn) } func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index ba6d8126c..c4019fed2 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -367,7 +367,13 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs( for start < len(eventIDs) { n := minOfInts(len(eventIDs)-start, 999) query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", sqlutil.QueryVariadic(n), 1) - rows, err := txn.QueryContext(ctx, query, iEventIDs[start:start+n]...) + var rows *sql.Rows + var err error + if txn == nil { + rows, err = s.db.QueryContext(ctx, query, iEventIDs[start:start+n]...) + } else { + rows, err = txn.QueryContext(ctx, query, iEventIDs[start:start+n]...) + } if err != nil { return nil, err }