diff --git a/roomserver/input/authevents.go b/roomserver/input/authevents.go index 74be2ed33..9c6ba9429 100644 --- a/roomserver/input/authevents.go +++ b/roomserver/input/authevents.go @@ -16,6 +16,7 @@ package input import ( "context" + "fmt" "sort" "github.com/matrix-org/dendrite/roomserver/types" @@ -35,16 +36,19 @@ func checkAuthEvents( if err != nil { return nil, err } + fmt.Println("authStateEntries:", authStateEntries) // TODO: check for duplicate state keys here. // Work out which of the state events we actually need. stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event}) + fmt.Println("stateNeeded:", stateNeeded) // Load the actual auth events from the database. authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) if err != nil { return nil, err } + fmt.Println("authEvents:", authEvents) // Check if the event is allowed. if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index 0ea6ee8ec..ffca7ff2e 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "fmt" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -61,9 +62,9 @@ func (s *eventJSONStatements) prepare(db *sql.DB) (err error) { } func (s *eventJSONStatements) insertEventJSON( - ctx context.Context, eventNID types.EventNID, eventJSON []byte, + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { - _, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON) + _, err := common.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) return err } @@ -73,9 +74,9 @@ type eventJSONPair struct { } func (s *eventJSONStatements) bulkSelectEventJSON( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]eventJSONPair, error) { - rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + rows, err := common.TxStmt(txn, s.bulkSelectEventJSONStmt).QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { fmt.Println("bulkSelectEventJSON s.bulkSelectEventJSONStmt.QueryContext:", err) return nil, err diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index 0d4eda82a..60d1cb30e 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -118,10 +118,10 @@ func (s *eventStateKeyStatements) selectEventStateKeyNID( } func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( - ctx context.Context, eventStateKeys []string, + ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { - rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext( - ctx, pq.StringArray(eventStateKeys), + rows, err := common.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt).QueryContext( + ctx, sqliteInStr(pq.StringArray(eventStateKeys)), ) if err != nil { fmt.Println("bulkSelectEventStateKeyNID s.bulkSelectEventStateKeyNIDStmt.QueryContext:", err) @@ -142,13 +142,13 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( } func (s *eventStateKeyStatements) bulkSelectEventStateKey( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { nIDs := make(pq.Int64Array, len(eventStateKeyNIDs)) for i := range eventStateKeyNIDs { nIDs[i] = int64(eventStateKeyNIDs[i]) } - rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs) + rows, err := common.TxStmt(txn, s.bulkSelectEventStateKeyStmt).QueryContext(ctx, nIDs) if err != nil { fmt.Println("bulkSelectEventStateKey s.bulkSelectEventStateKeyStmt.QueryContext:", err) return nil, err diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 6b9e35611..4b00f3fdf 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -94,28 +95,32 @@ func (s *eventTypeStatements) prepare(db *sql.DB) (err error) { } func (s *eventTypeStatements) insertEventTypeNID( - ctx context.Context, eventType string, + ctx context.Context, tx *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 var err error - if _, err = s.insertEventTypeNIDStmt.ExecContext(ctx, eventType); err == nil { - err = s.insertEventTypeNIDResultStmt.QueryRowContext(ctx).Scan(&eventTypeNID) + insertStmt := common.TxStmt(tx, s.insertEventTypeNIDStmt) + resultStmt := common.TxStmt(tx, s.insertEventTypeNIDResultStmt) + if _, err = insertStmt.ExecContext(ctx, eventType); err == nil { + err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID) } return types.EventTypeNID(eventTypeNID), err } func (s *eventTypeStatements) selectEventTypeNID( - ctx context.Context, eventType string, + ctx context.Context, tx *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 - err := s.selectEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) + selectStmt := common.TxStmt(tx, s.selectEventTypeNIDStmt) + err := selectStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) return types.EventTypeNID(eventTypeNID), err } func (s *eventTypeStatements) bulkSelectEventTypeNID( - ctx context.Context, eventTypes []string, + ctx context.Context, tx *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { - rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes)) + selectStmt := common.TxStmt(tx, s.bulkSelectEventTypeNIDStmt) + rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventTypes))) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index e02b62eeb..6b03a42f1 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -138,6 +138,7 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { func (s *eventStatements) insertEvent( ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, @@ -149,30 +150,34 @@ func (s *eventStatements) insertEvent( var eventNID int64 var stateNID int64 var err error - if _, err = s.insertEventStmt.ExecContext( + insertStmt := common.TxStmt(txn, s.insertEventStmt) + resultStmt := common.TxStmt(txn, s.insertEventResultStmt) + if _, err = insertStmt.ExecContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, ); err == nil { - err = s.insertEventResultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID) + err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID) } return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } func (s *eventStatements) selectEvent( - ctx context.Context, eventID string, + ctx context.Context, txn *sql.Tx, eventID string, ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 - err := s.selectEventStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID) + selectStmt := common.TxStmt(txn, s.selectEventStmt) + err := selectStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID) return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError func (s *eventStatements) bulkSelectStateEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateEntry, error) { - rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) + selectStmt := common.TxStmt(txn, s.bulkSelectStateEventByIDStmt) + rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventIDs))) if err != nil { return nil, err } @@ -210,9 +215,10 @@ func (s *eventStatements) bulkSelectStateEventByID( // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. func (s *eventStatements) bulkSelectStateAtEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateAtEvent, error) { - rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) + selectStmt := common.TxStmt(txn, s.bulkSelectStateAtEventByIDStmt) + rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventIDs))) if err != nil { return nil, err } @@ -244,9 +250,10 @@ func (s *eventStatements) bulkSelectStateAtEventByID( } func (s *eventStatements) updateEventState( - ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { - _, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID)) + updateStmt := common.TxStmt(txn, s.updateEventStateStmt) + _, err := updateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID)) if err != nil { fmt.Println("updateEventState s.updateEventStateStmt.ExecContext:", err) } @@ -256,8 +263,8 @@ func (s *eventStatements) updateEventState( func (s *eventStatements) selectEventSentToOutput( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (sentToOutput bool, err error) { - stmt := common.TxStmt(txn, s.selectEventSentToOutputStmt) - err = stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) + selectStmt := common.TxStmt(txn, s.selectEventSentToOutputStmt) + err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) //err = s.selectEventSentToOutputStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) if err != nil { fmt.Println("selectEventSentToOutput stmt.QueryRowContext:", err) @@ -266,8 +273,8 @@ func (s *eventStatements) selectEventSentToOutput( } func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { - stmt := common.TxStmt(txn, s.updateEventSentToOutputStmt) - _, err := stmt.ExecContext(ctx, int64(eventNID)) + updateStmt := common.TxStmt(txn, s.updateEventSentToOutputStmt) + _, err := updateStmt.ExecContext(ctx, int64(eventNID)) //_, err := s.updateEventSentToOutputStmt.ExecContext(ctx, int64(eventNID)) if err != nil { fmt.Println("updateEventSentToOutput stmt.QueryRowContext:", err) @@ -278,8 +285,8 @@ func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql. func (s *eventStatements) selectEventID( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (eventID string, err error) { - stmt := common.TxStmt(txn, s.selectEventIDStmt) - err = stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID) + selectStmt := common.TxStmt(txn, s.selectEventIDStmt) + err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID) if err != nil { fmt.Println("selectEventID stmt.QueryRowContext:", err) } @@ -289,8 +296,8 @@ func (s *eventStatements) selectEventID( func (s *eventStatements) bulkSelectStateAtEventAndReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]types.StateAtEventAndReference, error) { - stmt := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt) - rows, err := stmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs))) + selectStmt := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt) + rows, err := selectStmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs))) if err != nil { fmt.Println("bulkSelectStateAtEventAndREference stmt.QueryContext:", err) return nil, err @@ -328,9 +335,10 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference( } func (s *eventStatements) bulkSelectEventReference( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]gomatrixserverlib.EventReference, error) { - rows, err := s.bulkSelectEventReferenceStmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs))) + selectStmt := common.TxStmt(txn, s.bulkSelectEventReferenceStmt) + rows, err := selectStmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs))) if err != nil { fmt.Println("bulkSelectEventReference s.bulkSelectEventReferenceStmt.QueryContext:", err) return nil, err @@ -352,8 +360,9 @@ func (s *eventStatements) bulkSelectEventReference( } // bulkSelectEventID returns a map from numeric event ID to string event ID. -func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { - rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs))) +func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { + selectStmt := common.TxStmt(txn, s.bulkSelectEventIDStmt) + rows, err := selectStmt.QueryContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs))) if err != nil { fmt.Println("bulkSelectEventID s.bulkSelectEventIDStmt.QueryContext:", err) return nil, err @@ -378,8 +387,9 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []typ // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { - rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) +func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { + selectStmt := common.TxStmt(txn, s.bulkSelectEventNIDStmt) + rows, err := selectStmt.QueryContext(ctx, sqliteInStr(pq.StringArray(eventIDs))) if err != nil { fmt.Println("bulkSelectEventNID s.bulkSelectEventNIDStmt.QueryContext:", err) return nil, err @@ -398,10 +408,10 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []str return results, nil } -func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) { +func (s *eventStatements) selectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) { var result int64 - stmt := s.selectMaxEventDepthStmt - err := stmt.QueryRowContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs))).Scan(&result) + selectStmt := common.TxStmt(txn, s.selectMaxEventDepthStmt) + err := selectStmt.QueryRowContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs))).Scan(&result) if err != nil { fmt.Println("selectMaxEventDepth stmt.QueryRowContext:", err) return 0, err diff --git a/roomserver/storage/sqlite3/list.go b/roomserver/storage/sqlite3/list.go index 4fe4e334b..f9874e838 100644 --- a/roomserver/storage/sqlite3/list.go +++ b/roomserver/storage/sqlite3/list.go @@ -16,3 +16,7 @@ func sqliteIn(a pq.Int64Array) string { } return strings.Join(b, ",") } + +func sqliteInStr(a pq.StringArray) string { + return "\"" + strings.Join(a, "\",\"") + "\"" +} diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 2c13b539d..9eb7222d8 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -96,8 +96,8 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { } func (s *membershipStatements) insertMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) error { stmt := common.TxStmt(txn, s.insertMembershipStmt) _, err := stmt.ExecContext(ctx, roomNID, targetUserNID) @@ -108,8 +108,8 @@ func (s *membershipStatements) insertMembership( } func (s *membershipStatements) selectMembershipForUpdate( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (membership membershipState, err error) { stmt := common.TxStmt(txn, s.selectMembershipForUpdateStmt) err = stmt.QueryRowContext( @@ -122,10 +122,11 @@ func (s *membershipStatements) selectMembershipForUpdate( } func (s *membershipStatements) selectMembershipFromRoomAndTarget( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventNID types.EventNID, membership membershipState, err error) { - err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( + selectStmt := common.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt) + err = selectStmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID) if err != nil { @@ -135,9 +136,11 @@ func (s *membershipStatements) selectMembershipFromRoomAndTarget( } func (s *membershipStatements) selectMembershipsFromRoom( - ctx context.Context, roomNID types.RoomNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, ) (eventNIDs []types.EventNID, err error) { - rows, err := s.selectMembershipsFromRoomStmt.QueryContext(ctx, roomNID) + selectStmt := common.TxStmt(txn, s.selectMembershipsFromRoomStmt) + rows, err := selectStmt.QueryContext(ctx, roomNID) if err != nil { fmt.Println("selectMembershipsFromRoom s.selectMembershipsFromRoomStmt.QueryContext:", err) return @@ -154,10 +157,10 @@ func (s *membershipStatements) selectMembershipsFromRoom( return } func (s *membershipStatements) selectMembershipsFromRoomAndMembership( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership membershipState, ) (eventNIDs []types.EventNID, err error) { - stmt := s.selectMembershipsFromRoomAndMembershipStmt + stmt := common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt) rows, err := stmt.QueryContext(ctx, roomNID, membership) if err != nil { fmt.Println("selectMembershipsFromRoomAndMembership stmt.QueryContext:", err) @@ -176,8 +179,8 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership( } func (s *membershipStatements) updateMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership membershipState, eventNID types.EventNID, ) error { diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index 36d2588dd..b4a5735f3 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -19,6 +19,8 @@ import ( "context" "database/sql" "fmt" + + "github.com/matrix-org/dendrite/common" ) const roomAliasesSchema = ` @@ -74,9 +76,10 @@ func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) { } func (s *roomAliasesStatements) insertRoomAlias( - ctx context.Context, alias string, roomID string, creatorUserID string, + ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string, ) (err error) { - _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) + insertStmt := common.TxStmt(txn, s.insertRoomAliasStmt) + _, err = insertStmt.ExecContext(ctx, alias, roomID, creatorUserID) if err != nil { fmt.Println("insertRoomAlias s.insertRoomAliasStmt.ExecContent:", err) } @@ -84,9 +87,10 @@ func (s *roomAliasesStatements) insertRoomAlias( } func (s *roomAliasesStatements) selectRoomIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (roomID string, err error) { - err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) + selectStmt := common.TxStmt(txn, s.selectRoomIDFromAliasStmt) + err = selectStmt.QueryRowContext(ctx, alias).Scan(&roomID) if err == sql.ErrNoRows { return "", nil } @@ -94,10 +98,11 @@ func (s *roomAliasesStatements) selectRoomIDFromAlias( } func (s *roomAliasesStatements) selectAliasesFromRoomID( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (aliases []string, err error) { aliases = []string{} - rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) + selectStmt := common.TxStmt(txn, s.selectAliasesFromRoomIDStmt) + rows, err := selectStmt.QueryContext(ctx, roomID) if err != nil { fmt.Println("selectAliasesFromRoomID s.selectAliasesFromRoomIDStmt.QueryContext:", err) return @@ -117,9 +122,10 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID( } func (s *roomAliasesStatements) selectCreatorIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (creatorID string, err error) { - err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) + selectStmt := common.TxStmt(txn, s.selectCreatorIDFromAliasStmt) + err = selectStmt.QueryRowContext(ctx, alias).Scan(&creatorID) if err == sql.ErrNoRows { return "", nil } @@ -127,8 +133,9 @@ func (s *roomAliasesStatements) selectCreatorIDFromAlias( } func (s *roomAliasesStatements) deleteRoomAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (err error) { - _, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias) + deleteStmt := common.TxStmt(txn, s.deleteRoomAliasStmt) + _, err = deleteStmt.ExecContext(ctx, alias) return } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 7fe2913a5..491084dd5 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -113,11 +113,11 @@ func (s *roomStatements) selectRoomNID( } func (s *roomStatements) selectLatestEventNIDs( - ctx context.Context, roomNID types.RoomNID, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { var nids pq.Int64Array var stateSnapshotNID int64 - stmt := s.selectLatestEventNIDsStmt + stmt := common.TxStmt(txn, s.selectLatestEventNIDsStmt) err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID) if err != nil { fmt.Println("selectLatestEventNIDs stmt.QueryRowContext:", err) diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index ff664fe9a..7a84f3136 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -22,6 +22,7 @@ import ( "sort" "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" ) @@ -93,12 +94,12 @@ func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { } func (s *stateBlockStatements) bulkInsertStateData( - ctx context.Context, + ctx context.Context, txn *sql.Tx, stateBlockNID types.StateBlockNID, entries []types.StateEntry, ) error { for _, entry := range entries { - _, err := s.insertStateDataStmt.ExecContext( + _, err := common.TxStmt(txn, s.insertStateDataStmt).ExecContext( ctx, int64(stateBlockNID), int64(entry.EventTypeNID), @@ -115,20 +116,23 @@ func (s *stateBlockStatements) bulkInsertStateData( func (s *stateBlockStatements) selectNextStateBlockNID( ctx context.Context, + txn *sql.Tx, ) (types.StateBlockNID, error) { var stateBlockNID int64 - err := s.selectNextStateBlockNIDStmt.QueryRowContext(ctx).Scan(&stateBlockNID) + selectStmt := common.TxStmt(txn, s.selectNextStateBlockNIDStmt) + err := selectStmt.QueryRowContext(ctx).Scan(&stateBlockNID) return types.StateBlockNID(stateBlockNID), err } func (s *stateBlockStatements) bulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, + ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { nids := make([]int64, len(stateBlockNIDs)) for i := range stateBlockNIDs { nids[i] = int64(stateBlockNIDs[i]) } - rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, sqliteIn(pq.Int64Array(nids))) + selectStmt := common.TxStmt(txn, s.bulkSelectStateBlockEntriesStmt) + rows, err := selectStmt.QueryContext(ctx, sqliteIn(pq.Int64Array(nids))) if err != nil { fmt.Println("bulkSelectStateBlockEntries s.bulkSelectStateBlockEntriesStmt.QueryContext:", err) return nil, err @@ -173,7 +177,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries( } func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( - ctx context.Context, + ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntryList, error) { @@ -182,7 +186,8 @@ func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( sort.Sort(tuples) eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.QueryContext( + selectStmt := common.TxStmt(txn, s.bulkSelectFilteredStateBlockEntriesStmt) + rows, err := selectStmt.QueryContext( ctx, stateBlockNIDsAsArray(stateBlockNIDs), eventTypeNIDArray, diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index fde58f352..71a40f03f 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -21,6 +21,7 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -69,14 +70,16 @@ func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) { } func (s *stateSnapshotStatements) insertState( - ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, ) (stateNID types.StateSnapshotNID, err error) { nids := make([]int64, len(stateBlockNIDs)) for i := range stateBlockNIDs { nids[i] = int64(stateBlockNIDs[i]) } - if _, err = s.insertStateStmt.ExecContext(ctx, int64(roomNID), pq.Int64Array(nids)); err == nil { - err = s.insertStateResultStmt.QueryRowContext(ctx).Scan(&stateNID) + insertStmt := common.TxStmt(txn, s.insertStateStmt) + resultStmt := common.TxStmt(txn, s.insertStateResultStmt) + if _, err = insertStmt.ExecContext(ctx, int64(roomNID), pq.Int64Array(nids)); err == nil { + err = resultStmt.QueryRowContext(ctx).Scan(&stateNID) if err != nil { fmt.Println("insertState s.insertStateResultStmt.QueryRowContext:", err) } @@ -87,13 +90,14 @@ func (s *stateSnapshotStatements) insertState( } func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, + ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { nids := make([]int64, len(stateNIDs)) for i := range stateNIDs { nids[i] = int64(stateNIDs[i]) } - rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, sqliteIn(pq.Int64Array(nids))) + selectStmt := common.TxStmt(txn, s.bulkSelectStateBlockNIDsStmt) + rows, err := selectStmt.QueryContext(ctx, sqliteIn(pq.Int64Array(nids))) if err != nil { fmt.Println("bulkSelectStateBlockNIDs s.bulkSelectStateBlockNIDsStmt.QueryContext:", err) return nil, err diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index b47adae5e..138d971f9 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -22,6 +22,7 @@ import ( "fmt" "net/url" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -52,9 +53,9 @@ func Open(dataSourceName string) (*Database, error) { if d.db, err = sql.Open("sqlite3", cs); err != nil { return nil, err } - //d.db.Exec("PRAGMA journal_mode=WAL;") + d.db.Exec("PRAGMA journal_mode=WAL;") //d.db.Exec("PRAGMA parser_trace = true;") - d.db.SetMaxOpenConns(1) + //d.db.SetMaxOpenConns(1) if err = d.statements.prepare(d.db); err != nil { return nil, err } @@ -84,43 +85,59 @@ func (d *Database) StoreEvent( } } - if roomNID, err = d.assignRoomNID(ctx, nil, event.RoomID()); err != nil { + err = common.WithTransaction(d.db, func(tx *sql.Tx) error { + roomNID, err = d.assignRoomNID(ctx, tx, event.RoomID()) + return err + }) + if err != nil { return 0, types.StateAtEvent{}, err } - if eventTypeNID, err = d.assignEventTypeNID(ctx, event.Type()); err != nil { + err = common.WithTransaction(d.db, func(tx *sql.Tx) error { + eventTypeNID, err = d.assignEventTypeNID(ctx, tx, event.Type()) + return err + }) + if err != nil { return 0, types.StateAtEvent{}, err } - eventStateKey := event.StateKey() - // Assigned a numeric ID for the state_key if there is one present. - // Otherwise set the numeric ID for the state_key to 0. - if eventStateKey != nil { - if eventStateKeyNID, err = d.assignStateKeyNID(ctx, nil, *eventStateKey); err != nil { - return 0, types.StateAtEvent{}, err + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + eventStateKey := event.StateKey() + // Assigned a numeric ID for the state_key if there is one present. + // Otherwise set the numeric ID for the state_key to 0. + if eventStateKey != nil { + if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil { + return err + } } - } - if eventNID, stateNID, err = d.statements.insertEvent( - ctx, - roomNID, - eventTypeNID, - eventStateKeyNID, - event.EventID(), - event.EventReference().EventSHA256, - authEventNIDs, - event.Depth(), - ); err != nil { - if err == sql.ErrNoRows { - // We've already inserted the event so select the numeric event ID - eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID()) + if eventNID, stateNID, err = d.statements.insertEvent( + ctx, + txn, + roomNID, + eventTypeNID, + eventStateKeyNID, + event.EventID(), + event.EventReference().EventSHA256, + authEventNIDs, + event.Depth(), + ); err != nil { + if err == sql.ErrNoRows { + // We've already inserted the event so select the numeric event ID + eventNID, stateNID, err = d.statements.selectEvent(ctx, txn, event.EventID()) + } + if err != nil { + return err + } } - if err != nil { - return 0, types.StateAtEvent{}, err - } - } - if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil { + if err = d.statements.insertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { + return err + } + + return nil + }) + if err != nil { return 0, types.StateAtEvent{}, err } @@ -138,9 +155,9 @@ func (d *Database) StoreEvent( func (d *Database) assignRoomNID( ctx context.Context, txn *sql.Tx, roomID string, -) (types.RoomNID, error) { +) (roomNID types.RoomNID, err error) { // Check if we already have a numeric ID in the database. - roomNID, err := d.statements.selectRoomNID(ctx, txn, roomID) + roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) if err == sql.ErrNoRows { // We don't have a numeric ID so insert one into the database. roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID) @@ -149,30 +166,30 @@ func (d *Database) assignRoomNID( roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) } } - return roomNID, err + return } func (d *Database) assignEventTypeNID( - ctx context.Context, eventType string, -) (types.EventTypeNID, error) { + ctx context.Context, txn *sql.Tx, eventType string, +) (eventTypeNID types.EventTypeNID, err error) { // Check if we already have a numeric ID in the database. - eventTypeNID, err := d.statements.selectEventTypeNID(ctx, eventType) + eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType) if err == sql.ErrNoRows { // We don't have a numeric ID so insert one into the database. - eventTypeNID, err = d.statements.insertEventTypeNID(ctx, eventType) + eventTypeNID, err = d.statements.insertEventTypeNID(ctx, txn, eventType) if err == sql.ErrNoRows { // We raced with another insert so run the select again. - eventTypeNID, err = d.statements.selectEventTypeNID(ctx, eventType) + eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType) } } - return eventTypeNID, err + return } func (d *Database) assignStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, -) (types.EventStateKeyNID, error) { +) (eventStateKeyNID types.EventStateKeyNID, err error) { // Check if we already have a numeric ID in the database. - eventStateKeyNID, err := d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) + eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) if err == sql.ErrNoRows { // We don't have a numeric ID so insert one into the database. eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey) @@ -181,61 +198,69 @@ func (d *Database) assignStateKeyNID( eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) } } - return eventStateKeyNID, err + return } // StateEntriesForEventIDs implements input.EventDatabase func (d *Database) StateEntriesForEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateEntry, error) { - return d.statements.bulkSelectStateEventByID(ctx, eventIDs) + return d.statements.bulkSelectStateEventByID(ctx, nil, eventIDs) } // EventTypeNIDs implements state.RoomStateDatabase func (d *Database) EventTypeNIDs( ctx context.Context, eventTypes []string, ) (map[string]types.EventTypeNID, error) { - return d.statements.bulkSelectEventTypeNID(ctx, eventTypes) + return d.statements.bulkSelectEventTypeNID(ctx, nil, eventTypes) } // EventStateKeyNIDs implements state.RoomStateDatabase func (d *Database) EventStateKeyNIDs( ctx context.Context, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { - return d.statements.bulkSelectEventStateKeyNID(ctx, eventStateKeys) + return d.statements.bulkSelectEventStateKeyNID(ctx, nil, eventStateKeys) } // EventStateKeys implements query.RoomserverQueryAPIDatabase func (d *Database) EventStateKeys( ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { - return d.statements.bulkSelectEventStateKey(ctx, eventStateKeyNIDs) + return d.statements.bulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs) } // EventNIDs implements query.RoomserverQueryAPIDatabase func (d *Database) EventNIDs( ctx context.Context, eventIDs []string, ) (map[string]types.EventNID, error) { - return d.statements.bulkSelectEventNID(ctx, eventIDs) + return d.statements.bulkSelectEventNID(ctx, nil, eventIDs) } // Events implements input.EventDatabase func (d *Database) Events( ctx context.Context, eventNIDs []types.EventNID, ) ([]types.Event, error) { - eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs) - if err != nil { - return nil, err - } + var eventJSONs []eventJSONPair + var err error results := make([]types.Event, len(eventJSONs)) - for i, eventJSON := range eventJSONs { - result := &results[i] - result.EventNID = eventJSON.EventNID - // TODO: Use NewEventFromTrustedJSON for efficiency - result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON(eventJSON.EventJSON) + common.WithTransaction(d.db, func(txn *sql.Tx) error { + eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs) if err != nil { - return nil, err + return nil } + for i, eventJSON := range eventJSONs { + result := &results[i] + result.EventNID = eventJSON.EventNID + // TODO: Use NewEventFromTrustedJSON for efficiency + result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON(eventJSON.EventJSON) + if err != nil { + return nil + } + } + return nil + }) + if err != nil { + return []types.Event{}, err } return results, nil } @@ -246,62 +271,68 @@ func (d *Database) AddState( roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, -) (types.StateSnapshotNID, error) { - if len(state) > 0 { - stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx) - if err != nil { - return 0, err +) (stateNID types.StateSnapshotNID, err error) { + common.WithTransaction(d.db, func(txn *sql.Tx) error { + if len(state) > 0 { + stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx, txn) + if err != nil { + return err + } + if err = d.statements.bulkInsertStateData(ctx, txn, stateBlockNID, state); err != nil { + return err + } + stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) } - if err = d.statements.bulkInsertStateData(ctx, stateBlockNID, state); err != nil { - return 0, err - } - stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) + stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs) + return nil + }) + if err != nil { + return 0, err } - - return d.statements.insertState(ctx, roomNID, stateBlockNIDs) + return } // SetState implements input.EventDatabase func (d *Database) SetState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { - return d.statements.updateEventState(ctx, eventNID, stateNID) + return d.statements.updateEventState(ctx, nil, eventNID, stateNID) } // StateAtEventIDs implements input.EventDatabase func (d *Database) StateAtEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { - return d.statements.bulkSelectStateAtEventByID(ctx, eventIDs) + return d.statements.bulkSelectStateAtEventByID(ctx, nil, eventIDs) } // StateBlockNIDs implements state.RoomStateDatabase func (d *Database) StateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { - return d.statements.bulkSelectStateBlockNIDs(ctx, stateNIDs) + return d.statements.bulkSelectStateBlockNIDs(ctx, nil, stateNIDs) } // StateEntries implements state.RoomStateDatabase func (d *Database) StateEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { - return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs) + return d.statements.bulkSelectStateBlockEntries(ctx, nil, stateBlockNIDs) } // SnapshotNIDFromEventID implements state.RoomStateDatabase func (d *Database) SnapshotNIDFromEventID( ctx context.Context, eventID string, -) (types.StateSnapshotNID, error) { - _, stateNID, err := d.statements.selectEvent(ctx, eventID) - return stateNID, err +) (stateNID types.StateSnapshotNID, err error) { + _, stateNID, err = d.statements.selectEvent(ctx, nil, eventID) + return } // EventIDs implements input.RoomEventDatabase func (d *Database) EventIDs( ctx context.Context, eventNIDs []types.EventNID, ) (map[types.EventNID]string, error) { - return d.statements.bulkSelectEventID(ctx, eventNIDs) + return d.statements.bulkSelectEventID(ctx, nil, eventNIDs) } // GetLatestEventsForUpdate implements input.EventDatabase @@ -403,21 +434,25 @@ func (u *roomRecentEventsUpdater) SetLatestEvents( for i := range latest { eventNIDs[i] = latest[i].EventNID } - return u.d.statements.updateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) + // TODO: transaction was removed here - is this wise? + return u.d.statements.updateLatestEventNIDs(u.ctx, nil, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) } // HasEventBeenSent implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { - return u.d.statements.selectEventSentToOutput(u.ctx, u.txn, eventNID) + // TODO: transaction was removed here - is this wise? + return u.d.statements.selectEventSentToOutput(u.ctx, nil, eventNID) } // MarkEventAsSent implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { - return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID) + // TODO: transaction was removed here - is this wise? + return u.d.statements.updateEventSentToOutput(u.ctx, nil, eventNID) } func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) { - return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID) + // TODO: transaction was removed here - is this wise? + return u.d.membershipUpdaterTxn(u.ctx, nil, u.roomNID, targetUserNID) } // RoomNID implements query.RoomserverQueryAPIDB @@ -432,20 +467,24 @@ func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, e // LatestEventIDs implements query.RoomserverQueryAPIDatabase func (d *Database) LatestEventIDs( ctx context.Context, roomNID types.RoomNID, -) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) { - eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(ctx, roomNID) - if err != nil { - return nil, 0, 0, err - } - references, err := d.statements.bulkSelectEventReference(ctx, eventNIDs) - if err != nil { - return nil, 0, 0, err - } - depth, err := d.statements.selectMaxEventDepth(ctx, eventNIDs) - if err != nil { - return nil, 0, 0, err - } - return references, currentStateSnapshotNID, depth, nil +) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + var eventNIDs []types.EventNID + eventNIDs, currentStateSnapshotNID, err = d.statements.selectLatestEventNIDs(ctx, txn, roomNID) + if err != nil { + return err + } + references, err = d.statements.bulkSelectEventReference(ctx, txn, eventNIDs) + if err != nil { + return err + } + depth, err = d.statements.selectMaxEventDepth(ctx, txn, eventNIDs) + if err != nil { + return err + } + return nil + }) + return } // GetInvitesForUser implements query.RoomserverQueryAPIDatabase @@ -459,29 +498,29 @@ func (d *Database) GetInvitesForUser( // SetRoomAlias implements alias.RoomserverAliasAPIDB func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { - return d.statements.insertRoomAlias(ctx, alias, roomID, creatorUserID) + return d.statements.insertRoomAlias(ctx, nil, alias, roomID, creatorUserID) } // GetRoomIDForAlias implements alias.RoomserverAliasAPIDB func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { - return d.statements.selectRoomIDFromAlias(ctx, alias) + return d.statements.selectRoomIDFromAlias(ctx, nil, alias) } // GetAliasesForRoomID implements alias.RoomserverAliasAPIDB func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { - return d.statements.selectAliasesFromRoomID(ctx, roomID) + return d.statements.selectAliasesFromRoomID(ctx, nil, roomID) } // GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB func (d *Database) GetCreatorIDForAlias( ctx context.Context, alias string, ) (string, error) { - return d.statements.selectCreatorIDFromAlias(ctx, alias) + return d.statements.selectCreatorIDFromAlias(ctx, nil, alias) } // RemoveRoomAlias implements alias.RoomserverAliasAPIDB func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { - return d.statements.deleteRoomAlias(ctx, alias) + return d.statements.deleteRoomAlias(ctx, nil, alias) } // StateEntriesForTuples implements state.RoomStateDatabase @@ -491,7 +530,7 @@ func (d *Database) StateEntriesForTuples( stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntryList, error) { return d.statements.bulkSelectFilteredStateBlockEntries( - ctx, stateBlockNIDs, stateKeyTuples, + ctx, nil, stateBlockNIDs, stateKeyTuples, ) } @@ -666,36 +705,46 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s func (d *Database) GetMembership( ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, ) (membershipEventNID types.EventNID, stillInRoom bool, err error) { - requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID) - if err != nil { - return - } + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + requestSenderUserNID, err := d.assignStateKeyNID(ctx, txn, requestSenderUserID) + if err != nil { + return err + } - senderMembershipEventNID, senderMembership, err := - d.statements.selectMembershipFromRoomAndTarget( - ctx, roomNID, requestSenderUserNID, - ) - if err == sql.ErrNoRows { - // The user has never been a member of that room - return 0, false, nil - } else if err != nil { - return - } + membershipEventNID, _, err = + d.statements.selectMembershipFromRoomAndTarget( + ctx, txn, roomNID, requestSenderUserNID, + ) + if err == sql.ErrNoRows { + // The user has never been a member of that room + return nil + } + if err != nil { + return err + } + stillInRoom = true + return nil + }) - return senderMembershipEventNID, senderMembership == membershipStateJoin, nil + return } // GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB func (d *Database) GetMembershipEventNIDsForRoom( ctx context.Context, roomNID types.RoomNID, joinOnly bool, -) ([]types.EventNID, error) { - if joinOnly { - return d.statements.selectMembershipsFromRoomAndMembership( - ctx, roomNID, membershipStateJoin, - ) - } +) (eventNIDs []types.EventNID, err error) { + common.WithTransaction(d.db, func(txn *sql.Tx) error { + if joinOnly { + eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership( + ctx, txn, roomNID, membershipStateJoin, + ) + return nil + } - return d.statements.selectMembershipsFromRoom(ctx, roomNID) + eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID) + return nil + }) + return } // EventsFromIDs implements query.RoomserverQueryAPIEventDB