diff --git a/cmd/dendrite-polylith-multi/personalities/federationapi.go b/cmd/dendrite-polylith-multi/personalities/federationapi.go index 498be3c43..5ff085282 100644 --- a/cmd/dendrite-polylith-multi/personalities/federationapi.go +++ b/cmd/dendrite-polylith-multi/personalities/federationapi.go @@ -33,7 +33,7 @@ func FederationAPI(base *setup.BaseDendrite, cfg *config.Dendrite) { base.PublicFederationAPIMux, base.PublicKeyAPIMux, &base.Cfg.FederationAPI, userAPI, federation, keyRing, rsAPI, fsAPI, base.EDUServerClient(), keyAPI, - &base.Cfg.MSCs, + &base.Cfg.MSCs, nil, ) base.SetupAndServeHTTP( diff --git a/federationapi/api/servers.go b/federationapi/api/servers.go new file mode 100644 index 000000000..6bb15763d --- /dev/null +++ b/federationapi/api/servers.go @@ -0,0 +1,11 @@ +package api + +import ( + "context" + + "github.com/matrix-org/gomatrixserverlib" +) + +type ServersInRoomProvider interface { + GetServersForRoom(ctx context.Context, roomID string, event *gomatrixserverlib.Event) []gomatrixserverlib.ServerName +} diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 6188b283e..b3297434a 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -17,6 +17,7 @@ package federationapi import ( "github.com/gorilla/mux" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" + federationAPI "github.com/matrix-org/dendrite/federationapi/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" @@ -39,10 +40,12 @@ func AddPublicRoutes( eduAPI eduserverAPI.EDUServerInputAPI, keyAPI keyserverAPI.KeyInternalAPI, mscCfg *config.MSCs, + servers federationAPI.ServersInRoomProvider, ) { routing.Setup( fedRouter, keyRouter, cfg, rsAPI, eduAPI, federationSenderAPI, keyRing, federation, userAPI, keyAPI, mscCfg, + servers, ) } diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index b97876d3d..505a11dae 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -31,7 +31,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { fsAPI := base.FederationSenderHTTPClient() // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. - federationapi.AddPublicRoutes(base.PublicFederationAPIMux, base.PublicKeyAPIMux, &cfg.FederationAPI, nil, nil, keyRing, nil, fsAPI, nil, nil, &cfg.MSCs) + federationapi.AddPublicRoutes(base.PublicFederationAPIMux, base.PublicKeyAPIMux, &cfg.FederationAPI, nil, nil, keyRing, nil, fsAPI, nil, nil, &cfg.MSCs, nil) baseURL, cancel := test.ListenAndServe(t, base.PublicFederationAPIMux, true) defer cancel() serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 07a28c3fc..8f33c7660 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -20,6 +20,7 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/clientapi/jsonerror" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" + federationAPI "github.com/matrix-org/dendrite/federationapi/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" @@ -50,6 +51,7 @@ func Setup( userAPI userapi.UserInternalAPI, keyAPI keyserverAPI.KeyInternalAPI, mscCfg *config.MSCs, + servers federationAPI.ServersInRoomProvider, ) { v2keysmux := keyMux.PathPrefix("/v2").Subrouter() v1fedmux := fedMux.PathPrefix("/v1").Subrouter() @@ -99,7 +101,7 @@ func Setup( func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return Send( httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]), - cfg, rsAPI, eduAPI, keyAPI, keys, federation, mu, + cfg, rsAPI, eduAPI, keyAPI, keys, federation, mu, servers, ) }, )).Methods(http.MethodPut, http.MethodOptions) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 708ba38ec..ae9a63fc2 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -16,16 +16,16 @@ package routing import ( "context" - "database/sql" "encoding/json" + "errors" "fmt" "net/http" "sync" "time" - "github.com/getsentry/sentry-go" "github.com/matrix-org/dendrite/clientapi/jsonerror" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" + federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" @@ -34,6 +34,7 @@ import ( "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + "go.uber.org/atomic" ) const ( @@ -88,6 +89,67 @@ func init() { ) } +type sendFIFOQueue struct { + tasks []*inputTask + count int + mutex sync.Mutex + notifs chan struct{} +} + +func newSendFIFOQueue() *sendFIFOQueue { + q := &sendFIFOQueue{ + notifs: make(chan struct{}, 1), + } + return q +} + +func (q *sendFIFOQueue) push(frame *inputTask) { + q.mutex.Lock() + defer q.mutex.Unlock() + q.tasks = append(q.tasks, frame) + q.count++ + select { + case q.notifs <- struct{}{}: + default: + } +} + +// pop returns the first item of the queue, if there is one. +// The second return value will indicate if a task was returned. +func (q *sendFIFOQueue) pop() (*inputTask, bool) { + q.mutex.Lock() + defer q.mutex.Unlock() + if q.count == 0 { + return nil, false + } + frame := q.tasks[0] + q.tasks[0] = nil + q.tasks = q.tasks[1:] + q.count-- + if q.count == 0 { + // Force a GC of the underlying array, since it might have + // grown significantly if the queue was hammered for some reason + q.tasks = nil + } + return frame, true +} + +type inputTask struct { + ctx context.Context + t *txnReq + event *gomatrixserverlib.Event + wg *sync.WaitGroup + err error // written back by worker, only safe to read when all tasks are done + duration time.Duration // written back by worker, only safe to read when all tasks are done +} + +type inputWorker struct { + running atomic.Bool + input *sendFIFOQueue +} + +var inputWorkers sync.Map // room ID -> *inputWorker + // Send implements /_matrix/federation/v1/send/{txnID} func Send( httpReq *http.Request, @@ -100,14 +162,16 @@ func Send( keys gomatrixserverlib.JSONVerifier, federation *gomatrixserverlib.FederationClient, mu *internal.MutexByRoom, + servers federationAPI.ServersInRoomProvider, ) util.JSONResponse { t := txnReq{ rsAPI: rsAPI, eduAPI: eduAPI, keys: keys, federation: federation, + hadEvents: make(map[string]bool), haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), - newEvents: make(map[string]bool), + servers: servers, keyAPI: keyAPI, roomsMu: mu, } @@ -159,21 +223,21 @@ func Send( type txnReq struct { gomatrixserverlib.Transaction - rsAPI api.RoomserverInternalAPI - eduAPI eduserverAPI.EDUServerInputAPI - keyAPI keyapi.KeyInternalAPI - keys gomatrixserverlib.JSONVerifier - federation txnFederationClient - servers []gomatrixserverlib.ServerName - serversMutex sync.RWMutex - roomsMu *internal.MutexByRoom + rsAPI api.RoomserverInternalAPI + eduAPI eduserverAPI.EDUServerInputAPI + keyAPI keyapi.KeyInternalAPI + keys gomatrixserverlib.JSONVerifier + federation txnFederationClient + roomsMu *internal.MutexByRoom + // something that can tell us about which servers are in a room right now + servers federationAPI.ServersInRoomProvider + // a list of events from the auth and prev events which we already had + hadEvents map[string]bool // local cache of events for auth checks, etc - this may include events // which the roomserver is unaware of. - haveEvents map[string]*gomatrixserverlib.HeaderedEvent - // new events which the roomserver does not know about - newEvents map[string]bool - newEventsMutex sync.RWMutex - work string // metrics + haveEvents map[string]*gomatrixserverlib.HeaderedEvent + haveEventsMutex sync.Mutex + work string // metrics } // A subset of FederationClient functionality that txn requires. Useful for testing. @@ -189,8 +253,12 @@ type txnFederationClient interface { func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { results := make(map[string]gomatrixserverlib.PDUResult) + //var resultsMutex sync.Mutex + + var wg sync.WaitGroup + var tasks []*inputTask + wg.Add(1) // for processEDUs - pdus := []*gomatrixserverlib.HeaderedEvent{} for _, pdu := range t.PDUs { pduCountTotal.WithLabelValues("total").Inc() var header struct { @@ -241,83 +309,97 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res } continue } - pdus = append(pdus, event.Headered(verRes.RoomVersion)) + v, _ := inputWorkers.LoadOrStore(event.RoomID(), &inputWorker{ + input: newSendFIFOQueue(), + }) + worker := v.(*inputWorker) + if !worker.running.Load() { + go worker.run() + } + wg.Add(1) + task := &inputTask{ + ctx: ctx, + t: t, + event: event, + wg: &wg, + } + tasks = append(tasks, task) + worker.input.push(task) } - // Process the events. - for _, e := range pdus { - evStart := time.Now() - if err := t.processEvent(ctx, e.Unwrap()); err != nil { - // If the error is due to the event itself being bad then we skip - // it and move onto the next event. We report an error so that the - // sender knows that we have skipped processing it. - // - // However if the event is due to a temporary failure in our server - // such as a database being unavailable then we should bail, and - // hope that the sender will retry when we are feeling better. - // - // It is uncertain what we should do if an event fails because - // we failed to fetch more information from the sending server. - // For example if a request to /state fails. - // If we skip the event then we risk missing the event until we - // receive another event referencing it. - // If we bail and stop processing then we risk wedging incoming - // transactions from that server forever. - if isProcessingErrorFatal(err) { - sentry.CaptureException(err) - // Any other error should be the result of a temporary error in - // our server so we should bail processing the transaction entirely. - util.GetLogger(ctx).Warnf("Processing %s failed fatally: %s", e.EventID(), err) - jsonErr := util.ErrorResponse(err) - processEventSummary.WithLabelValues(t.work, MetricsOutcomeFatal).Observe( - float64(time.Since(evStart).Nanoseconds()) / 1000., - ) - return nil, &jsonErr - } else { - // Auth errors mean the event is 'rejected' which have to be silent to appease sytest - errMsg := "" - outcome := MetricsOutcomeRejected - _, rejected := err.(*gomatrixserverlib.NotAllowed) - if !rejected { - errMsg = err.Error() - outcome = MetricsOutcomeFail - } - util.GetLogger(ctx).WithError(err).WithField("event_id", e.EventID()).WithField("rejected", rejected).Warn( - "Failed to process incoming federation event, skipping", - ) - processEventSummary.WithLabelValues(t.work, outcome).Observe( - float64(time.Since(evStart).Nanoseconds()) / 1000., - ) - results[e.EventID()] = gomatrixserverlib.PDUResult{ - Error: errMsg, - } + go func() { + defer wg.Done() + t.processEDUs(ctx) + }() + + wg.Wait() + + for _, task := range tasks { + if task.err != nil { + results[task.event.EventID()] = gomatrixserverlib.PDUResult{ + Error: task.err.Error(), } } else { - results[e.EventID()] = gomatrixserverlib.PDUResult{} - pduCountTotal.WithLabelValues("success").Inc() - processEventSummary.WithLabelValues(t.work, MetricsOutcomeOK).Observe( - float64(time.Since(evStart).Nanoseconds()) / 1000., - ) + results[task.event.EventID()] = gomatrixserverlib.PDUResult{} } } - t.processEDUs(ctx) if c := len(results); c > 0 { util.GetLogger(ctx).Infof("Processed %d PDUs from transaction %q", c, t.TransactionID) } return &gomatrixserverlib.RespSend{PDUs: results}, nil } -// isProcessingErrorFatal returns true if the error is really bad and -// we should stop processing the transaction, and returns false if it -// is just some less serious error about a specific event. -func isProcessingErrorFatal(err error) bool { - switch err { - case sql.ErrConnDone: - case sql.ErrTxDone: - return true +func (t *inputWorker) run() { + if !t.running.CAS(false, true) { + return + } + defer t.running.Store(false) + for { + task, ok := t.input.pop() + if !ok { + return + } + if task == nil { + continue + } + func() { + defer task.wg.Done() + select { + case <-task.ctx.Done(): + task.err = context.DeadlineExceeded + return + default: + evStart := time.Now() + task.err = task.t.processEvent(task.ctx, task.event) + task.duration = time.Since(evStart) + if err := task.err; err != nil { + switch err.(type) { + case *gomatrixserverlib.NotAllowed: + processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeRejected).Observe( + float64(time.Since(evStart).Nanoseconds()) / 1000., + ) + util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", true).Warn( + "Failed to process incoming federation event, skipping", + ) + task.err = nil // make "rejected" failures silent + default: + processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeFail).Observe( + float64(time.Since(evStart).Nanoseconds()) / 1000., + ) + util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", false).Warn( + "Failed to process incoming federation event, skipping", + ) + } + } else { + pduCountTotal.WithLabelValues("success").Inc() + processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeOK).Observe( + float64(time.Since(evStart).Nanoseconds()) / 1000., + ) + } + } + }() } - return false } type roomNotFoundError struct { @@ -340,19 +422,6 @@ func (e missingPrevEventsError) Error() string { return fmt.Sprintf("unable to get prev_events for event %q: %s", e.eventID, e.err) } -func (t *txnReq) haveEventIDs() map[string]bool { - t.newEventsMutex.RLock() - defer t.newEventsMutex.RUnlock() - result := make(map[string]bool, len(t.haveEvents)) - for eventID := range t.haveEvents { - if t.newEvents[eventID] { - continue - } - result[eventID] = true - } - return result -} - func (t *txnReq) processEDUs(ctx context.Context) { for _, e := range t.EDUs { eduCountTotal.Inc() @@ -479,22 +548,24 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli } } -func (t *txnReq) getServers(ctx context.Context, roomID string) []gomatrixserverlib.ServerName { - t.serversMutex.Lock() - defer t.serversMutex.Unlock() +func (t *txnReq) getServers(ctx context.Context, roomID string, event *gomatrixserverlib.Event) []gomatrixserverlib.ServerName { + // The server that sent us the event should be sufficient to tell us about missing + // prev and auth events. + servers := []gomatrixserverlib.ServerName{t.Origin} + // If the event origin is different to the transaction origin then we can use + // this as a last resort. The origin server that created the event would have + // had to know the auth and prev events. + if event != nil { + if origin := event.Origin(); origin != t.Origin { + servers = append(servers, origin) + } + } + // If a specific room-to-server provider exists then use that. This will primarily + // be used for the P2P demos. if t.servers != nil { - return t.servers + servers = append(servers, t.servers.GetServersForRoom(ctx, roomID, event)...) } - t.servers = []gomatrixserverlib.ServerName{t.Origin} - serverReq := &api.QueryServerJoinedToRoomRequest{ - RoomID: roomID, - } - serverRes := &api.QueryServerJoinedToRoomResponse{} - if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil { - t.servers = append(t.servers, serverRes.ServerNames...) - util.GetLogger(ctx).Infof("Found %d server(s) to query for missing events in %q", len(t.servers), roomID) - } - return t.servers + return servers } func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) error { @@ -527,6 +598,15 @@ func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) e return roomNotFoundError{e.RoomID()} } + // Prepare a map of all the events we already had before this point, so + // that we don't send them to the roomserver again. + for _, eventID := range append(e.AuthEventIDs(), e.PrevEventIDs()...) { + t.hadEvents[eventID] = true + } + for _, eventID := range append(stateResp.MissingAuthEventIDs, stateResp.MissingPrevEventIDs...) { + t.hadEvents[eventID] = false + } + if len(stateResp.MissingAuthEventIDs) > 0 { t.work = MetricsWorkMissingAuthEvents logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs)) @@ -570,11 +650,14 @@ func (t *txnReq) retrieveMissingAuthEvents( withNextEvent: for missingAuthEventID := range missingAuthEvents { withNextServer: - for _, server := range t.getServers(ctx, e.RoomID()) { + for _, server := range t.getServers(ctx, e.RoomID(), e) { logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server) tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID) if err != nil { logger.WithError(err).Warnf("Failed to retrieve auth event %q", missingAuthEventID) + if errors.Is(err, context.DeadlineExceeded) { + return err + } continue withNextServer } ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion) @@ -596,6 +679,8 @@ withNextEvent: ); err != nil { return fmt.Errorf("api.SendEvents: %w", err) } + t.hadEvents[ev.EventID()] = true // if the roomserver didn't know about the event before, it does now + t.cacheAndReturn(ev.Headered(stateResp.RoomVersion)) delete(missingAuthEvents, missingAuthEventID) continue withNextEvent } @@ -618,14 +703,14 @@ func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserv return gomatrixserverlib.Allowed(e, &authUsingState) } +var processEventWithMissingStateMutexes = internal.NewMutexByRoom() + func (t *txnReq) processEventWithMissingState( ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, ) error { - // Do this with a fresh context, so that we keep working even if the - // original request times out. With any luck, by the time the remote - // side retries, we'll have fetched the missing state. - gmectx, cancel := context.WithTimeout(context.Background(), time.Minute*5) - defer cancel() + processEventWithMissingStateMutexes.Lock(e.RoomID()) + defer processEventWithMissingStateMutexes.Unlock(e.RoomID()) + // We are missing the previous events for this events. // This means that there is a gap in our view of the history of the // room. There two ways that we can handle such a gap: @@ -646,7 +731,7 @@ func (t *txnReq) processEventWithMissingState( // - fill in the gap completely then process event `e` returning no backwards extremity // - fail to fill in the gap and tell us to terminate the transaction err=not nil // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction - newEvents, err := t.getMissingEvents(gmectx, e, roomVersion) + newEvents, err := t.getMissingEvents(ctx, e, roomVersion) if err != nil { return err } @@ -673,7 +758,7 @@ func (t *txnReq) processEventWithMissingState( // Look up what the state is after the backward extremity. This will either // come from the roomserver, if we know all the required events, or it will // come from a remote server via /state_ids if not. - prevState, trustworthy, lerr := t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID) + prevState, trustworthy, lerr := t.lookupStateAfterEvent(ctx, roomVersion, backwardsExtremity.RoomID(), prevEventID) if lerr != nil { util.GetLogger(ctx).WithError(lerr).Errorf("Failed to lookup state after prev_event: %s", prevEventID) return lerr @@ -717,7 +802,7 @@ func (t *txnReq) processEventWithMissingState( } // There's more than one previous state - run them all through state res t.roomsMu.Lock(e.RoomID()) - resolvedState, err = t.resolveStatesAndCheck(gmectx, roomVersion, respStates, backwardsExtremity) + resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, backwardsExtremity) t.roomsMu.Unlock(e.RoomID()) if err != nil { util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) @@ -734,7 +819,7 @@ func (t *txnReq) processEventWithMissingState( api.KindOld, resolvedState, backwardsExtremity.Headered(roomVersion), - t.haveEventIDs(), + t.hadEvents, ) if err != nil { return fmt.Errorf("api.SendEventWithState: %w", err) @@ -786,7 +871,7 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix default: return nil, false, fmt.Errorf("t.lookupEvent: %w", err) } - t.cacheAndReturn(h) + h = t.cacheAndReturn(h) if h.StateKey() != nil { addedToState := false for i := range respState.StateEvents { @@ -806,6 +891,8 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix } func (t *txnReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *gomatrixserverlib.HeaderedEvent { + t.haveEventsMutex.Lock() + defer t.haveEventsMutex.Unlock() if cached, exists := t.haveEvents[ev.EventID()]; exists { return cached } @@ -828,6 +915,7 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event // set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this // processEvent request, which is better for memory. stateEvents[i] = t.cacheAndReturn(ev) + t.hadEvents[ev.EventID()] = true } // we should never access res.StateEvents again so we delete it here to make GC faster res.StateEvents = nil @@ -835,6 +923,7 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event var authEvents []*gomatrixserverlib.Event missingAuthEvents := map[string]bool{} for _, ev := range stateEvents { + t.haveEventsMutex.Lock() for _, ae := range ev.AuthEventIDs() { if aev, ok := t.haveEvents[ae]; ok { authEvents = append(authEvents, aev.Unwrap()) @@ -842,6 +931,7 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event missingAuthEvents[ae] = true } } + t.haveEventsMutex.Unlock() } // QueryStateAfterEvents does not return the auth events, so fetch them now. We know the roomserver has them else it wouldn't // have stored the event. @@ -858,8 +948,9 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { return nil } - for i := range queryRes.Events { + for i, ev := range queryRes.Events { authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap()) + t.hadEvents[ev.EventID()] = true } queryRes.Events = nil } @@ -934,12 +1025,13 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Even return nil, err } latestEvents := make([]string, len(res.LatestEvents)) - for i := range res.LatestEvents { + for i, ev := range res.LatestEvents { latestEvents[i] = res.LatestEvents[i].EventID + t.hadEvents[ev.EventID] = true } var missingResp *gomatrixserverlib.RespMissingEvents - servers := t.getServers(ctx, e.RoomID()) + servers := t.getServers(ctx, e.RoomID(), e) for _, server := range servers { var m gomatrixserverlib.RespMissingEvents if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ @@ -953,6 +1045,9 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Even break } else { logger.WithError(err).Errorf("%s pushed us an event but %q did not respond to /get_missing_events", t.Origin, server) + if errors.Is(err, context.DeadlineExceeded) { + break + } } } @@ -980,6 +1075,12 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Even // For now, we do not allow Case B, so reject the event. logger.Infof("get_missing_events returned %d events", len(missingResp.Events)) + // Make sure events from the missingResp are using the cache - missing events + // will be added and duplicates will be removed. + for i, ev := range missingResp.Events { + missingResp.Events[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + } + // topologically sort and sanity check that we are making forward progress newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents) shouldHaveSomeEventIDs := e.PrevEventIDs() @@ -1018,6 +1119,14 @@ func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID if err := state.Check(ctx, t.keys, nil); err != nil { return nil, err } + // Cache the results of this state lookup and deduplicate anything we already + // have in the cache, freeing up memory. + for i, ev := range state.AuthEvents { + state.AuthEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + } + for i, ev := range state.StateEvents { + state.StateEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + } return &state, nil } @@ -1033,6 +1142,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even wantIDs := append(stateIDs.StateEventIDs, stateIDs.AuthEventIDs...) missing := make(map[string]bool) var missingEventList []string + t.haveEventsMutex.Lock() for _, sid := range wantIDs { if _, ok := t.haveEvents[sid]; !ok { if !missing[sid] { @@ -1041,6 +1151,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even } } } + t.haveEventsMutex.Unlock() // fetch as many as we can from the roomserver queryReq := api.QueryEventsByIDRequest{ @@ -1050,9 +1161,10 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { return nil, err } - for i := range queryRes.Events { + for i, ev := range queryRes.Events { + queryRes.Events[i] = t.cacheAndReturn(queryRes.Events[i]) + t.hadEvents[ev.EventID()] = true evID := queryRes.Events[i].EventID() - t.cacheAndReturn(queryRes.Events[i]) if missing[evID] { delete(missing, evID) } @@ -1153,6 +1265,9 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs) ( *gomatrixserverlib.RespState, error) { // nolint:unparam + t.haveEventsMutex.Lock() + defer t.haveEventsMutex.Unlock() + // create a RespState response using the response to /state_ids as a guide respState := gomatrixserverlib.RespState{} @@ -1193,11 +1308,14 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib. } var event *gomatrixserverlib.Event found := false - servers := t.getServers(ctx, roomID) + servers := t.getServers(ctx, roomID, nil) for _, serverName := range servers { txn, err := t.federation.GetEvent(ctx, serverName, missingEventID) if err != nil || len(txn.PDUs) == 0 { util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("Failed to get missing /event for event ID") + if errors.Is(err, context.DeadlineExceeded) { + break + } continue } event, err = gomatrixserverlib.NewEventFromUntrustedJSON(txn.PDUs[0], roomVersion) @@ -1216,9 +1334,5 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib. util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) return nil, verifySigError{event.EventID(), err} } - h := event.Headered(roomVersion) - t.newEventsMutex.Lock() - t.newEvents[h.EventID()] = true - t.newEventsMutex.Unlock() - return h, nil + return t.cacheAndReturn(event.Headered(roomVersion)), nil } diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index b14cbd35a..98ff1a0a3 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -370,7 +370,7 @@ func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederat keys: &test.NopJSONVerifier{}, federation: fedClient, haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), - newEvents: make(map[string]bool), + hadEvents: make(map[string]bool), roomsMu: internal.NewMutexByRoom(), } t.PDUs = pdus diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 82ece2307..6bc43c9cd 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" + "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" "go.uber.org/atomic" ) @@ -38,8 +39,7 @@ type Inputer struct { ServerName gomatrixserverlib.ServerName ACLs *acls.ServerACLs OutputRoomEventTopic string - - workers sync.Map // room ID -> *inputWorker + workers sync.Map // room ID -> *inputWorker } type inputTask struct { @@ -52,7 +52,7 @@ type inputTask struct { type inputWorker struct { r *Inputer running atomic.Bool - input chan *inputTask + input *fifoQueue } // Guarded by a CAS on w.running @@ -60,7 +60,14 @@ func (w *inputWorker) start() { defer w.running.Store(false) for { select { - case task := <-w.input: + case <-w.input.wait(): + task, ok := w.input.pop() + if !ok { + continue + } + roomserverInputBackpressure.With(prometheus.Labels{ + "room_id": task.event.Event.RoomID(), + }).Dec() hooks.Run(hooks.KindNewEventReceived, task.event.Event) _, task.err = w.r.processRoomEvent(task.ctx, task.event) if task.err == nil { @@ -117,6 +124,20 @@ func (r *Inputer) WriteOutputEvents(roomID string, updates []api.OutputEvent) er return errs } +func init() { + prometheus.MustRegister(roomserverInputBackpressure) +} + +var roomserverInputBackpressure = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "dendrite", + Subsystem: "roomserver", + Name: "input_backpressure", + Help: "How many events are queued for input for a given room", + }, + []string{"room_id"}, +) + // InputRoomEvents implements api.RoomserverInternalAPI func (r *Inputer) InputRoomEvents( _ context.Context, @@ -143,7 +164,7 @@ func (r *Inputer) InputRoomEvents( // room - the channel will be quite small as it's just pointer types. w, _ := r.workers.LoadOrStore(roomID, &inputWorker{ r: r, - input: make(chan *inputTask, 32), + input: newFIFOQueue(), }) worker := w.(*inputWorker) @@ -160,7 +181,10 @@ func (r *Inputer) InputRoomEvents( if worker.running.CAS(false, true) { go worker.start() } - worker.input <- tasks[i] + worker.input.push(tasks[i]) + roomserverInputBackpressure.With(prometheus.Labels{ + "room_id": roomID, + }).Inc() } // Wait for all of the workers to return results about our tasks. diff --git a/roomserver/internal/input/input_fifo.go b/roomserver/internal/input/input_fifo.go new file mode 100644 index 000000000..694b17245 --- /dev/null +++ b/roomserver/internal/input/input_fifo.go @@ -0,0 +1,64 @@ +package input + +import ( + "sync" +) + +type fifoQueue struct { + tasks []*inputTask + count int + mutex sync.Mutex + notifs chan struct{} +} + +func newFIFOQueue() *fifoQueue { + q := &fifoQueue{ + notifs: make(chan struct{}, 1), + } + return q +} + +func (q *fifoQueue) push(frame *inputTask) { + q.mutex.Lock() + defer q.mutex.Unlock() + q.tasks = append(q.tasks, frame) + q.count++ + select { + case q.notifs <- struct{}{}: + default: + } +} + +// pop returns the first item of the queue, if there is one. +// The second return value will indicate if a task was returned. +// You must check this value, even after calling wait(). +func (q *fifoQueue) pop() (*inputTask, bool) { + q.mutex.Lock() + defer q.mutex.Unlock() + if q.count == 0 { + return nil, false + } + frame := q.tasks[0] + q.tasks[0] = nil + q.tasks = q.tasks[1:] + q.count-- + if q.count == 0 { + // Force a GC of the underlying array, since it might have + // grown significantly if the queue was hammered for some reason + q.tasks = nil + } + return frame, true +} + +// wait returns a channel which can be used to detect when an +// item is waiting in the queue. +func (q *fifoQueue) wait() <-chan struct{} { + q.mutex.Lock() + defer q.mutex.Unlock() + if q.count > 0 && len(q.notifs) == 0 { + ch := make(chan struct{}) + close(ch) + return ch + } + return q.notifs +} diff --git a/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go b/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go index 84da96149..d87ae052b 100644 --- a/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go +++ b/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go @@ -119,11 +119,15 @@ func UpStateBlocksRefactor(tx *sql.Tx) error { _roomserver_state_snapshots JOIN _roomserver_state_block ON _roomserver_state_block.state_block_nid = ANY (_roomserver_state_snapshots.state_block_nids) WHERE - _roomserver_state_snapshots.state_snapshot_nid = ANY ( SELECT DISTINCT + _roomserver_state_snapshots.state_snapshot_nid = ANY ( + SELECT _roomserver_state_snapshots.state_snapshot_nid FROM _roomserver_state_snapshots - LIMIT $1 OFFSET $2)) AS _roomserver_state_block + ORDER BY _roomserver_state_snapshots.state_snapshot_nid ASC + LIMIT $1 OFFSET $2 + ) + ) AS _roomserver_state_block GROUP BY state_snapshot_nid, room_nid, @@ -202,6 +206,23 @@ func UpStateBlocksRefactor(tx *sql.Tx) error { } } + // By this point we should have no more state_snapshot_nids below maxsnapshotid in either roomserver_rooms or roomserver_events + // If we do, this is a problem if Dendrite tries to load the snapshot as it will not exist + // in roomserver_state_snapshots + var count int64 + if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_events WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil { + return fmt.Errorf("assertion query failed: %s", err) + } + if count > 0 { + return fmt.Errorf("%d events exist in roomserver_events which have not been converted to a new state_snapshot_nid; this is a bug, please report", count) + } + if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil { + return fmt.Errorf("assertion query failed: %s", err) + } + if count > 0 { + return fmt.Errorf("%d rooms exist in roomserver_rooms which have not been converted to a new state_snapshot_nid; this is a bug, please report", count) + } + if _, err = tx.Exec(` DROP TABLE _roomserver_state_snapshots; DROP SEQUENCE roomserver_state_snapshot_nid_seq; diff --git a/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go b/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go index 3b93b3fa6..42edbbc6f 100644 --- a/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go +++ b/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go @@ -31,6 +31,7 @@ func LoadStateBlocksRefactor(m *sqlutil.Migrations) { m.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor) } +// nolint:gocyclo func UpStateBlocksRefactor(tx *sql.Tx) error { logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!") defer logrus.Warn("State storage upgrade complete") @@ -45,6 +46,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error { } maxsnapshotid++ maxblockid++ + oldMaxSnapshotID := maxsnapshotid if _, err := tx.Exec(`ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil { return fmt.Errorf("tx.Exec: %w", err) @@ -133,6 +135,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error { if jerr != nil { return fmt.Errorf("json.Marshal (new blocks): %w", jerr) } + var newsnapshot types.StateSnapshotNID err = tx.QueryRow(` INSERT INTO roomserver_state_snapshots (state_snapshot_nid, state_snapshot_hash, room_nid, state_block_nids) @@ -144,7 +147,8 @@ func UpStateBlocksRefactor(tx *sql.Tx) error { return fmt.Errorf("tx.QueryRow.Scan (insert new snapshot): %w", err) } maxsnapshotid++ - if _, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid); err != nil { + _, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid) + if err != nil { return fmt.Errorf("tx.Exec (update events): %w", err) } if _, err = tx.Exec(`UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid); err != nil { @@ -153,6 +157,23 @@ func UpStateBlocksRefactor(tx *sql.Tx) error { } } + // By this point we should have no more state_snapshot_nids below oldMaxSnapshotID in either roomserver_rooms or roomserver_events + // If we do, this is a problem if Dendrite tries to load the snapshot as it will not exist + // in roomserver_state_snapshots + var count int64 + if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_events WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil { + return fmt.Errorf("assertion query failed: %s", err) + } + if count > 0 { + return fmt.Errorf("%d events exist in roomserver_events which have not been converted to a new state_snapshot_nid; this is a bug, please report", count) + } + if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil { + return fmt.Errorf("assertion query failed: %s", err) + } + if count > 0 { + return fmt.Errorf("%d rooms exist in roomserver_rooms which have not been converted to a new state_snapshot_nid; this is a bug, please report", count) + } + if _, err = tx.Exec(`DROP TABLE _roomserver_state_snapshots;`); err != nil { return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err) } diff --git a/setup/monolith.go b/setup/monolith.go index a740ebb7f..235be4474 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -68,7 +68,7 @@ func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ss federationapi.AddPublicRoutes( ssMux, keyMux, &m.Config.FederationAPI, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationSenderAPI, - m.EDUInternalAPI, m.KeyAPI, &m.Config.MSCs, + m.EDUInternalAPI, m.KeyAPI, &m.Config.MSCs, nil, ) mediaapi.AddPublicRoutes(mediaMux, &m.Config.MediaAPI, m.UserAPI, m.Client) syncapi.AddPublicRoutes(