Protect txnReq.newEvents and txnReq.haveEvents with mutex

This commit is contained in:
Neil Alexander 2020-11-17 10:18:36 +00:00
parent 20a01bceb2
commit db01b82274
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -109,9 +109,11 @@ type txnReq struct {
federation txnFederationClient federation txnFederationClient
// local cache of events for auth checks, etc - this may include events // local cache of events for auth checks, etc - this may include events
// which the roomserver is unaware of. // which the roomserver is unaware of.
haveEvents map[string]*gomatrixserverlib.HeaderedEvent haveEvents map[string]*gomatrixserverlib.HeaderedEvent
haveEventsMutex sync.RWMutex
// new events which the roomserver does not know about // new events which the roomserver does not know about
newEvents map[string]bool newEvents map[string]bool
newEventsMutex sync.RWMutex
} }
// A subset of FederationClient functionality that txn requires. Useful for testing. // A subset of FederationClient functionality that txn requires. Useful for testing.
@ -264,6 +266,10 @@ func (e missingPrevEventsError) Error() string {
} }
func (t *txnReq) haveEventIDs() map[string]bool { func (t *txnReq) haveEventIDs() map[string]bool {
t.haveEventsMutex.RLock()
t.newEventsMutex.RLock()
defer t.haveEventsMutex.RUnlock()
defer t.newEventsMutex.RUnlock()
result := make(map[string]bool, len(t.haveEvents)) result := make(map[string]bool, len(t.haveEvents))
for eventID := range t.haveEvents { for eventID := range t.haveEvents {
if t.newEvents[eventID] { if t.newEvents[eventID] {
@ -704,7 +710,9 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix
default: default:
return nil, false, fmt.Errorf("t.lookupEvent: %w", err) return nil, false, fmt.Errorf("t.lookupEvent: %w", err)
} }
t.haveEventsMutex.Lock()
t.haveEvents[h.EventID()] = h t.haveEvents[h.EventID()] = h
t.haveEventsMutex.Unlock()
if h.StateKey() != nil { if h.StateKey() != nil {
addedToState := false addedToState := false
for i := range respState.StateEvents { for i := range respState.StateEvents {
@ -733,6 +741,7 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event
util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID) util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID)
return nil return nil
} }
t.haveEventsMutex.RLock()
for i, ev := range res.StateEvents { for i, ev := range res.StateEvents {
t.haveEvents[ev.EventID()] = res.StateEvents[i] t.haveEvents[ev.EventID()] = res.StateEvents[i]
} }
@ -748,6 +757,7 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event
} }
} }
} }
t.haveEventsMutex.RUnlock()
// QueryStateAfterEvents does not return the auth events, so fetch them now. We know the roomserver has them else it wouldn't // QueryStateAfterEvents does not return the auth events, so fetch them now. We know the roomserver has them else it wouldn't
// have stored the event. // have stored the event.
var missingEventList []string var missingEventList []string
@ -762,11 +772,13 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event
if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
return nil return nil
} }
t.haveEventsMutex.Lock()
for i := range queryRes.Events { for i := range queryRes.Events {
evID := queryRes.Events[i].EventID() evID := queryRes.Events[i].EventID()
t.haveEvents[evID] = queryRes.Events[i] t.haveEvents[evID] = queryRes.Events[i]
authEvents = append(authEvents, queryRes.Events[i].Unwrap()) authEvents = append(authEvents, queryRes.Events[i].Unwrap())
} }
t.haveEventsMutex.Unlock()
evs := gomatrixserverlib.UnwrapEventHeaders(res.StateEvents) evs := gomatrixserverlib.UnwrapEventHeaders(res.StateEvents)
return &gomatrixserverlib.RespState{ return &gomatrixserverlib.RespState{
@ -960,6 +972,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
wantIDs := append(stateIDs.StateEventIDs, stateIDs.AuthEventIDs...) wantIDs := append(stateIDs.StateEventIDs, stateIDs.AuthEventIDs...)
missing := make(map[string]bool) missing := make(map[string]bool)
var missingEventList []string var missingEventList []string
t.haveEventsMutex.RLock()
for _, sid := range wantIDs { for _, sid := range wantIDs {
if _, ok := t.haveEvents[sid]; !ok { if _, ok := t.haveEvents[sid]; !ok {
if !missing[sid] { if !missing[sid] {
@ -968,6 +981,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
} }
} }
} }
t.haveEventsMutex.RUnlock()
// fetch as many as we can from the roomserver // fetch as many as we can from the roomserver
queryReq := api.QueryEventsByIDRequest{ queryReq := api.QueryEventsByIDRequest{
@ -977,6 +991,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
return nil, err return nil, err
} }
t.haveEventsMutex.RLock()
for i := range queryRes.Events { for i := range queryRes.Events {
evID := queryRes.Events[i].EventID() evID := queryRes.Events[i].EventID()
t.haveEvents[evID] = queryRes.Events[i] t.haveEvents[evID] = queryRes.Events[i]
@ -984,6 +999,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
delete(missing, evID) delete(missing, evID)
} }
} }
t.haveEventsMutex.RUnlock()
concurrentRequests := 8 concurrentRequests := 8
missingCount := len(missing) missingCount := len(missing)
@ -1034,11 +1050,6 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
var fetchgroup sync.WaitGroup var fetchgroup sync.WaitGroup
fetchgroup.Add(concurrentRequests) fetchgroup.Add(concurrentRequests)
// This is the only place where we'll write to t.haveEvents from
// multiple goroutines, and everywhere else is blocked on this
// synchronous function anyway.
var haveEventsMutex sync.Mutex
// Define what we'll do in order to fetch the missing event ID. // Define what we'll do in order to fetch the missing event ID.
fetch := func(missingEventID string) { fetch := func(missingEventID string) {
var h *gomatrixserverlib.HeaderedEvent var h *gomatrixserverlib.HeaderedEvent
@ -1055,9 +1066,9 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
}).Info("Failed to fetch missing event") }).Info("Failed to fetch missing event")
return return
} }
haveEventsMutex.Lock() t.haveEventsMutex.Lock()
t.haveEvents[h.EventID()] = h t.haveEvents[h.EventID()] = h
haveEventsMutex.Unlock() t.haveEventsMutex.Unlock()
} }
// Create the worker. // Create the worker.
@ -1084,6 +1095,9 @@ func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStat
// create a RespState response using the response to /state_ids as a guide // create a RespState response using the response to /state_ids as a guide
respState := gomatrixserverlib.RespState{} respState := gomatrixserverlib.RespState{}
t.haveEventsMutex.RLock()
t.haveEventsMutex.RUnlock()
for i := range stateIDs.StateEventIDs { for i := range stateIDs.StateEventIDs {
ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]] ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]]
if !ok { if !ok {
@ -1144,6 +1158,8 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.
return nil, verifySigError{event.EventID(), err} return nil, verifySigError{event.EventID(), err}
} }
h := event.Headered(roomVersion) h := event.Headered(roomVersion)
t.newEventsMutex.Lock()
t.newEvents[h.EventID()] = true t.newEvents[h.EventID()] = true
t.newEventsMutex.Unlock()
return h, nil return h, nil
} }