diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 3bbeb439a..13639a5e3 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -2,6 +2,8 @@ package consumers import ( "context" + "reflect" + "sync" "testing" "github.com/matrix-org/gomatrixserverlib" @@ -11,6 +13,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi/storage" + userAPITypes "github.com/matrix-org/dendrite/userapi/types" ) func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { @@ -127,3 +130,87 @@ func Test_evaluatePushRules(t *testing.T) { } }) } + +func TestMessageStats(t *testing.T) { + type args struct { + eventType string + eventSender string + } + tests := []struct { + name string + args args + ourServer gomatrixserverlib.ServerName + wantStats userAPITypes.MessageStats + }{ + { + name: "m.room.create does not count as a message", + ourServer: "localhost", + args: args{ + eventType: "m.room.create", + eventSender: "@alice:localhost", + }, + }, + { + name: "our server - message", + ourServer: "localhost", + args: args{ + eventType: "m.room.message", + eventSender: "@alice:localhost", + }, + wantStats: userAPITypes.MessageStats{Messages: 1, SentMessages: 1}, + }, + { + name: "our server - E2EE message", + ourServer: "localhost", + args: args{ + eventType: "m.room.encrypted", + eventSender: "@alice:localhost", + }, + wantStats: userAPITypes.MessageStats{Messages: 1, SentMessages: 1, MessagesE2EE: 1, SentMessagesE2EE: 1}, + }, + + { + name: "remote server - message", + ourServer: "localhost", + args: args{ + eventType: "m.room.message", + eventSender: "@alice:remote", + }, + wantStats: userAPITypes.MessageStats{Messages: 2, SentMessages: 1, MessagesE2EE: 1, SentMessagesE2EE: 1}, + }, + { + name: "remote server - E2EE message", + ourServer: "localhost", + args: args{ + eventType: "m.room.encrypted", + eventSender: "@alice:remote", + }, + wantStats: userAPITypes.MessageStats{Messages: 2, SentMessages: 1, MessagesE2EE: 2, SentMessagesE2EE: 1}, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &OutputRoomEventConsumer{ + db: db, + msgCounts: map[gomatrixserverlib.ServerName]userAPITypes.MessageStats{}, + msgCountsLock: sync.Mutex{}, + serverName: tt.ourServer, + } + s.storeMessageStats(context.Background(), tt.args.eventType, tt.args.eventSender) + + gotStats, err := db.DailyMessages(context.Background(), tt.ourServer) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if !reflect.DeepEqual(gotStats, tt.wantStats) { + t.Fatalf("expected %+v, got %+v", tt.wantStats, gotStats) + } + }) + } + }) +}