From 142299c2b7fae91e4d5fec0e7013e2a10147b8a5 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 19 Jan 2021 17:07:59 +0000 Subject: [PATCH] More tweaks --- syncapi/storage/sqlite3/current_room_state_table.go | 4 ++-- syncapi/storage/sqlite3/filtering.go | 10 ++++++++-- syncapi/storage/sqlite3/output_room_events_table.go | 12 ++++++------ 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 21202d2ac..517b4ec9f 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -180,7 +180,7 @@ func (s *currentRoomStateStatements) SelectCurrentState( stateFilter *gomatrixserverlib.StateFilter, ) ([]*gomatrixserverlib.HeaderedEvent, error) { stmt, params, err := prepareWithFilters( - s.db, selectCurrentStateSQL, + s.db, txn, selectCurrentStateSQL, []interface{}{}, stateFilter.Senders, stateFilter.NotSenders, stateFilter.Types, stateFilter.NotTypes, @@ -190,7 +190,7 @@ func (s *currentRoomStateStatements) SelectCurrentState( return nil, fmt.Errorf("s.prepareWithFilters: %w", err) } - rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...) + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, err } diff --git a/syncapi/storage/sqlite3/filtering.go b/syncapi/storage/sqlite3/filtering.go index d8465388e..0f6bd681f 100644 --- a/syncapi/storage/sqlite3/filtering.go +++ b/syncapi/storage/sqlite3/filtering.go @@ -16,7 +16,7 @@ import ( // and it's easier just to have the caller extract the relevant // parts. func prepareWithFilters( - db *sql.DB, query string, params []interface{}, + db *sql.DB, txn *sql.Tx, query string, params []interface{}, senders, notsenders, types, nottypes []string, limit int, order string, ) (*sql.Stmt, []interface{}, error) { @@ -51,7 +51,13 @@ func prepareWithFilters( query += fmt.Sprintf(" LIMIT $%d", offset+1) params = append(params, limit) - stmt, err := db.Prepare(query) + var stmt *sql.Stmt + var err error + if txn != nil { + stmt, err = txn.Prepare(query) + } else { + stmt, err = db.Prepare(query) + } if err != nil { return nil, nil, fmt.Errorf("s.db.Prepare: %w", err) } diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index cc48176ca..4debc53f6 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -144,7 +144,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange( stateFilter *gomatrixserverlib.StateFilter, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { stmt, params, err := prepareWithFilters( - s.db, selectStateInRangeSQL, + s.db, txn, selectStateInRangeSQL, []interface{}{ r.Low(), r.High(), }, @@ -156,7 +156,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange( return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err) } - rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...) + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, nil, err } @@ -317,7 +317,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( } stmt, params, err := prepareWithFilters( - s.db, query, + s.db, txn, query, []interface{}{ roomID, r.Low(), r.High(), }, @@ -329,7 +329,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) } - rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...) + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, false, err } @@ -365,7 +365,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, ) ([]types.StreamEvent, error) { stmt, params, err := prepareWithFilters( - s.db, selectEarlyEventsSQL, + s.db, txn, selectEarlyEventsSQL, []interface{}{ roomID, r.Low(), r.High(), }, @@ -376,7 +376,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( if err != nil { return nil, fmt.Errorf("s.prepareWithFilters: %w", err) } - rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...) + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, err }