Make txn *sql.Tx arguments optional everywhere using a utility function (#191)

* Make txn *sql.Tx arguments optional everywhere using a utility function

* Clarify that if the txn is nil the stmt will run outside a transaction
This commit is contained in:
Mark Haines 2017-08-21 17:20:23 +01:00 committed by GitHub
parent 57b7097368
commit 808c2e09f6
11 changed files with 55 additions and 61 deletions

View file

@ -55,3 +55,14 @@ func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
succeeded = true succeeded = true
return return
} }
// TxStmt wraps an SQL stmt inside an optional transaction.
// If the transaction is nil then it returns the original statement that will
// run outside of a transaction.
// Otherwise returns a copy of the statement that will run inside the transaction.
func TxStmt(transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt {
if transaction != nil {
statement = transaction.Stmt(statement)
}
return statement
}

View file

@ -18,6 +18,7 @@ import (
"database/sql" "database/sql"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -79,18 +80,18 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) {
func (s *joinedHostsStatements) insertJoinedHosts( func (s *joinedHostsStatements) insertJoinedHosts(
txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName,
) error { ) error {
_, err := txn.Stmt(s.insertJoinedHostsStmt).Exec(roomID, eventID, serverName) _, err := common.TxStmt(txn, s.insertJoinedHostsStmt).Exec(roomID, eventID, serverName)
return err return err
} }
func (s *joinedHostsStatements) deleteJoinedHosts(txn *sql.Tx, eventIDs []string) error { func (s *joinedHostsStatements) deleteJoinedHosts(txn *sql.Tx, eventIDs []string) error {
_, err := txn.Stmt(s.deleteJoinedHostsStmt).Exec(pq.StringArray(eventIDs)) _, err := common.TxStmt(txn, s.deleteJoinedHostsStmt).Exec(pq.StringArray(eventIDs))
return err return err
} }
func (s *joinedHostsStatements) selectJoinedHosts(txn *sql.Tx, roomID string, func (s *joinedHostsStatements) selectJoinedHosts(txn *sql.Tx, roomID string,
) ([]types.JoinedHost, error) { ) ([]types.JoinedHost, error) {
rows, err := txn.Stmt(s.selectJoinedHostsStmt).Query(roomID) rows, err := common.TxStmt(txn, s.selectJoinedHostsStmt).Query(roomID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -16,6 +16,8 @@ package storage
import ( import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common"
) )
const roomSchema = ` const roomSchema = `
@ -65,7 +67,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
// insertRoom inserts the room if it didn't already exist. // insertRoom inserts the room if it didn't already exist.
// If the room didn't exist then last_event_id is set to the empty string. // If the room didn't exist then last_event_id is set to the empty string.
func (s *roomStatements) insertRoom(txn *sql.Tx, roomID string) error { func (s *roomStatements) insertRoom(txn *sql.Tx, roomID string) error {
_, err := txn.Stmt(s.insertRoomStmt).Exec(roomID) _, err := common.TxStmt(txn, s.insertRoomStmt).Exec(roomID)
return err return err
} }
@ -74,7 +76,7 @@ func (s *roomStatements) insertRoom(txn *sql.Tx, roomID string) error {
// exists by calling insertRoom first. // exists by calling insertRoom first.
func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string, error) { func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string, error) {
var lastEventID string var lastEventID string
err := txn.Stmt(s.selectRoomForUpdateStmt).QueryRow(roomID).Scan(&lastEventID) err := common.TxStmt(txn, s.selectRoomForUpdateStmt).QueryRow(roomID).Scan(&lastEventID)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -84,6 +86,6 @@ func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string
// updateRoom updates the last_event_id for the room. selectRoomForUpdate should // updateRoom updates the last_event_id for the room. selectRoomForUpdate should
// have already been called earlier within the transaction. // have already been called earlier within the transaction.
func (s *roomStatements) updateRoom(txn *sql.Tx, roomID, lastEventID string) error { func (s *roomStatements) updateRoom(txn *sql.Tx, roomID, lastEventID string) error {
_, err := txn.Stmt(s.updateRoomStmt).Exec(roomID, lastEventID) _, err := common.TxStmt(txn, s.updateRoomStmt).Exec(roomID, lastEventID)
return err return err
} }

View file

@ -18,6 +18,7 @@ import (
"database/sql" "database/sql"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -92,21 +93,13 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
func (s *eventStateKeyStatements) insertEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) { func (s *eventStateKeyStatements) insertEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64 var eventStateKeyNID int64
stmt := s.insertEventStateKeyNIDStmt err := common.TxStmt(txn, s.insertEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID)
if txn != nil {
stmt = txn.Stmt(stmt)
}
err := stmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID)
return types.EventStateKeyNID(eventStateKeyNID), err return types.EventStateKeyNID(eventStateKeyNID), err
} }
func (s *eventStateKeyStatements) selectEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) { func (s *eventStateKeyStatements) selectEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64 var eventStateKeyNID int64
stmt := s.selectEventStateKeyNIDStmt err := common.TxStmt(txn, s.selectEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID)
if txn != nil {
stmt = txn.Stmt(stmt)
}
err := stmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID)
return types.EventStateKeyNID(eventStateKeyNID), err return types.EventStateKeyNID(eventStateKeyNID), err
} }
@ -131,11 +124,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(eventStateKeys []st
func (s *eventStateKeyStatements) selectEventStateKey(txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID) (string, error) { func (s *eventStateKeyStatements) selectEventStateKey(txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID) (string, error) {
var eventStateKey string var eventStateKey string
stmt := s.selectEventStateKeyStmt err := common.TxStmt(txn, s.selectEventStateKeyStmt).QueryRow(eventStateKeyNID).Scan(&eventStateKey)
if txn != nil {
stmt = txn.Stmt(stmt)
}
err := stmt.QueryRow(eventStateKeyNID).Scan(&eventStateKey)
return eventStateKey, err return eventStateKey, err
} }

View file

@ -19,6 +19,7 @@ import (
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -253,22 +254,22 @@ func (s *eventStatements) updateEventState(eventNID types.EventNID, stateNID typ
} }
func (s *eventStatements) selectEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) { func (s *eventStatements) selectEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) {
err = txn.Stmt(s.selectEventSentToOutputStmt).QueryRow(int64(eventNID)).Scan(&sentToOutput) err = common.TxStmt(txn, s.selectEventSentToOutputStmt).QueryRow(int64(eventNID)).Scan(&sentToOutput)
return return
} }
func (s *eventStatements) updateEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) error { func (s *eventStatements) updateEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) error {
_, err := txn.Stmt(s.updateEventSentToOutputStmt).Exec(int64(eventNID)) _, err := common.TxStmt(txn, s.updateEventSentToOutputStmt).Exec(int64(eventNID))
return err return err
} }
func (s *eventStatements) selectEventID(txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) { func (s *eventStatements) selectEventID(txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) {
err = txn.Stmt(s.selectEventIDStmt).QueryRow(int64(eventNID)).Scan(&eventID) err = common.TxStmt(txn, s.selectEventIDStmt).QueryRow(int64(eventNID)).Scan(&eventID)
return return
} }
func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) { func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) {
rows, err := txn.Stmt(s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs)) rows, err := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -17,6 +17,7 @@ package storage
import ( import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -94,7 +95,7 @@ func (s *inviteStatements) insertInviteEvent(
targetUserNID, senderUserNID types.EventStateKeyNID, targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte, inviteEventJSON []byte,
) (bool, error) { ) (bool, error) {
result, err := txn.Stmt(s.insertInviteEventStmt).Exec( result, err := common.TxStmt(txn, s.insertInviteEventStmt).Exec(
inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
) )
if err != nil { if err != nil {
@ -110,7 +111,7 @@ func (s *inviteStatements) insertInviteEvent(
func (s *inviteStatements) updateInviteRetired( func (s *inviteStatements) updateInviteRetired(
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) ([]string, error) { ) ([]string, error) {
rows, err := txn.Stmt(s.updateInviteRetiredStmt).Query(roomNID, targetUserNID) rows, err := common.TxStmt(txn, s.updateInviteRetiredStmt).Query(roomNID, targetUserNID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -17,6 +17,7 @@ package storage
import ( import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -115,14 +116,14 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
func (s *membershipStatements) insertMembership( func (s *membershipStatements) insertMembership(
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) error { ) error {
_, err := txn.Stmt(s.insertMembershipStmt).Exec(roomNID, targetUserNID) _, err := common.TxStmt(txn, s.insertMembershipStmt).Exec(roomNID, targetUserNID)
return err return err
} }
func (s *membershipStatements) selectMembershipForUpdate( func (s *membershipStatements) selectMembershipForUpdate(
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (membership membershipState, err error) { ) (membership membershipState, err error) {
err = txn.Stmt(s.selectMembershipForUpdateStmt).QueryRow( err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRow(
roomNID, targetUserNID, roomNID, targetUserNID,
).Scan(&membership) ).Scan(&membership)
return return
@ -179,7 +180,7 @@ func (s *membershipStatements) updateMembership(
senderUserNID types.EventStateKeyNID, membership membershipState, senderUserNID types.EventStateKeyNID, membership membershipState,
eventNID types.EventNID, eventNID types.EventNID,
) error { ) error {
_, err := txn.Stmt(s.updateMembershipStmt).Exec( _, err := common.TxStmt(txn, s.updateMembershipStmt).Exec(
roomNID, targetUserNID, senderUserNID, membership, eventNID, roomNID, targetUserNID, senderUserNID, membership, eventNID,
) )
return err return err

View file

@ -17,6 +17,7 @@ package storage
import ( import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -73,7 +74,7 @@ func (s *previousEventStatements) prepare(db *sql.DB) (err error) {
} }
func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error { func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error {
_, err := txn.Stmt(s.insertPreviousEventStmt).Exec(previousEventID, previousEventReferenceSHA256, int64(eventNID)) _, err := common.TxStmt(txn, s.insertPreviousEventStmt).Exec(previousEventID, previousEventReferenceSHA256, int64(eventNID))
return err return err
} }
@ -81,5 +82,5 @@ func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEvent
// Returns sql.ErrNoRows if the event reference doesn't exist. // Returns sql.ErrNoRows if the event reference doesn't exist.
func (s *previousEventStatements) selectPreviousEventExists(txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error { func (s *previousEventStatements) selectPreviousEventExists(txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error {
var ok int64 var ok int64
return txn.Stmt(s.selectPreviousEventExistsStmt).QueryRow(eventID, eventReferenceSHA256).Scan(&ok) return common.TxStmt(txn, s.selectPreviousEventExistsStmt).QueryRow(eventID, eventReferenceSHA256).Scan(&ok)
} }

View file

@ -18,6 +18,7 @@ import (
"database/sql" "database/sql"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -82,21 +83,13 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
func (s *roomStatements) insertRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) { func (s *roomStatements) insertRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) {
var roomNID int64 var roomNID int64
stmt := s.insertRoomNIDStmt err := common.TxStmt(txn, s.insertRoomNIDStmt).QueryRow(roomID).Scan(&roomNID)
if txn != nil {
stmt = txn.Stmt(stmt)
}
err := stmt.QueryRow(roomID).Scan(&roomNID)
return types.RoomNID(roomNID), err return types.RoomNID(roomNID), err
} }
func (s *roomStatements) selectRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) { func (s *roomStatements) selectRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) {
var roomNID int64 var roomNID int64
stmt := s.selectRoomNIDStmt err := common.TxStmt(txn, s.selectRoomNIDStmt).QueryRow(roomID).Scan(&roomNID)
if txn != nil {
stmt = txn.Stmt(stmt)
}
err := stmt.QueryRow(roomID).Scan(&roomNID)
return types.RoomNID(roomNID), err return types.RoomNID(roomNID), err
} }
@ -120,7 +113,7 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID ty
var nids pq.Int64Array var nids pq.Int64Array
var lastEventSentNID int64 var lastEventSentNID int64
var stateSnapshotNID int64 var stateSnapshotNID int64
err := txn.Stmt(s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID) err := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -135,7 +128,7 @@ func (s *roomStatements) updateLatestEventNIDs(
txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID,
stateSnapshotNID types.StateSnapshotNID, stateSnapshotNID types.StateSnapshotNID,
) error { ) error {
_, err := txn.Stmt(s.updateLatestEventNIDsStmt).Exec( _, err := common.TxStmt(txn, s.updateLatestEventNIDsStmt).Exec(
roomNID, eventNIDsAsArray(eventNIDs), int64(lastEventSentNID), int64(stateSnapshotNID), roomNID, eventNIDsAsArray(eventNIDs), int64(lastEventSentNID), int64(stateSnapshotNID),
) )
return err return err

View file

@ -18,6 +18,7 @@ import (
"database/sql" "database/sql"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -136,7 +137,7 @@ func (s *currentRoomStateStatements) selectJoinedUsers() (map[string][]string, e
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
func (s *currentRoomStateStatements) selectRoomIDsWithMembership(txn *sql.Tx, userID, membership string) ([]string, error) { func (s *currentRoomStateStatements) selectRoomIDsWithMembership(txn *sql.Tx, userID, membership string) ([]string, error) {
rows, err := txn.Stmt(s.selectRoomIDsWithMembershipStmt).Query(userID, membership) rows, err := common.TxStmt(txn, s.selectRoomIDsWithMembershipStmt).Query(userID, membership)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -155,7 +156,7 @@ func (s *currentRoomStateStatements) selectRoomIDsWithMembership(txn *sql.Tx, us
// CurrentState returns all the current state events for the given room. // CurrentState returns all the current state events for the given room.
func (s *currentRoomStateStatements) selectCurrentState(txn *sql.Tx, roomID string) ([]gomatrixserverlib.Event, error) { func (s *currentRoomStateStatements) selectCurrentState(txn *sql.Tx, roomID string) ([]gomatrixserverlib.Event, error) {
rows, err := txn.Stmt(s.selectCurrentStateStmt).Query(roomID) rows, err := common.TxStmt(txn, s.selectCurrentStateStmt).Query(roomID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -165,21 +166,21 @@ func (s *currentRoomStateStatements) selectCurrentState(txn *sql.Tx, roomID stri
} }
func (s *currentRoomStateStatements) deleteRoomStateByEventID(txn *sql.Tx, eventID string) error { func (s *currentRoomStateStatements) deleteRoomStateByEventID(txn *sql.Tx, eventID string) error {
_, err := txn.Stmt(s.deleteRoomStateByEventIDStmt).Exec(eventID) _, err := common.TxStmt(txn, s.deleteRoomStateByEventIDStmt).Exec(eventID)
return err return err
} }
func (s *currentRoomStateStatements) upsertRoomState( func (s *currentRoomStateStatements) upsertRoomState(
txn *sql.Tx, event gomatrixserverlib.Event, membership *string, addedAt int64, txn *sql.Tx, event gomatrixserverlib.Event, membership *string, addedAt int64,
) error { ) error {
_, err := txn.Stmt(s.upsertRoomStateStmt).Exec( _, err := common.TxStmt(txn, s.upsertRoomStateStmt).Exec(
event.RoomID(), event.EventID(), event.Type(), *event.StateKey(), event.JSON(), membership, addedAt, event.RoomID(), event.EventID(), event.Type(), *event.StateKey(), event.JSON(), membership, addedAt,
) )
return err return err
} }
func (s *currentRoomStateStatements) selectEventsWithEventIDs(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) { func (s *currentRoomStateStatements) selectEventsWithEventIDs(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) {
rows, err := txn.Stmt(s.selectEventsWithEventIDsStmt).Query(pq.StringArray(eventIDs)) rows, err := common.TxStmt(txn, s.selectEventsWithEventIDsStmt).Query(pq.StringArray(eventIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -19,6 +19,7 @@ import (
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -105,7 +106,7 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
func (s *outputRoomEventsStatements) selectStateInRange( func (s *outputRoomEventsStatements) selectStateInRange(
txn *sql.Tx, oldPos, newPos types.StreamPosition, txn *sql.Tx, oldPos, newPos types.StreamPosition,
) (map[string]map[string]bool, map[string]streamEvent, error) { ) (map[string]map[string]bool, map[string]streamEvent, error) {
rows, err := txn.Stmt(s.selectStateInRangeStmt).Query(oldPos, newPos) rows, err := common.TxStmt(txn, s.selectStateInRangeStmt).Query(oldPos, newPos)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -167,12 +168,8 @@ func (s *outputRoomEventsStatements) selectStateInRange(
// then this function should only ever be used at startup, as it will race with inserting events if it is // then this function should only ever be used at startup, as it will race with inserting events if it is
// done afterwards. If there are no inserted events, 0 is returned. // done afterwards. If there are no inserted events, 0 is returned.
func (s *outputRoomEventsStatements) selectMaxID(txn *sql.Tx) (id int64, err error) { func (s *outputRoomEventsStatements) selectMaxID(txn *sql.Tx) (id int64, err error) {
stmt := s.selectMaxIDStmt
if txn != nil {
stmt = txn.Stmt(stmt)
}
var nullableID sql.NullInt64 var nullableID sql.NullInt64
err = stmt.QueryRow().Scan(&nullableID) err = common.TxStmt(txn, s.selectMaxIDStmt).QueryRow().Scan(&nullableID)
if nullableID.Valid { if nullableID.Valid {
id = nullableID.Int64 id = nullableID.Int64
} }
@ -182,7 +179,7 @@ func (s *outputRoomEventsStatements) selectMaxID(txn *sql.Tx) (id int64, err err
// InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position // InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position
// of the inserted event. // of the inserted event.
func (s *outputRoomEventsStatements) insertEvent(txn *sql.Tx, event *gomatrixserverlib.Event, addState, removeState []string) (streamPos int64, err error) { func (s *outputRoomEventsStatements) insertEvent(txn *sql.Tx, event *gomatrixserverlib.Event, addState, removeState []string) (streamPos int64, err error) {
err = txn.Stmt(s.insertEventStmt).QueryRow( err = common.TxStmt(txn, s.insertEventStmt).QueryRow(
event.RoomID(), event.EventID(), event.JSON(), pq.StringArray(addState), pq.StringArray(removeState), event.RoomID(), event.EventID(), event.JSON(), pq.StringArray(addState), pq.StringArray(removeState),
).Scan(&streamPos) ).Scan(&streamPos)
return return
@ -209,11 +206,7 @@ func (s *outputRoomEventsStatements) selectRecentEvents(
// Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing // Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing
// from the database. // from the database.
func (s *outputRoomEventsStatements) selectEvents(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) { func (s *outputRoomEventsStatements) selectEvents(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) {
stmt := s.selectEventsStmt rows, err := common.TxStmt(txn, s.selectEventsStmt).Query(pq.StringArray(eventIDs))
if txn != nil {
stmt = txn.Stmt(stmt)
}
rows, err := stmt.Query(pq.StringArray(eventIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }