diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index ab16a5293..5d15c732f 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -95,30 +95,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms return true } - _, sender, _ := gomatrixserverlib.SplitID('@', event.Sender()) - - switch event.Type() { - case "m.room.message": - s.msgCountsLock.Lock() - msgCount := s.msgCounts[s.serverName] - msgCount.Messages++ - if sender == s.serverName { - msgCount.SentMessages++ - } - s.msgCounts[s.serverName] = msgCount - s.msgCountsLock.Unlock() - case "m.room.encrypted": - s.msgCountsLock.Lock() - msgCount := s.msgCounts[s.serverName] - msgCount.MessagesE2EE++ - if sender == s.serverName { - msgCount.SentMessagesE2EE++ - } - s.msgCounts[s.serverName] = msgCount - s.msgCountsLock.Unlock() - } - - s.storeMessageStats(ctx) + s.storeMessageStats(ctx, event.Type(), event.Sender()) log.WithFields(log.Fields{ "event_id": event.EventID(), @@ -139,17 +116,30 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms return true } -func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context) { +func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventType, eventSender string) { s.msgCountsLock.Lock() defer s.msgCountsLock.Unlock() - var sumStats int64 = 0 - for _, stats := range s.msgCounts { - sumStats += stats.Messages + stats.SentMessages + stats.MessagesE2EE + stats.SentMessagesE2EE - } - // Nothing to do - if sumStats == 0 { + + _, sender, err := gomatrixserverlib.SplitID('@', eventSender) + if err != nil { return } + msgCount := s.msgCounts[s.serverName] + switch eventType { + case "m.room.message": + msgCount.Messages++ + if sender == s.serverName { + msgCount.SentMessages++ + } + case "m.room.encrypted": + msgCount.MessagesE2EE++ + if sender == s.serverName { + msgCount.SentMessagesE2EE++ + } + default: + return + } + s.msgCounts[s.serverName] = msgCount for serverName, stats := range s.msgCounts { err := s.db.UpsertDailyMessages(ctx, serverName, stats) if err != nil {