diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 87bdc5dbf..fc73d492a 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -19,12 +19,15 @@ import ( "context" "encoding/json" "sync" + "time" "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" + "go.uber.org/atomic" ) type Inputer struct { @@ -33,7 +36,43 @@ type Inputer struct { ServerName gomatrixserverlib.ServerName OutputRoomEventTopic string - mutexes sync.Map // room ID -> *sync.Mutex, protects calls to processRoomEvent + workers sync.Map // room ID -> *inputWorker +} + +type inputTask struct { + event api.InputRoomEvent + wg *sync.WaitGroup + eventID string // written back by worker + err error // written back by worker +} + +type inputWorker struct { + r *Inputer + running atomic.Bool + input chan *inputTask +} + +func (w *inputWorker) start() { + if !w.running.CAS(false, true) { + return + } + defer w.running.Store(false) + + logrus.Warn("STARTING WORKER") + defer logrus.Warn("SHUTTING DOWN WORKER") + + for { + select { + case task := <-w.input: + logrus.Warn("WORKER DOING TASK") + task.eventID, task.err = w.r.processRoomEvent(context.TODO(), task.event) + logrus.Warn("WORKER FINISHING TASK") + task.wg.Done() + logrus.Warn("WORKER FINISHED TASK") + case <-time.After(time.Second * 5): + return + } + } } // WriteOutputEvents implements OutputRoomEventWriter @@ -74,18 +113,50 @@ func (r *Inputer) InputRoomEvents( request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, ) (err error) { + wg := &sync.WaitGroup{} + wg.Add(len(request.InputRoomEvents)) + tasks := make([]*inputTask, len(request.InputRoomEvents)) + logrus.Warnf("Received %d input events", len(tasks)) + for i, e := range request.InputRoomEvents { + // Work out if we are running per-room workers or if we're just doing + // it on a global basis (e.g. SQLite). roomID := "global" if r.DB.SupportsConcurrentRoomInputs() { roomID = e.Event.RoomID() } - mutex, _ := r.mutexes.LoadOrStore(roomID, &sync.Mutex{}) - mutex.(*sync.Mutex).Lock() - if response.EventID, err = r.processRoomEvent(ctx, request.InputRoomEvents[i]); err != nil { - mutex.(*sync.Mutex).Unlock() - return err + + // Look up the worker, or create it if it doesn't exist. + w, _ := r.workers.LoadOrStore(roomID, &inputWorker{ + r: r, + input: make(chan *inputTask), + }) + worker := w.(*inputWorker) + + // Create a task. This contains the input event and a reference to + // the wait group, so that the worker can notify us when this specific + // task has been finished. + tasks[i] = &inputTask{ + event: e, + wg: wg, } - mutex.(*sync.Mutex).Unlock() + + // Send the task to the worker. + go func(task *inputTask) { worker.input <- task }(tasks[i]) + go worker.start() } + + logrus.Warnf("Waiting for %d task(s)", len(tasks)) + wg.Wait() + logrus.Warnf("Tasks finished") + + for _, task := range tasks { + if task.err != nil { + logrus.Warnf("Error: %w", task.err) + } else { + logrus.Warnf("Event ID: %s", task.eventID) + } + } + return nil }