diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index d5e44f72c..989224a49 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -232,7 +232,8 @@ type txnReq struct { // something that can tell us about which servers are in a room right now servers federationAPI.ServersInRoomProvider // a list of events from the auth and prev events which we already had - hadEvents map[string]bool + hadEvents map[string]bool + hadEventsMutex sync.Mutex // local cache of events for auth checks, etc - this may include events // which the roomserver is unaware of. haveEvents map[string]*gomatrixserverlib.HeaderedEvent @@ -594,12 +595,14 @@ func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) e // Prepare a map of all the events we already had before this point, so // that we don't send them to the roomserver again. + t.hadEventsMutex.Lock() for _, eventID := range append(e.AuthEventIDs(), e.PrevEventIDs()...) { t.hadEvents[eventID] = true } for _, eventID := range append(stateResp.MissingAuthEventIDs, stateResp.MissingPrevEventIDs...) { t.hadEvents[eventID] = false } + t.hadEventsMutex.Unlock() if len(stateResp.MissingAuthEventIDs) > 0 { t.work = MetricsWorkMissingAuthEvents @@ -673,7 +676,9 @@ withNextEvent: ); err != nil { return fmt.Errorf("api.SendEvents: %w", err) } + t.hadEventsMutex.Lock() t.hadEvents[ev.EventID()] = true // if the roomserver didn't know about the event before, it does now + t.hadEventsMutex.Unlock() t.cacheAndReturn(ev.Headered(stateResp.RoomVersion)) delete(missingAuthEvents, missingAuthEventID) continue withNextEvent @@ -801,14 +806,23 @@ func (t *txnReq) processEventWithMissingState( // First of all, send the backward extremity into the roomserver with the // newly resolved state. This marks the "oldest" point in the backfill and - // sets the baseline state for any new events after this. + // sets the baseline state for any new events after this. We'll make a + // copy of the hadEvents map so that it can be taken downstream without + // worrying about concurrent map reads/writes, since t.hadEvents is meant + // to be protected by a mutex. + hadEvents := map[string]bool{} + t.hadEventsMutex.Lock() + for k, v := range t.hadEvents { + hadEvents[k] = v + } + t.hadEventsMutex.Unlock() err = api.SendEventWithState( context.Background(), t.rsAPI, api.KindOld, resolvedState, backwardsExtremity.Headered(roomVersion), - t.hadEvents, + hadEvents, ) if err != nil { return fmt.Errorf("api.SendEventWithState: %w", err) @@ -904,7 +918,9 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event // set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this // processEvent request, which is better for memory. stateEvents[i] = t.cacheAndReturn(ev) + t.hadEventsMutex.Lock() t.hadEvents[ev.EventID()] = true + t.hadEventsMutex.Unlock() } // we should never access res.StateEvents again so we delete it here to make GC faster res.StateEvents = nil @@ -939,7 +955,9 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event } for i, ev := range queryRes.Events { authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap()) + t.hadEventsMutex.Lock() t.hadEvents[ev.EventID()] = true + t.hadEventsMutex.Unlock() } queryRes.Events = nil } @@ -1016,7 +1034,9 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Even latestEvents := make([]string, len(res.LatestEvents)) for i, ev := range res.LatestEvents { latestEvents[i] = res.LatestEvents[i].EventID + t.hadEventsMutex.Lock() t.hadEvents[ev.EventID] = true + t.hadEventsMutex.Unlock() } var missingResp *gomatrixserverlib.RespMissingEvents @@ -1152,7 +1172,9 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even } for i, ev := range queryRes.Events { queryRes.Events[i] = t.cacheAndReturn(queryRes.Events[i]) + t.hadEventsMutex.Lock() t.hadEvents[ev.EventID()] = true + t.hadEventsMutex.Unlock() evID := queryRes.Events[i].EventID() if missing[evID] { delete(missing, evID)