diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 104d2e73e..0a944f517 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -109,9 +109,11 @@ type txnReq struct { federation txnFederationClient // local cache of events for auth checks, etc - this may include events // which the roomserver is unaware of. - haveEvents map[string]*gomatrixserverlib.HeaderedEvent + haveEvents map[string]*gomatrixserverlib.HeaderedEvent + haveEventsMutex sync.RWMutex // new events which the roomserver does not know about - newEvents map[string]bool + newEvents map[string]bool + newEventsMutex sync.RWMutex } // A subset of FederationClient functionality that txn requires. Useful for testing. @@ -264,6 +266,10 @@ func (e missingPrevEventsError) Error() string { } func (t *txnReq) haveEventIDs() map[string]bool { + t.haveEventsMutex.RLock() + t.newEventsMutex.RLock() + defer t.haveEventsMutex.RUnlock() + defer t.newEventsMutex.RUnlock() result := make(map[string]bool, len(t.haveEvents)) for eventID := range t.haveEvents { if t.newEvents[eventID] { @@ -704,7 +710,9 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix default: return nil, false, fmt.Errorf("t.lookupEvent: %w", err) } + t.haveEventsMutex.Lock() t.haveEvents[h.EventID()] = h + t.haveEventsMutex.Unlock() if h.StateKey() != nil { addedToState := false for i := range respState.StateEvents { @@ -733,6 +741,7 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID) return nil } + t.haveEventsMutex.RLock() for i, ev := range res.StateEvents { t.haveEvents[ev.EventID()] = res.StateEvents[i] } @@ -748,6 +757,7 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event } } } + t.haveEventsMutex.RUnlock() // QueryStateAfterEvents does not return the auth events, so fetch them now. We know the roomserver has them else it wouldn't // have stored the event. var missingEventList []string @@ -762,11 +772,13 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { return nil } + t.haveEventsMutex.Lock() for i := range queryRes.Events { evID := queryRes.Events[i].EventID() t.haveEvents[evID] = queryRes.Events[i] authEvents = append(authEvents, queryRes.Events[i].Unwrap()) } + t.haveEventsMutex.Unlock() evs := gomatrixserverlib.UnwrapEventHeaders(res.StateEvents) return &gomatrixserverlib.RespState{ @@ -960,6 +972,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even wantIDs := append(stateIDs.StateEventIDs, stateIDs.AuthEventIDs...) missing := make(map[string]bool) var missingEventList []string + t.haveEventsMutex.RLock() for _, sid := range wantIDs { if _, ok := t.haveEvents[sid]; !ok { if !missing[sid] { @@ -968,6 +981,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even } } } + t.haveEventsMutex.RUnlock() // fetch as many as we can from the roomserver queryReq := api.QueryEventsByIDRequest{ @@ -977,6 +991,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { return nil, err } + t.haveEventsMutex.RLock() for i := range queryRes.Events { evID := queryRes.Events[i].EventID() t.haveEvents[evID] = queryRes.Events[i] @@ -984,6 +999,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even delete(missing, evID) } } + t.haveEventsMutex.RUnlock() concurrentRequests := 8 missingCount := len(missing) @@ -1034,11 +1050,6 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even var fetchgroup sync.WaitGroup fetchgroup.Add(concurrentRequests) - // This is the only place where we'll write to t.haveEvents from - // multiple goroutines, and everywhere else is blocked on this - // synchronous function anyway. - var haveEventsMutex sync.Mutex - // Define what we'll do in order to fetch the missing event ID. fetch := func(missingEventID string) { var h *gomatrixserverlib.HeaderedEvent @@ -1055,9 +1066,9 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even }).Info("Failed to fetch missing event") return } - haveEventsMutex.Lock() + t.haveEventsMutex.Lock() t.haveEvents[h.EventID()] = h - haveEventsMutex.Unlock() + t.haveEventsMutex.Unlock() } // Create the worker. @@ -1084,6 +1095,9 @@ func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStat // create a RespState response using the response to /state_ids as a guide respState := gomatrixserverlib.RespState{} + t.haveEventsMutex.RLock() + t.haveEventsMutex.RUnlock() + for i := range stateIDs.StateEventIDs { ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]] if !ok { @@ -1144,6 +1158,8 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib. return nil, verifySigError{event.EventID(), err} } h := event.Headered(roomVersion) + t.newEventsMutex.Lock() t.newEvents[h.EventID()] = true + t.newEventsMutex.Unlock() return h, nil }