diff --git a/src/github.com/matrix-org/dendrite/common/sql.go b/src/github.com/matrix-org/dendrite/common/sql.go index 4abe7410e..c2fb753fc 100644 --- a/src/github.com/matrix-org/dendrite/common/sql.go +++ b/src/github.com/matrix-org/dendrite/common/sql.go @@ -55,3 +55,14 @@ func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) { succeeded = true return } + +// TxStmt wraps an SQL stmt inside an optional transaction. +// If the transaction is nil then it returns the original statement that will +// run outside of a transaction. +// Otherwise returns a copy of the statement that will run inside the transaction. +func TxStmt(transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt { + if transaction != nil { + statement = transaction.Stmt(statement) + } + return statement +} diff --git a/src/github.com/matrix-org/dendrite/federationsender/storage/joined_hosts_table.go b/src/github.com/matrix-org/dendrite/federationsender/storage/joined_hosts_table.go index 7ba1b0b07..fffcc7f3f 100644 --- a/src/github.com/matrix-org/dendrite/federationsender/storage/joined_hosts_table.go +++ b/src/github.com/matrix-org/dendrite/federationsender/storage/joined_hosts_table.go @@ -18,6 +18,7 @@ import ( "database/sql" "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -79,18 +80,18 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { func (s *joinedHostsStatements) insertJoinedHosts( txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName, ) error { - _, err := txn.Stmt(s.insertJoinedHostsStmt).Exec(roomID, eventID, serverName) + _, err := common.TxStmt(txn, s.insertJoinedHostsStmt).Exec(roomID, eventID, serverName) return err } func (s *joinedHostsStatements) deleteJoinedHosts(txn *sql.Tx, eventIDs []string) error { - _, err := txn.Stmt(s.deleteJoinedHostsStmt).Exec(pq.StringArray(eventIDs)) + _, err := common.TxStmt(txn, s.deleteJoinedHostsStmt).Exec(pq.StringArray(eventIDs)) return err } func (s *joinedHostsStatements) selectJoinedHosts(txn *sql.Tx, roomID string, ) ([]types.JoinedHost, error) { - rows, err := txn.Stmt(s.selectJoinedHostsStmt).Query(roomID) + rows, err := common.TxStmt(txn, s.selectJoinedHostsStmt).Query(roomID) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/federationsender/storage/room_table.go b/src/github.com/matrix-org/dendrite/federationsender/storage/room_table.go index daac7ddf4..bcc0bb1df 100644 --- a/src/github.com/matrix-org/dendrite/federationsender/storage/room_table.go +++ b/src/github.com/matrix-org/dendrite/federationsender/storage/room_table.go @@ -16,6 +16,8 @@ package storage import ( "database/sql" + + "github.com/matrix-org/dendrite/common" ) const roomSchema = ` @@ -65,7 +67,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { // insertRoom inserts the room if it didn't already exist. // If the room didn't exist then last_event_id is set to the empty string. func (s *roomStatements) insertRoom(txn *sql.Tx, roomID string) error { - _, err := txn.Stmt(s.insertRoomStmt).Exec(roomID) + _, err := common.TxStmt(txn, s.insertRoomStmt).Exec(roomID) return err } @@ -74,7 +76,7 @@ func (s *roomStatements) insertRoom(txn *sql.Tx, roomID string) error { // exists by calling insertRoom first. func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string, error) { var lastEventID string - err := txn.Stmt(s.selectRoomForUpdateStmt).QueryRow(roomID).Scan(&lastEventID) + err := common.TxStmt(txn, s.selectRoomForUpdateStmt).QueryRow(roomID).Scan(&lastEventID) if err != nil { return "", err } @@ -84,6 +86,6 @@ func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string // updateRoom updates the last_event_id for the room. selectRoomForUpdate should // have already been called earlier within the transaction. func (s *roomStatements) updateRoom(txn *sql.Tx, roomID, lastEventID string) error { - _, err := txn.Stmt(s.updateRoomStmt).Exec(roomID, lastEventID) + _, err := common.TxStmt(txn, s.updateRoomStmt).Exec(roomID, lastEventID) return err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/event_state_keys_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/event_state_keys_table.go index b4dae8f25..b06f5b2a5 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/event_state_keys_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/event_state_keys_table.go @@ -18,6 +18,7 @@ import ( "database/sql" "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -92,21 +93,13 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { func (s *eventStateKeyStatements) insertEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 - stmt := s.insertEventStateKeyNIDStmt - if txn != nil { - stmt = txn.Stmt(stmt) - } - err := stmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID) + err := common.TxStmt(txn, s.insertEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID) return types.EventStateKeyNID(eventStateKeyNID), err } func (s *eventStateKeyStatements) selectEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 - stmt := s.selectEventStateKeyNIDStmt - if txn != nil { - stmt = txn.Stmt(stmt) - } - err := stmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID) + err := common.TxStmt(txn, s.selectEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID) return types.EventStateKeyNID(eventStateKeyNID), err } @@ -131,11 +124,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(eventStateKeys []st func (s *eventStateKeyStatements) selectEventStateKey(txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID) (string, error) { var eventStateKey string - stmt := s.selectEventStateKeyStmt - if txn != nil { - stmt = txn.Stmt(stmt) - } - err := stmt.QueryRow(eventStateKeyNID).Scan(&eventStateKey) + err := common.TxStmt(txn, s.selectEventStateKeyStmt).QueryRow(eventStateKeyNID).Scan(&eventStateKey) return eventStateKey, err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go index b6db15c82..2d2b85625 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go @@ -19,6 +19,7 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -253,22 +254,22 @@ func (s *eventStatements) updateEventState(eventNID types.EventNID, stateNID typ } func (s *eventStatements) selectEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) { - err = txn.Stmt(s.selectEventSentToOutputStmt).QueryRow(int64(eventNID)).Scan(&sentToOutput) + err = common.TxStmt(txn, s.selectEventSentToOutputStmt).QueryRow(int64(eventNID)).Scan(&sentToOutput) return } func (s *eventStatements) updateEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) error { - _, err := txn.Stmt(s.updateEventSentToOutputStmt).Exec(int64(eventNID)) + _, err := common.TxStmt(txn, s.updateEventSentToOutputStmt).Exec(int64(eventNID)) return err } func (s *eventStatements) selectEventID(txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) { - err = txn.Stmt(s.selectEventIDStmt).QueryRow(int64(eventNID)).Scan(&eventID) + err = common.TxStmt(txn, s.selectEventIDStmt).QueryRow(int64(eventNID)).Scan(&eventID) return } func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) { - rows, err := txn.Stmt(s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs)) + rows, err := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/invite_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/invite_table.go index 9e0860b42..8bae2b781 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/invite_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/invite_table.go @@ -17,6 +17,7 @@ package storage import ( "database/sql" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -94,7 +95,7 @@ func (s *inviteStatements) insertInviteEvent( targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { - result, err := txn.Stmt(s.insertInviteEventStmt).Exec( + result, err := common.TxStmt(txn, s.insertInviteEventStmt).Exec( inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, ) if err != nil { @@ -110,7 +111,7 @@ func (s *inviteStatements) insertInviteEvent( func (s *inviteStatements) updateInviteRetired( txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) ([]string, error) { - rows, err := txn.Stmt(s.updateInviteRetiredStmt).Query(roomNID, targetUserNID) + rows, err := common.TxStmt(txn, s.updateInviteRetiredStmt).Query(roomNID, targetUserNID) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/membership_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/membership_table.go index 52051af59..6edc7a528 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/membership_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/membership_table.go @@ -17,6 +17,7 @@ package storage import ( "database/sql" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -115,14 +116,14 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { func (s *membershipStatements) insertMembership( txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) error { - _, err := txn.Stmt(s.insertMembershipStmt).Exec(roomNID, targetUserNID) + _, err := common.TxStmt(txn, s.insertMembershipStmt).Exec(roomNID, targetUserNID) return err } func (s *membershipStatements) selectMembershipForUpdate( txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (membership membershipState, err error) { - err = txn.Stmt(s.selectMembershipForUpdateStmt).QueryRow( + err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRow( roomNID, targetUserNID, ).Scan(&membership) return @@ -179,7 +180,7 @@ func (s *membershipStatements) updateMembership( senderUserNID types.EventStateKeyNID, membership membershipState, eventNID types.EventNID, ) error { - _, err := txn.Stmt(s.updateMembershipStmt).Exec( + _, err := common.TxStmt(txn, s.updateMembershipStmt).Exec( roomNID, targetUserNID, senderUserNID, membership, eventNID, ) return err diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/previous_events_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/previous_events_table.go index 71795d488..9fcf1cb5c 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/previous_events_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/previous_events_table.go @@ -17,6 +17,7 @@ package storage import ( "database/sql" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -73,7 +74,7 @@ func (s *previousEventStatements) prepare(db *sql.DB) (err error) { } func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error { - _, err := txn.Stmt(s.insertPreviousEventStmt).Exec(previousEventID, previousEventReferenceSHA256, int64(eventNID)) + _, err := common.TxStmt(txn, s.insertPreviousEventStmt).Exec(previousEventID, previousEventReferenceSHA256, int64(eventNID)) return err } @@ -81,5 +82,5 @@ func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEvent // Returns sql.ErrNoRows if the event reference doesn't exist. func (s *previousEventStatements) selectPreviousEventExists(txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error { var ok int64 - return txn.Stmt(s.selectPreviousEventExistsStmt).QueryRow(eventID, eventReferenceSHA256).Scan(&ok) + return common.TxStmt(txn, s.selectPreviousEventExistsStmt).QueryRow(eventID, eventReferenceSHA256).Scan(&ok) } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go index 24744fdff..4ba329f39 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go @@ -18,6 +18,7 @@ import ( "database/sql" "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -82,21 +83,13 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { func (s *roomStatements) insertRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) { var roomNID int64 - stmt := s.insertRoomNIDStmt - if txn != nil { - stmt = txn.Stmt(stmt) - } - err := stmt.QueryRow(roomID).Scan(&roomNID) + err := common.TxStmt(txn, s.insertRoomNIDStmt).QueryRow(roomID).Scan(&roomNID) return types.RoomNID(roomNID), err } func (s *roomStatements) selectRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) { var roomNID int64 - stmt := s.selectRoomNIDStmt - if txn != nil { - stmt = txn.Stmt(stmt) - } - err := stmt.QueryRow(roomID).Scan(&roomNID) + err := common.TxStmt(txn, s.selectRoomNIDStmt).QueryRow(roomID).Scan(&roomNID) return types.RoomNID(roomNID), err } @@ -120,7 +113,7 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID ty var nids pq.Int64Array var lastEventSentNID int64 var stateSnapshotNID int64 - err := txn.Stmt(s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID) + err := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID) if err != nil { return nil, 0, 0, err } @@ -135,7 +128,7 @@ func (s *roomStatements) updateLatestEventNIDs( txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID, ) error { - _, err := txn.Stmt(s.updateLatestEventNIDsStmt).Exec( + _, err := common.TxStmt(txn, s.updateLatestEventNIDsStmt).Exec( roomNID, eventNIDsAsArray(eventNIDs), int64(lastEventSentNID), int64(stateSnapshotNID), ) return err diff --git a/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go b/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go index 9958e0d15..10933e965 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go +++ b/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go @@ -18,6 +18,7 @@ import ( "database/sql" "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/gomatrixserverlib" ) @@ -136,7 +137,7 @@ func (s *currentRoomStateStatements) selectJoinedUsers() (map[string][]string, e // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. func (s *currentRoomStateStatements) selectRoomIDsWithMembership(txn *sql.Tx, userID, membership string) ([]string, error) { - rows, err := txn.Stmt(s.selectRoomIDsWithMembershipStmt).Query(userID, membership) + rows, err := common.TxStmt(txn, s.selectRoomIDsWithMembershipStmt).Query(userID, membership) if err != nil { return nil, err } @@ -155,7 +156,7 @@ func (s *currentRoomStateStatements) selectRoomIDsWithMembership(txn *sql.Tx, us // CurrentState returns all the current state events for the given room. func (s *currentRoomStateStatements) selectCurrentState(txn *sql.Tx, roomID string) ([]gomatrixserverlib.Event, error) { - rows, err := txn.Stmt(s.selectCurrentStateStmt).Query(roomID) + rows, err := common.TxStmt(txn, s.selectCurrentStateStmt).Query(roomID) if err != nil { return nil, err } @@ -165,21 +166,21 @@ func (s *currentRoomStateStatements) selectCurrentState(txn *sql.Tx, roomID stri } func (s *currentRoomStateStatements) deleteRoomStateByEventID(txn *sql.Tx, eventID string) error { - _, err := txn.Stmt(s.deleteRoomStateByEventIDStmt).Exec(eventID) + _, err := common.TxStmt(txn, s.deleteRoomStateByEventIDStmt).Exec(eventID) return err } func (s *currentRoomStateStatements) upsertRoomState( txn *sql.Tx, event gomatrixserverlib.Event, membership *string, addedAt int64, ) error { - _, err := txn.Stmt(s.upsertRoomStateStmt).Exec( + _, err := common.TxStmt(txn, s.upsertRoomStateStmt).Exec( event.RoomID(), event.EventID(), event.Type(), *event.StateKey(), event.JSON(), membership, addedAt, ) return err } func (s *currentRoomStateStatements) selectEventsWithEventIDs(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) { - rows, err := txn.Stmt(s.selectEventsWithEventIDsStmt).Query(pq.StringArray(eventIDs)) + rows, err := common.TxStmt(txn, s.selectEventsWithEventIDsStmt).Query(pq.StringArray(eventIDs)) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/syncapi/storage/output_room_events_table.go b/src/github.com/matrix-org/dendrite/syncapi/storage/output_room_events_table.go index f3c46298a..93774d1f1 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/storage/output_room_events_table.go +++ b/src/github.com/matrix-org/dendrite/syncapi/storage/output_room_events_table.go @@ -19,6 +19,7 @@ import ( log "github.com/Sirupsen/logrus" "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -105,7 +106,7 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { func (s *outputRoomEventsStatements) selectStateInRange( txn *sql.Tx, oldPos, newPos types.StreamPosition, ) (map[string]map[string]bool, map[string]streamEvent, error) { - rows, err := txn.Stmt(s.selectStateInRangeStmt).Query(oldPos, newPos) + rows, err := common.TxStmt(txn, s.selectStateInRangeStmt).Query(oldPos, newPos) if err != nil { return nil, nil, err } @@ -167,12 +168,8 @@ func (s *outputRoomEventsStatements) selectStateInRange( // then this function should only ever be used at startup, as it will race with inserting events if it is // done afterwards. If there are no inserted events, 0 is returned. func (s *outputRoomEventsStatements) selectMaxID(txn *sql.Tx) (id int64, err error) { - stmt := s.selectMaxIDStmt - if txn != nil { - stmt = txn.Stmt(stmt) - } var nullableID sql.NullInt64 - err = stmt.QueryRow().Scan(&nullableID) + err = common.TxStmt(txn, s.selectMaxIDStmt).QueryRow().Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 } @@ -182,7 +179,7 @@ func (s *outputRoomEventsStatements) selectMaxID(txn *sql.Tx) (id int64, err err // InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position // of the inserted event. func (s *outputRoomEventsStatements) insertEvent(txn *sql.Tx, event *gomatrixserverlib.Event, addState, removeState []string) (streamPos int64, err error) { - err = txn.Stmt(s.insertEventStmt).QueryRow( + err = common.TxStmt(txn, s.insertEventStmt).QueryRow( event.RoomID(), event.EventID(), event.JSON(), pq.StringArray(addState), pq.StringArray(removeState), ).Scan(&streamPos) return @@ -209,11 +206,7 @@ func (s *outputRoomEventsStatements) selectRecentEvents( // Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing // from the database. func (s *outputRoomEventsStatements) selectEvents(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) { - stmt := s.selectEventsStmt - if txn != nil { - stmt = txn.Stmt(stmt) - } - rows, err := stmt.Query(pq.StringArray(eventIDs)) + rows, err := common.TxStmt(txn, s.selectEventsStmt).Query(pq.StringArray(eventIDs)) if err != nil { return nil, err }