diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 994399532..c8346733d 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -58,6 +58,9 @@ const insertRoomNIDSQL = "" + const selectRoomNIDSQL = "" + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" +const selectRoomNIDForUpdateSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1 FOR UPDATE" + const selectLatestEventNIDsSQL = "" + "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" @@ -85,6 +88,7 @@ const bulkSelectRoomNIDsSQL = "" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt + selectRoomNIDForUpdateStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt @@ -106,6 +110,7 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) { return s, sqlutil.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, + {&s.selectRoomNIDForUpdateStmt, selectRoomNIDForUpdateSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, @@ -169,6 +174,15 @@ func (s *roomStatements) SelectRoomNID( return types.RoomNID(roomNID), err } +func (s *roomStatements) SelectRoomNIDForUpdate( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var roomNID int64 + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDForUpdateStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err +} + func (s *roomStatements) SelectLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 641633b4e..fca767662 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1360,9 +1360,9 @@ func (d *Database) PurgeRoom(ctx context.Context, roomID string) error { return fmt.Errorf("not supported on this database engine") } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - roomNID, err := d.RoomsTable.SelectRoomNID(ctx, txn, roomID) + roomNID, err := d.RoomsTable.SelectRoomNIDForUpdate(ctx, txn, roomID) if err != nil { - return fmt.Errorf("failed to find room NID: %w", err) + return fmt.Errorf("failed to lock the room: %w", err) } if err := d.Purge.PurgeStateBlocks(ctx, txn, roomNID); err != nil { return fmt.Errorf("failed to purge state blocks: %w", err) diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 25b611b3e..072746420 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -169,6 +169,12 @@ func (s *roomStatements) SelectRoomNID( return types.RoomNID(roomNID), err } +func (s *roomStatements) SelectRoomNIDForUpdate( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + return 0, fmt.Errorf("not supported on SQLite") +} + func (s *roomStatements) SelectLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 508ac3c2b..3e93f5dff 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -72,6 +72,7 @@ type Events interface { type Rooms interface { InsertRoomNID(ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion) (types.RoomNID, error) SelectRoomNID(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error) + SelectRoomNIDForUpdate(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error) SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error