diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/consumer.go b/src/github.com/matrix-org/dendrite/roomserver/input/consumer.go index df9f796f2..b433d707f 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/consumer.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/consumer.go @@ -7,6 +7,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" sarama "gopkg.in/Shopify/sarama.v1" + "sync/atomic" ) // A ConsumerDatabase has the storage APIs needed by the consumer. @@ -47,10 +48,12 @@ type Consumer struct { // If left as nil then the consumer will panic when it encounters an error ErrorLogger ErrorLogger // If non-nil then the consumer will stop processing messages after this - // many messages and will shutdown - StopProcessingAfter *int + // many messages and will shutdown. Malformed messages are included in the count. + StopProcessingAfter *int64 // If not-nil then the consumer will call this to shutdown the server. ShutdownCallback func(reason string) + // How many messages the consumer has processed. + processed int64 } // WriteOutputRoomEvent implements OutputRoomEventWriter @@ -113,14 +116,7 @@ func (c *Consumer) Start() error { // consumePartition consumes the room events for a single partition of the kafkaesque stream. func (c *Consumer) consumePartition(pc sarama.PartitionConsumer) { defer pc.Close() - var processed int for message := range pc.Messages() { - if c.StopProcessingAfter != nil && processed >= *c.StopProcessingAfter { - if c.ShutdownCallback != nil { - c.ShutdownCallback(fmt.Sprintf("Stopping processing after %d messages", processed)) - } - return - } var input api.InputRoomEvent if err := json.Unmarshal(message.Value, &input); err != nil { // If the message is invalid then log it and move onto the next message in the stream. @@ -139,7 +135,19 @@ func (c *Consumer) consumePartition(pc sarama.PartitionConsumer) { if err := c.DB.SetPartitionOffset(c.InputRoomEventTopic, message.Partition, message.Offset); err != nil { c.logError(message, err) } - processed++ + // Update the number of processed messages using atomic addition because it is accessed from multiple goroutines. + processed := atomic.AddInt64(&c.processed, 1) + // Check if we should stop processing. + // Note that since we have multiple goroutines it's quite likely that we'll overshoot by a few messages. + // If we try to stop processing after M message and we have N goroutines then we will process somewhere + // between M and (N + M) messages because the N goroutines could all try to process what they think will be the + // last message. We could be more careful here but this is good enough for getting rough benchmarks. + if c.StopProcessingAfter != nil && processed >= int64(*c.StopProcessingAfter) { + if c.ShutdownCallback != nil { + c.ShutdownCallback(fmt.Sprintf("Stopping processing after %d messages", c.processed)) + } + return + } } } diff --git a/src/github.com/matrix-org/dendrite/roomserver/roomserver/roomserver.go b/src/github.com/matrix-org/dendrite/roomserver/roomserver/roomserver.go index bfcaea146..e865b5bea 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/roomserver/roomserver.go +++ b/src/github.com/matrix-org/dendrite/roomserver/roomserver/roomserver.go @@ -19,7 +19,10 @@ var ( inputRoomEventTopic = os.Getenv("TOPIC_INPUT_ROOM_EVENT") outputRoomEventTopic = os.Getenv("TOPIC_OUTPUT_ROOM_EVENT") bindAddr = os.Getenv("BIND_ADDRESS") - stopProcessingAfter = os.Getenv("STOP_AFTER") + // Shuts the roomserver down after processing a given number of messages. + // This is useful for running benchmarks for seeing how quickly the server + // can process a given number of messages. + stopProcessingAfter = os.Getenv("STOP_AFTER") ) func main() { @@ -47,7 +50,7 @@ func main() { } if stopProcessingAfter != "" { - count, err := strconv.Atoi(stopProcessingAfter) + count, err := strconv.ParseInt(stopProcessingAfter, 10, 64) if err != nil { panic(err) }