More tweaks

This commit is contained in:
Neil Alexander 2021-01-19 17:07:59 +00:00
parent e3c625b5ea
commit 142299c2b7
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
3 changed files with 16 additions and 10 deletions

View file

@ -180,7 +180,7 @@ func (s *currentRoomStateStatements) SelectCurrentState(
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
) ([]*gomatrixserverlib.HeaderedEvent, error) { ) ([]*gomatrixserverlib.HeaderedEvent, error) {
stmt, params, err := prepareWithFilters( stmt, params, err := prepareWithFilters(
s.db, selectCurrentStateSQL, s.db, txn, selectCurrentStateSQL,
[]interface{}{}, []interface{}{},
stateFilter.Senders, stateFilter.NotSenders, stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes, stateFilter.Types, stateFilter.NotTypes,
@ -190,7 +190,7 @@ func (s *currentRoomStateStatements) SelectCurrentState(
return nil, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
} }
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...) rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -16,7 +16,7 @@ import (
// and it's easier just to have the caller extract the relevant // and it's easier just to have the caller extract the relevant
// parts. // parts.
func prepareWithFilters( func prepareWithFilters(
db *sql.DB, query string, params []interface{}, db *sql.DB, txn *sql.Tx, query string, params []interface{},
senders, notsenders, types, nottypes []string, senders, notsenders, types, nottypes []string,
limit int, order string, limit int, order string,
) (*sql.Stmt, []interface{}, error) { ) (*sql.Stmt, []interface{}, error) {
@ -51,7 +51,13 @@ func prepareWithFilters(
query += fmt.Sprintf(" LIMIT $%d", offset+1) query += fmt.Sprintf(" LIMIT $%d", offset+1)
params = append(params, limit) params = append(params, limit)
stmt, err := db.Prepare(query) var stmt *sql.Stmt
var err error
if txn != nil {
stmt, err = txn.Prepare(query)
} else {
stmt, err = db.Prepare(query)
}
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("s.db.Prepare: %w", err) return nil, nil, fmt.Errorf("s.db.Prepare: %w", err)
} }

View file

@ -144,7 +144,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) { ) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt, params, err := prepareWithFilters( stmt, params, err := prepareWithFilters(
s.db, selectStateInRangeSQL, s.db, txn, selectStateInRangeSQL,
[]interface{}{ []interface{}{
r.Low(), r.High(), r.Low(), r.High(),
}, },
@ -156,7 +156,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
} }
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...) rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -317,7 +317,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
} }
stmt, params, err := prepareWithFilters( stmt, params, err := prepareWithFilters(
s.db, query, s.db, txn, query,
[]interface{}{ []interface{}{
roomID, r.Low(), r.High(), roomID, r.Low(), r.High(),
}, },
@ -329,7 +329,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
} }
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...) rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
@ -365,7 +365,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
stmt, params, err := prepareWithFilters( stmt, params, err := prepareWithFilters(
s.db, selectEarlyEventsSQL, s.db, txn, selectEarlyEventsSQL,
[]interface{}{ []interface{}{
roomID, r.Low(), r.High(), roomID, r.Low(), r.High(),
}, },
@ -376,7 +376,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
if err != nil { if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
} }
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...) rows, err := stmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, err return nil, err
} }