mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 00:03:09 -06:00
Protect txnReq.newEvents and txnReq.haveEvents with mutex
This commit is contained in:
parent
20a01bceb2
commit
db01b82274
|
|
@ -109,9 +109,11 @@ type txnReq struct {
|
||||||
federation txnFederationClient
|
federation txnFederationClient
|
||||||
// local cache of events for auth checks, etc - this may include events
|
// local cache of events for auth checks, etc - this may include events
|
||||||
// which the roomserver is unaware of.
|
// which the roomserver is unaware of.
|
||||||
haveEvents map[string]*gomatrixserverlib.HeaderedEvent
|
haveEvents map[string]*gomatrixserverlib.HeaderedEvent
|
||||||
|
haveEventsMutex sync.RWMutex
|
||||||
// new events which the roomserver does not know about
|
// new events which the roomserver does not know about
|
||||||
newEvents map[string]bool
|
newEvents map[string]bool
|
||||||
|
newEventsMutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// A subset of FederationClient functionality that txn requires. Useful for testing.
|
// A subset of FederationClient functionality that txn requires. Useful for testing.
|
||||||
|
|
@ -264,6 +266,10 @@ func (e missingPrevEventsError) Error() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *txnReq) haveEventIDs() map[string]bool {
|
func (t *txnReq) haveEventIDs() map[string]bool {
|
||||||
|
t.haveEventsMutex.RLock()
|
||||||
|
t.newEventsMutex.RLock()
|
||||||
|
defer t.haveEventsMutex.RUnlock()
|
||||||
|
defer t.newEventsMutex.RUnlock()
|
||||||
result := make(map[string]bool, len(t.haveEvents))
|
result := make(map[string]bool, len(t.haveEvents))
|
||||||
for eventID := range t.haveEvents {
|
for eventID := range t.haveEvents {
|
||||||
if t.newEvents[eventID] {
|
if t.newEvents[eventID] {
|
||||||
|
|
@ -704,7 +710,9 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix
|
||||||
default:
|
default:
|
||||||
return nil, false, fmt.Errorf("t.lookupEvent: %w", err)
|
return nil, false, fmt.Errorf("t.lookupEvent: %w", err)
|
||||||
}
|
}
|
||||||
|
t.haveEventsMutex.Lock()
|
||||||
t.haveEvents[h.EventID()] = h
|
t.haveEvents[h.EventID()] = h
|
||||||
|
t.haveEventsMutex.Unlock()
|
||||||
if h.StateKey() != nil {
|
if h.StateKey() != nil {
|
||||||
addedToState := false
|
addedToState := false
|
||||||
for i := range respState.StateEvents {
|
for i := range respState.StateEvents {
|
||||||
|
|
@ -733,6 +741,7 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event
|
||||||
util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID)
|
util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
t.haveEventsMutex.RLock()
|
||||||
for i, ev := range res.StateEvents {
|
for i, ev := range res.StateEvents {
|
||||||
t.haveEvents[ev.EventID()] = res.StateEvents[i]
|
t.haveEvents[ev.EventID()] = res.StateEvents[i]
|
||||||
}
|
}
|
||||||
|
|
@ -748,6 +757,7 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
t.haveEventsMutex.RUnlock()
|
||||||
// QueryStateAfterEvents does not return the auth events, so fetch them now. We know the roomserver has them else it wouldn't
|
// QueryStateAfterEvents does not return the auth events, so fetch them now. We know the roomserver has them else it wouldn't
|
||||||
// have stored the event.
|
// have stored the event.
|
||||||
var missingEventList []string
|
var missingEventList []string
|
||||||
|
|
@ -762,11 +772,13 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event
|
||||||
if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
|
if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
t.haveEventsMutex.Lock()
|
||||||
for i := range queryRes.Events {
|
for i := range queryRes.Events {
|
||||||
evID := queryRes.Events[i].EventID()
|
evID := queryRes.Events[i].EventID()
|
||||||
t.haveEvents[evID] = queryRes.Events[i]
|
t.haveEvents[evID] = queryRes.Events[i]
|
||||||
authEvents = append(authEvents, queryRes.Events[i].Unwrap())
|
authEvents = append(authEvents, queryRes.Events[i].Unwrap())
|
||||||
}
|
}
|
||||||
|
t.haveEventsMutex.Unlock()
|
||||||
|
|
||||||
evs := gomatrixserverlib.UnwrapEventHeaders(res.StateEvents)
|
evs := gomatrixserverlib.UnwrapEventHeaders(res.StateEvents)
|
||||||
return &gomatrixserverlib.RespState{
|
return &gomatrixserverlib.RespState{
|
||||||
|
|
@ -960,6 +972,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
|
||||||
wantIDs := append(stateIDs.StateEventIDs, stateIDs.AuthEventIDs...)
|
wantIDs := append(stateIDs.StateEventIDs, stateIDs.AuthEventIDs...)
|
||||||
missing := make(map[string]bool)
|
missing := make(map[string]bool)
|
||||||
var missingEventList []string
|
var missingEventList []string
|
||||||
|
t.haveEventsMutex.RLock()
|
||||||
for _, sid := range wantIDs {
|
for _, sid := range wantIDs {
|
||||||
if _, ok := t.haveEvents[sid]; !ok {
|
if _, ok := t.haveEvents[sid]; !ok {
|
||||||
if !missing[sid] {
|
if !missing[sid] {
|
||||||
|
|
@ -968,6 +981,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
t.haveEventsMutex.RUnlock()
|
||||||
|
|
||||||
// fetch as many as we can from the roomserver
|
// fetch as many as we can from the roomserver
|
||||||
queryReq := api.QueryEventsByIDRequest{
|
queryReq := api.QueryEventsByIDRequest{
|
||||||
|
|
@ -977,6 +991,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
|
||||||
if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
|
if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
t.haveEventsMutex.RLock()
|
||||||
for i := range queryRes.Events {
|
for i := range queryRes.Events {
|
||||||
evID := queryRes.Events[i].EventID()
|
evID := queryRes.Events[i].EventID()
|
||||||
t.haveEvents[evID] = queryRes.Events[i]
|
t.haveEvents[evID] = queryRes.Events[i]
|
||||||
|
|
@ -984,6 +999,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
|
||||||
delete(missing, evID)
|
delete(missing, evID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
t.haveEventsMutex.RUnlock()
|
||||||
|
|
||||||
concurrentRequests := 8
|
concurrentRequests := 8
|
||||||
missingCount := len(missing)
|
missingCount := len(missing)
|
||||||
|
|
@ -1034,11 +1050,6 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
|
||||||
var fetchgroup sync.WaitGroup
|
var fetchgroup sync.WaitGroup
|
||||||
fetchgroup.Add(concurrentRequests)
|
fetchgroup.Add(concurrentRequests)
|
||||||
|
|
||||||
// This is the only place where we'll write to t.haveEvents from
|
|
||||||
// multiple goroutines, and everywhere else is blocked on this
|
|
||||||
// synchronous function anyway.
|
|
||||||
var haveEventsMutex sync.Mutex
|
|
||||||
|
|
||||||
// Define what we'll do in order to fetch the missing event ID.
|
// Define what we'll do in order to fetch the missing event ID.
|
||||||
fetch := func(missingEventID string) {
|
fetch := func(missingEventID string) {
|
||||||
var h *gomatrixserverlib.HeaderedEvent
|
var h *gomatrixserverlib.HeaderedEvent
|
||||||
|
|
@ -1055,9 +1066,9 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
|
||||||
}).Info("Failed to fetch missing event")
|
}).Info("Failed to fetch missing event")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
haveEventsMutex.Lock()
|
t.haveEventsMutex.Lock()
|
||||||
t.haveEvents[h.EventID()] = h
|
t.haveEvents[h.EventID()] = h
|
||||||
haveEventsMutex.Unlock()
|
t.haveEventsMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the worker.
|
// Create the worker.
|
||||||
|
|
@ -1084,6 +1095,9 @@ func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStat
|
||||||
// create a RespState response using the response to /state_ids as a guide
|
// create a RespState response using the response to /state_ids as a guide
|
||||||
respState := gomatrixserverlib.RespState{}
|
respState := gomatrixserverlib.RespState{}
|
||||||
|
|
||||||
|
t.haveEventsMutex.RLock()
|
||||||
|
t.haveEventsMutex.RUnlock()
|
||||||
|
|
||||||
for i := range stateIDs.StateEventIDs {
|
for i := range stateIDs.StateEventIDs {
|
||||||
ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]]
|
ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|
@ -1144,6 +1158,8 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.
|
||||||
return nil, verifySigError{event.EventID(), err}
|
return nil, verifySigError{event.EventID(), err}
|
||||||
}
|
}
|
||||||
h := event.Headered(roomVersion)
|
h := event.Headered(roomVersion)
|
||||||
|
t.newEventsMutex.Lock()
|
||||||
t.newEvents[h.EventID()] = true
|
t.newEvents[h.EventID()] = true
|
||||||
|
t.newEventsMutex.Unlock()
|
||||||
return h, nil
|
return h, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue