More tweaks

This commit is contained in:
Neil Alexander 2021-01-19 17:07:59 +00:00
parent e3c625b5ea
commit 142299c2b7
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
3 changed files with 16 additions and 10 deletions

View file

@ -180,7 +180,7 @@ func (s *currentRoomStateStatements) SelectCurrentState(
stateFilter *gomatrixserverlib.StateFilter,
) ([]*gomatrixserverlib.HeaderedEvent, error) {
stmt, params, err := prepareWithFilters(
s.db, selectCurrentStateSQL,
s.db, txn, selectCurrentStateSQL,
[]interface{}{},
stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes,
@ -190,7 +190,7 @@ func (s *currentRoomStateStatements) SelectCurrentState(
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...)
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
}

View file

@ -16,7 +16,7 @@ import (
// and it's easier just to have the caller extract the relevant
// parts.
func prepareWithFilters(
db *sql.DB, query string, params []interface{},
db *sql.DB, txn *sql.Tx, query string, params []interface{},
senders, notsenders, types, nottypes []string,
limit int, order string,
) (*sql.Stmt, []interface{}, error) {
@ -51,7 +51,13 @@ func prepareWithFilters(
query += fmt.Sprintf(" LIMIT $%d", offset+1)
params = append(params, limit)
stmt, err := db.Prepare(query)
var stmt *sql.Stmt
var err error
if txn != nil {
stmt, err = txn.Prepare(query)
} else {
stmt, err = db.Prepare(query)
}
if err != nil {
return nil, nil, fmt.Errorf("s.db.Prepare: %w", err)
}

View file

@ -144,7 +144,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
stateFilter *gomatrixserverlib.StateFilter,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt, params, err := prepareWithFilters(
s.db, selectStateInRangeSQL,
s.db, txn, selectStateInRangeSQL,
[]interface{}{
r.Low(), r.High(),
},
@ -156,7 +156,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...)
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, nil, err
}
@ -317,7 +317,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
}
stmt, params, err := prepareWithFilters(
s.db, query,
s.db, txn, query,
[]interface{}{
roomID, r.Low(), r.High(),
},
@ -329,7 +329,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
}
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...)
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, false, err
}
@ -365,7 +365,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
) ([]types.StreamEvent, error) {
stmt, params, err := prepareWithFilters(
s.db, selectEarlyEventsSQL,
s.db, txn, selectEarlyEventsSQL,
[]interface{}{
roomID, r.Low(), r.High(),
},
@ -376,7 +376,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...)
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
}