dendrite/userapi/consumers/roomserver_test.go

253 lines
7.2 KiB
Go
Raw Permalink Normal View History

package consumers
import (
"context"
"reflect"
"sync"
"testing"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi/storage"
userAPITypes "github.com/matrix-org/dendrite/userapi/types"
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
db, err := storage.NewUserDatabase(context.Background(), cm, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "", 4, 0, 0, "")
if err != nil {
t.Fatalf("failed to create new user db: %v", err)
}
return db, func() {
close()
}
}
func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent {
t.Helper()
ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV10).NewEventFromTrustedJSON([]byte(content), false)
if err != nil {
t.Fatalf("failed to create event: %v", err)
}
return &types.HeaderedEvent{PDU: ev}
}
func Test_evaluatePushRules(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
consumer := OutputRoomEventConsumer{db: db}
testCases := []struct {
name string
eventContent string
wantAction pushrules.ActionKind
wantActions []*pushrules.Action
wantNotify bool
}{
{
name: "m.receipt doesn't notify",
eventContent: `{"type":"m.receipt"}`,
wantAction: pushrules.UnknownAction,
wantActions: nil,
},
{
name: "m.reaction doesn't notify",
eventContent: `{"type":"m.reaction"}`,
wantAction: pushrules.DontNotifyAction,
wantActions: []*pushrules.Action{
{
Kind: pushrules.DontNotifyAction,
},
},
},
{
name: "m.room.message notifies",
eventContent: `{"type":"m.room.message"}`,
wantNotify: true,
wantAction: pushrules.NotifyAction,
wantActions: []*pushrules.Action{
{Kind: pushrules.NotifyAction},
},
},
{
name: "m.room.message highlights",
eventContent: `{"type":"m.room.message", "content": {"body": "test"} }`,
wantNotify: true,
wantAction: pushrules.NotifyAction,
wantActions: []*pushrules.Action{
{Kind: pushrules.NotifyAction},
{
Kind: pushrules.SetTweakAction,
Tweak: pushrules.SoundTweak,
Value: "default",
},
{
Kind: pushrules.SetTweakAction,
Tweak: pushrules.HighlightTweak,
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actions, err := consumer.evaluatePushRules(ctx, mustCreateEvent(t, tc.eventContent), &localMembership{
UserID: "@test:localhost",
Localpart: "test",
Domain: "localhost",
}, 10)
if err != nil {
t.Fatalf("failed to evaluate push rules: %v", err)
}
assert.Equal(t, tc.wantActions, actions)
gotAction, _, err := pushrules.ActionsToTweaks(actions)
if err != nil {
t.Fatalf("failed to get actions: %v", err)
}
if gotAction != tc.wantAction {
t.Fatalf("expected action to be '%s', got '%s'", tc.wantAction, gotAction)
}
// this is taken from `notifyLocal`
if tc.wantNotify && gotAction != pushrules.NotifyAction && gotAction != pushrules.CoalesceAction {
t.Fatalf("expected to notify but didn't")
}
})
}
})
}
func TestMessageStats(t *testing.T) {
type args struct {
eventType string
eventSender string
roomID string
}
tests := []struct {
name string
args args
ourServer spec.ServerName
lastUpdate time.Time
initRoomCounts map[spec.ServerName]map[string]bool
wantStats userAPITypes.MessageStats
}{
{
name: "m.room.create does not count as a message",
ourServer: "localhost",
args: args{
eventType: "m.room.create",
eventSender: "@alice:localhost",
},
},
{
name: "our server - message",
ourServer: "localhost",
args: args{
eventType: "m.room.message",
eventSender: "@alice:localhost",
roomID: "normalRoom",
},
wantStats: userAPITypes.MessageStats{Messages: 1, SentMessages: 1},
},
{
name: "our server - E2EE message",
ourServer: "localhost",
args: args{
eventType: "m.room.encrypted",
eventSender: "@alice:localhost",
roomID: "encryptedRoom",
},
wantStats: userAPITypes.MessageStats{Messages: 1, SentMessages: 1, MessagesE2EE: 1, SentMessagesE2EE: 1},
},
{
name: "remote server - message",
ourServer: "localhost",
args: args{
eventType: "m.room.message",
eventSender: "@alice:remote",
roomID: "normalRoom",
},
wantStats: userAPITypes.MessageStats{Messages: 2, SentMessages: 1, MessagesE2EE: 1, SentMessagesE2EE: 1},
},
{
name: "remote server - E2EE message",
ourServer: "localhost",
args: args{
eventType: "m.room.encrypted",
eventSender: "@alice:remote",
roomID: "encryptedRoom",
},
wantStats: userAPITypes.MessageStats{Messages: 2, SentMessages: 1, MessagesE2EE: 2, SentMessagesE2EE: 1},
},
{
name: "day change creates a new room map",
ourServer: "localhost",
lastUpdate: time.Now().Add(-time.Hour * 24),
initRoomCounts: map[spec.ServerName]map[string]bool{
"localhost": {"encryptedRoom": true},
},
args: args{
eventType: "m.room.encrypted",
eventSender: "@alice:remote",
roomID: "someOtherRoom",
},
wantStats: userAPITypes.MessageStats{Messages: 2, SentMessages: 1, MessagesE2EE: 3, SentMessagesE2EE: 1},
},
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.lastUpdate.IsZero() {
tt.lastUpdate = time.Now()
}
if tt.initRoomCounts == nil {
tt.initRoomCounts = map[spec.ServerName]map[string]bool{}
}
s := &OutputRoomEventConsumer{
db: db,
msgCounts: map[spec.ServerName]userAPITypes.MessageStats{},
roomCounts: tt.initRoomCounts,
countsLock: sync.Mutex{},
lastUpdate: tt.lastUpdate,
serverName: tt.ourServer,
}
s.storeMessageStats(context.Background(), tt.args.eventType, tt.args.eventSender, tt.args.roomID)
t.Logf("%+v", s.roomCounts)
gotStats, activeRooms, activeE2EERooms, err := db.DailyRoomsMessages(context.Background(), tt.ourServer)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if !reflect.DeepEqual(gotStats, tt.wantStats) {
t.Fatalf("expected %+v, got %+v", tt.wantStats, gotStats)
}
if tt.args.eventType == "m.room.encrypted" && activeE2EERooms != 1 {
t.Fatalf("expected room to be activeE2EE")
}
if tt.args.eventType == "m.room.message" && activeRooms != 1 {
t.Fatalf("expected room to be active")
}
})
}
})
}