diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index a06b4b696..3169b659b 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -468,27 +468,38 @@ func (d *Database) StoreEvent( if !isRejected { // ignore rejected redaction events redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event) } - - var roomInfo *types.RoomInfo - var updater *LatestEventsUpdater - roomInfo, err = d.RoomInfo(ctx, event.RoomID()) - if err != nil { - return fmt.Errorf("d.RoomInfo: %w", err) - } - updater, err = d.GetLatestEventsForUpdate(ctx, *roomInfo) - if err != nil { - return fmt.Errorf("NewLatestEventsUpdater: %w", err) - } - if err = updater.StorePreviousEvents(eventNID, event.PrevEvents()); err != nil { - return fmt.Errorf("updater.StorePreviousEvents: %w", err) - } - succeeded := false - return sqlutil.EndTransaction(updater, &succeeded) + return nil }) if err != nil { return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err) } + // We should attempt to update the previous events table with any + // references that this new event makes. We do this using a latest + // events updater because it somewhat works as a mutex, ensuring + // that there's a row-level lock on the latest room events (well, + // on Postgres at least). + var roomInfo *types.RoomInfo + var updater *LatestEventsUpdater + if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { + roomInfo, err = d.RoomInfo(ctx, event.RoomID()) + if err != nil { + return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) + } + if roomInfo == nil && len(prevEvents) > 0 { + return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) + } + updater, err = d.GetLatestEventsForUpdate(ctx, *roomInfo) + if err != nil { + return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("NewLatestEventsUpdater: %w", err) + } + if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { + return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err) + } + succeeded := false + err = sqlutil.EndTransaction(updater, &succeeded) + } + return roomNID, types.StateAtEvent{ BeforeStateSnapshotNID: stateNID, StateEntry: types.StateEntry{ @@ -498,7 +509,7 @@ func (d *Database) StoreEvent( }, EventNID: eventNID, }, - }, redactionEvent, redactedEventID, nil + }, redactionEvent, redactedEventID, err } func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) error {