diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 963d31713..a3011b500 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -109,8 +109,7 @@ type txnReq struct { federation txnFederationClient // local cache of events for auth checks, etc - this may include events // which the roomserver is unaware of. - haveEvents map[string]*gomatrixserverlib.HeaderedEvent - haveEventsMutex sync.RWMutex + haveEvents map[string]*gomatrixserverlib.HeaderedEvent // new events which the roomserver does not know about newEvents map[string]bool newEventsMutex sync.RWMutex @@ -266,9 +265,7 @@ func (e missingPrevEventsError) Error() string { } func (t *txnReq) haveEventIDs() map[string]bool { - t.haveEventsMutex.RLock() t.newEventsMutex.RLock() - defer t.haveEventsMutex.RUnlock() defer t.newEventsMutex.RUnlock() result := make(map[string]bool, len(t.haveEvents)) for eventID := range t.haveEvents { @@ -710,9 +707,7 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix default: return nil, false, fmt.Errorf("t.lookupEvent: %w", err) } - t.haveEventsMutex.Lock() t.haveEvents[h.EventID()] = h - t.haveEventsMutex.Unlock() if h.StateKey() != nil { addedToState := false for i := range respState.StateEvents { @@ -741,7 +736,6 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID) return nil } - t.haveEventsMutex.RLock() for i, ev := range res.StateEvents { t.haveEvents[ev.EventID()] = res.StateEvents[i] } @@ -757,7 +751,6 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event } } } - t.haveEventsMutex.RUnlock() // QueryStateAfterEvents does not return the auth events, so fetch them now. We know the roomserver has them else it wouldn't // have stored the event. var missingEventList []string @@ -772,13 +765,11 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { return nil } - t.haveEventsMutex.Lock() for i := range queryRes.Events { evID := queryRes.Events[i].EventID() t.haveEvents[evID] = queryRes.Events[i] authEvents = append(authEvents, queryRes.Events[i].Unwrap()) } - t.haveEventsMutex.Unlock() evs := gomatrixserverlib.UnwrapEventHeaders(res.StateEvents) return &gomatrixserverlib.RespState{ @@ -972,7 +963,6 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even wantIDs := append(stateIDs.StateEventIDs, stateIDs.AuthEventIDs...) missing := make(map[string]bool) var missingEventList []string - t.haveEventsMutex.RLock() for _, sid := range wantIDs { if _, ok := t.haveEvents[sid]; !ok { if !missing[sid] { @@ -981,7 +971,6 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even } } } - t.haveEventsMutex.RUnlock() // fetch as many as we can from the roomserver queryReq := api.QueryEventsByIDRequest{ @@ -991,7 +980,6 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { return nil, err } - t.haveEventsMutex.RLock() for i := range queryRes.Events { evID := queryRes.Events[i].EventID() t.haveEvents[evID] = queryRes.Events[i] @@ -999,7 +987,6 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even delete(missing, evID) } } - t.haveEventsMutex.RUnlock() concurrentRequests := 8 missingCount := len(missing) @@ -1050,6 +1037,11 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even var fetchgroup sync.WaitGroup fetchgroup.Add(concurrentRequests) + // This is the only place where we'll write to t.haveEvents from + // multiple goroutines, and everywhere else is blocked on this + // synchronous function anyway. + var haveEventsMutex sync.Mutex + // Define what we'll do in order to fetch the missing event ID. fetch := func(missingEventID string) { var h *gomatrixserverlib.HeaderedEvent @@ -1066,9 +1058,9 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even }).Info("Failed to fetch missing event") return } - t.haveEventsMutex.Lock() + haveEventsMutex.Lock() t.haveEvents[h.EventID()] = h - t.haveEventsMutex.Unlock() + haveEventsMutex.Unlock() } // Create the worker. @@ -1095,9 +1087,6 @@ func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStat // create a RespState response using the response to /state_ids as a guide respState := gomatrixserverlib.RespState{} - t.haveEventsMutex.RLock() - defer t.haveEventsMutex.RUnlock() - for i := range stateIDs.StateEventIDs { ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]] if !ok {