diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 786d4f31f..5a67a1bed 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -1,12 +1,15 @@ package roomserver import ( + "bytes" "context" + "crypto/ed25519" "encoding/json" "fmt" "os" "reflect" "testing" + "time" "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal/caching" @@ -80,7 +83,73 @@ func deleteDatabase() { } } -func mustLoadEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) []gomatrixserverlib.HeaderedEvent { +type fledglingEvent struct { + Type string + StateKey *string + Content interface{} + Sender string + RoomID string +} + +func mustCreateEvents(t *testing.T, roomVer gomatrixserverlib.RoomVersion, events []fledglingEvent) (result []gomatrixserverlib.HeaderedEvent) { + t.Helper() + depth := int64(1) + seed := make([]byte, ed25519.SeedSize) // zero seed + key := ed25519.NewKeyFromSeed(seed) + var prevs []string + roomState := make(map[gomatrixserverlib.StateKeyTuple]string) // state -> event ID + for _, ev := range events { + eb := gomatrixserverlib.EventBuilder{ + Sender: ev.Sender, + Depth: depth, + Type: ev.Type, + StateKey: ev.StateKey, + RoomID: ev.RoomID, + PrevEvents: prevs, + } + err := eb.SetContent(ev.Content) + if err != nil { + t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content) + } + stateNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(&eb) + if err != nil { + t.Fatalf("mustCreateEvent: failed to work out auth_events : %s", err) + } + var authEvents []string + for _, tuple := range stateNeeded.Tuples() { + eventID := roomState[tuple] + if eventID != "" { + authEvents = append(authEvents, eventID) + } + } + eb.AuthEvents = authEvents + signedEvent, err := eb.Build(time.Now(), testOrigin, "ed25519:test", key, roomVer) + if err != nil { + t.Fatalf("mustCreateEvent: failed to sign event: %s", err) + } + depth++ + prevs = []string{signedEvent.EventID()} + if ev.StateKey != nil { + roomState[gomatrixserverlib.StateKeyTuple{ + EventType: ev.Type, + StateKey: *ev.StateKey, + }] = signedEvent.EventID() + } + result = append(result, signedEvent.Headered(roomVer)) + } + return +} + +func eventsJSON(events []gomatrixserverlib.Event) []json.RawMessage { + result := make([]json.RawMessage, len(events)) + for i := range events { + result[i] = events[i].JSON() + } + return result +} + +func mustLoadRawEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) []gomatrixserverlib.HeaderedEvent { + t.Helper() hs := make([]gomatrixserverlib.HeaderedEvent, len(events)) for i := range events { e, err := gomatrixserverlib.NewEventFromTrustedJSON(events[i], false, ver) @@ -93,7 +162,8 @@ func mustLoadEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []js return hs } -func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) (api.RoomserverInternalAPI, *dummyProducer, []gomatrixserverlib.HeaderedEvent) { +func mustCreateRoomserverAPI(t *testing.T) (api.RoomserverInternalAPI, *dummyProducer) { + t.Helper() cfg := &config.Dendrite{} cfg.Defaults() cfg.Global.ServerName = testOrigin @@ -112,9 +182,14 @@ func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []js Cfg: cfg, } - rsAPI := NewInternalAPI(base, &test.NopJSONVerifier{}) - hevents := mustLoadEvents(t, ver, events) - if err = api.SendEvents(ctx, rsAPI, hevents, testOrigin, nil); err != nil { + return NewInternalAPI(base, &test.NopJSONVerifier{}), dp +} + +func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) (api.RoomserverInternalAPI, *dummyProducer, []gomatrixserverlib.HeaderedEvent) { + t.Helper() + rsAPI, dp := mustCreateRoomserverAPI(t) + hevents := mustLoadRawEvents(t, ver, events) + if err := api.SendEvents(ctx, rsAPI, hevents, testOrigin, nil); err != nil { t.Errorf("failed to SendEvents: %s", err) } return rsAPI, dp, hevents @@ -170,3 +245,163 @@ func TestOutputRedactedEvent(t *testing.T) { } } } + +// This tests that rewriting state via KindRewrite works correctly. +// This creates a small room with a create/join/name state, then replays it +// with a new room name. We expect the output events to contain the original events, +// followed by a single OutputNewRoomEvent with RewritesState set to true with the +// rewritten state events (with the 2nd room name). +func TestOutputRewritesState(t *testing.T) { + roomID := "!foo:" + string(testOrigin) + alice := "@alice:" + string(testOrigin) + emptyKey := "" + originalEvents := mustCreateEvents(t, gomatrixserverlib.RoomVersionV6, []fledglingEvent{ + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "creator": alice, + "room_version": "6", + }, + StateKey: &emptyKey, + Type: gomatrixserverlib.MRoomCreate, + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "membership": "join", + }, + StateKey: &alice, + Type: gomatrixserverlib.MRoomMember, + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "body": "hello world", + }, + StateKey: nil, + Type: "m.room.message", + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "name": "Room Name", + }, + StateKey: &emptyKey, + Type: "m.room.name", + }, + }) + rewriteEvents := mustCreateEvents(t, gomatrixserverlib.RoomVersionV6, []fledglingEvent{ + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "creator": alice, + }, + StateKey: &emptyKey, + Type: gomatrixserverlib.MRoomCreate, + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "membership": "join", + }, + StateKey: &alice, + Type: gomatrixserverlib.MRoomMember, + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "name": "Room Name 2", + }, + StateKey: &emptyKey, + Type: "m.room.name", + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "body": "hello world 2", + }, + StateKey: nil, + Type: "m.room.message", + }, + }) + deleteDatabase() + rsAPI, producer := mustCreateRoomserverAPI(t) + defer deleteDatabase() + err := api.SendEvents(context.Background(), rsAPI, originalEvents, testOrigin, nil) + if err != nil { + t.Fatalf("failed to send original events: %s", err) + } + // assert we got them produced, this is just a sanity check and isn't the intention of this test + if len(producer.producedMessages) != len(originalEvents) { + t.Fatalf("SendEvents didn't result in same number of produced output events: got %d want %d", len(producer.producedMessages), len(originalEvents)) + } + producer.producedMessages = nil // we aren't actually interested in these events, just the rewrite ones + + var inputEvents []api.InputRoomEvent + // slowly build up the state IDs again, we're basically telling the roomserver what to store as a snapshot + var stateIDs []string + // skip the last event, we'll use this to tie together the rewrite as the KindNew event + for i := 0; i < len(rewriteEvents)-1; i++ { + ev := rewriteEvents[i] + inputEvents = append(inputEvents, api.InputRoomEvent{ + Kind: api.KindRewrite, + Event: ev, + AuthEventIDs: ev.AuthEventIDs(), + HasState: true, + StateEventIDs: stateIDs, + }) + if ev.StateKey() != nil { + stateIDs = append(stateIDs, ev.EventID()) + } + } + lastEv := rewriteEvents[len(rewriteEvents)-1] + inputEvents = append(inputEvents, api.InputRoomEvent{ + Kind: api.KindNew, + Event: lastEv, + AuthEventIDs: lastEv.AuthEventIDs(), + HasState: true, + StateEventIDs: stateIDs, + }) + if err := api.SendInputRoomEvents(context.Background(), rsAPI, inputEvents); err != nil { + t.Fatalf("SendInputRoomEvents returned error for rewrite events: %s", err) + } + // we should just have one output event with the entire state of the room in it + if len(producer.producedMessages) != 1 { + t.Fatalf("Rewritten events got output, want only 1 got %d", len(producer.producedMessages)) + } + outputEvent := producer.producedMessages[0] + if !outputEvent.NewRoomEvent.RewritesState { + t.Errorf("RewritesState flag not set on output event") + } + if !reflect.DeepEqual(stateIDs, outputEvent.NewRoomEvent.AddsStateEventIDs) { + t.Errorf("Output event is missing room state event IDs, got %v want %v", outputEvent.NewRoomEvent.AddsStateEventIDs, stateIDs) + } + if !bytes.Equal(outputEvent.NewRoomEvent.Event.JSON(), lastEv.JSON()) { + t.Errorf( + "Output event isn't the latest KindNew event:\ngot %s\nwant %s", + string(outputEvent.NewRoomEvent.Event.JSON()), + string(lastEv.JSON()), + ) + } + if len(outputEvent.NewRoomEvent.AddStateEvents) != len(stateIDs) { + t.Errorf("Output event is missing room state events themselves, got %d want %d", len(outputEvent.NewRoomEvent.AddStateEvents), len(stateIDs)) + } + // make sure the state got overwritten, check the room name + hasRoomName := false + for _, ev := range outputEvent.NewRoomEvent.AddStateEvents { + if ev.Type() == "m.room.name" { + hasRoomName = string(ev.Content()) == `{"name":"Room Name 2"}` + } + } + if !hasRoomName { + t.Errorf("Output event did not overwrite room state") + } +}