diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go index c5e281aae..202c185f1 100644 --- a/syncapi/synctypes/clientevent_test.go +++ b/syncapi/synctypes/clientevent_test.go @@ -29,6 +29,58 @@ import ( const testSenderID = "testSenderID" const testUserID = "@test:localhost" +type EventFieldsToVerify struct { + EventID string + Type string + OriginServerTS spec.Timestamp + StateKey *string + Content spec.RawJSON + Unsigned spec.RawJSON + Sender string + Depth int64 + PrevEvents []string + AuthEvents []string +} + +func verifyEventFields(t *testing.T, got EventFieldsToVerify, want EventFieldsToVerify) { + if got.EventID != want.EventID { + t.Errorf("ClientEvent.EventID: wanted %s, got %s", want.EventID, got.EventID) + } + if got.OriginServerTS != want.OriginServerTS { + t.Errorf("ClientEvent.OriginServerTS: wanted %d, got %d", want.OriginServerTS, got.OriginServerTS) + } + if got.StateKey == nil && want.StateKey != nil { + t.Errorf("ClientEvent.StateKey: no state key present when one was wanted: %s", *want.StateKey) + } + if got.StateKey != nil && want.StateKey == nil { + t.Errorf("ClientEvent.StateKey: state key present when one was not wanted: %s", *got.StateKey) + } + if got.StateKey != nil && want.StateKey != nil && *got.StateKey != *want.StateKey { + t.Errorf("ClientEvent.StateKey: wanted %s, got %s", *want.StateKey, *got.StateKey) + } + if got.Type != want.Type { + t.Errorf("ClientEvent.Type: wanted %s, got %s", want.Type, got.Type) + } + if !bytes.Equal(got.Content, want.Content) { + t.Errorf("ClientEvent.Content: wanted %s, got %s", string(want.Content), string(got.Content)) + } + if !bytes.Equal(got.Unsigned, want.Unsigned) { + t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(want.Unsigned), string(got.Unsigned)) + } + if got.Sender != want.Sender { + t.Errorf("ClientEvent.Sender: wanted %s, got %s", want.Sender, got.Sender) + } + if got.Depth != want.Depth { + t.Errorf("ClientEvent.Depth: wanted %d, got %d", want.Depth, got.Depth) + } + if !reflect.DeepEqual(got.PrevEvents, want.PrevEvents) { + t.Errorf("ClientEvent.PrevEvents: wanted %v, got %v", want.PrevEvents, got.PrevEvents) + } + if !reflect.DeepEqual(got.AuthEvents, want.AuthEvents) { + t.Errorf("ClientEvent.AuthEvents: wanted %v, got %v", want.AuthEvents, got.AuthEvents) + } +} + func TestToClientEvent(t *testing.T) { // nolint: gocyclo ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV1).NewEventFromTrustedJSON([]byte(`{ "type": "m.room.name", @@ -55,27 +107,27 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo } sk := "" ce := ToClientEvent(ev, FormatAll, userID.String(), &sk, ev.Unsigned()) - if ce.EventID != ev.EventID() { - t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) - } - if ce.OriginServerTS != ev.OriginServerTS() { - t.Errorf("ClientEvent.OriginServerTS: wanted %d, got %d", ev.OriginServerTS(), ce.OriginServerTS) - } - if ce.StateKey == nil || *ce.StateKey != "" { - t.Errorf("ClientEvent.StateKey: wanted '', got %v", ce.StateKey) - } - if ce.Type != ev.Type() { - t.Errorf("ClientEvent.Type: wanted %s, got %s", ev.Type(), ce.Type) - } - if !bytes.Equal(ce.Content, ev.Content()) { - t.Errorf("ClientEvent.Content: wanted %s, got %s", string(ev.Content()), string(ce.Content)) - } - if !bytes.Equal(ce.Unsigned, ev.Unsigned()) { - t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) - } - if ce.Sender != userID.String() { - t.Errorf("ClientEvent.Sender: wanted %s, got %s", userID.String(), ce.Sender) - } + + verifyEventFields(t, + EventFieldsToVerify{ + EventID: ce.EventID, + Type: ce.Type, + OriginServerTS: ce.OriginServerTS, + StateKey: ce.StateKey, + Content: ce.Content, + Unsigned: ce.Unsigned, + Sender: ce.Sender, + }, + EventFieldsToVerify{ + EventID: ev.EventID(), + Type: ev.Type(), + OriginServerTS: ev.OriginServerTS(), + StateKey: &sk, + Content: ev.Content(), + Unsigned: ev.Unsigned(), + Sender: userID.String(), + }) + j, err := json.Marshal(ce) if err != nil { t.Fatalf("failed to Marshal ClientEvent: %s", err) @@ -155,36 +207,32 @@ func TestToClientEventFormatSyncFederation(t *testing.T) { // nolint: gocyclo } sk := "" ce := ToClientEvent(ev, FormatSyncFederation, userID.String(), &sk, ev.Unsigned()) - if ce.EventID != ev.EventID() { - t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) - } - if ce.OriginServerTS != ev.OriginServerTS() { - t.Errorf("ClientEvent.OriginServerTS: wanted %d, got %d", ev.OriginServerTS(), ce.OriginServerTS) - } - if ce.StateKey == nil || *ce.StateKey != "" { - t.Errorf("ClientEvent.StateKey: wanted '', got %v", ce.StateKey) - } - if ce.Type != ev.Type() { - t.Errorf("ClientEvent.Type: wanted %s, got %s", ev.Type(), ce.Type) - } - if !bytes.Equal(ce.Content, ev.Content()) { - t.Errorf("ClientEvent.Content: wanted %s, got %s", string(ev.Content()), string(ce.Content)) - } - if !bytes.Equal(ce.Unsigned, ev.Unsigned()) { - t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) - } - if ce.Sender != userID.String() { - t.Errorf("ClientEvent.Sender: wanted %s, got %s", userID.String(), ce.Sender) - } - if ce.Depth != ev.Depth() { - t.Errorf("ClientEvent.Depth: wanted %d, got %d", ev.Depth(), ce.Depth) - } - if !reflect.DeepEqual(ce.PrevEvents, ev.PrevEventIDs()) { - t.Errorf("ClientEvent.PrevEvents: wanted %v, got %v", ev.PrevEventIDs(), ce.PrevEvents) - } - if !reflect.DeepEqual(ce.AuthEvents, ev.AuthEventIDs()) { - t.Errorf("ClientEvent.AuthEvents: wanted %v, got %v", ev.AuthEventIDs(), ce.AuthEvents) - } + + verifyEventFields(t, + EventFieldsToVerify{ + EventID: ce.EventID, + Type: ce.Type, + OriginServerTS: ce.OriginServerTS, + StateKey: ce.StateKey, + Content: ce.Content, + Unsigned: ce.Unsigned, + Sender: ce.Sender, + Depth: ce.Depth, + PrevEvents: ce.PrevEvents, + AuthEvents: ce.AuthEvents, + }, + EventFieldsToVerify{ + EventID: ev.EventID(), + Type: ev.Type(), + OriginServerTS: ev.OriginServerTS(), + StateKey: &sk, + Content: ev.Content(), + Unsigned: ev.Unsigned(), + Sender: userID.String(), + Depth: ev.Depth(), + PrevEvents: ev.PrevEventIDs(), + AuthEvents: ev.AuthEventIDs(), + }) } func userIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { @@ -255,68 +303,59 @@ func TestToClientEventsFormatSyncFederation(t *testing.T) { // nolint: gocyclo clientEvents := ToClientEvents([]gomatrixserverlib.PDU{ev, ev2}, FormatSyncFederation, userIDForSender) ce := clientEvents[0] - if ce.EventID != ev.EventID() { - t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) - } - if ce.OriginServerTS != ev.OriginServerTS() { - t.Errorf("ClientEvent.OriginServerTS: wanted %d, got %d", ev.OriginServerTS(), ce.OriginServerTS) - } - if ce.StateKey == nil || *ce.StateKey != testSenderID { - t.Errorf("ClientEvent.StateKey: wanted %s, got %v", testSenderID, ce.StateKey) - } - if ce.Type != ev.Type() { - t.Errorf("ClientEvent.Type: wanted %s, got %s", ev.Type(), ce.Type) - } - if !bytes.Equal(ce.Content, ev.Content()) { - t.Errorf("ClientEvent.Content: wanted %s, got %s", string(ev.Content()), string(ce.Content)) - } - if !bytes.Equal(ce.Unsigned, ev.Unsigned()) { - t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) - } - if ce.Sender != testSenderID { - t.Errorf("ClientEvent.Sender: wanted %s, got %s", testSenderID, ce.Sender) - } - if ce.Depth != ev.Depth() { - t.Errorf("ClientEvent.Depth: wanted %d, got %d", ev.Depth(), ce.Depth) - } - if !reflect.DeepEqual(ce.PrevEvents, ev.PrevEventIDs()) { - t.Errorf("ClientEvent.PrevEvents: wanted %v, got %v", ev.PrevEventIDs(), ce.PrevEvents) - } - if !reflect.DeepEqual(ce.AuthEvents, ev.AuthEventIDs()) { - t.Errorf("ClientEvent.AuthEvents: wanted %v, got %v", ev.AuthEventIDs(), ce.AuthEvents) - } + sk := testSenderID + verifyEventFields(t, + EventFieldsToVerify{ + EventID: ce.EventID, + Type: ce.Type, + OriginServerTS: ce.OriginServerTS, + StateKey: ce.StateKey, + Content: ce.Content, + Unsigned: ce.Unsigned, + Sender: ce.Sender, + Depth: ce.Depth, + PrevEvents: ce.PrevEvents, + AuthEvents: ce.AuthEvents, + }, + EventFieldsToVerify{ + EventID: ev.EventID(), + Type: ev.Type(), + OriginServerTS: ev.OriginServerTS(), + StateKey: &sk, + Content: ev.Content(), + Unsigned: ev.Unsigned(), + Sender: testSenderID, + Depth: ev.Depth(), + PrevEvents: ev.PrevEventIDs(), + AuthEvents: ev.AuthEventIDs(), + }) ce2 := clientEvents[1] - if ce2.EventID != ev2.EventID() { - t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev2.EventID(), ce2.EventID) - } - if ce2.OriginServerTS != ev2.OriginServerTS() { - t.Errorf("ClientEvent.OriginServerTS: wanted %d, got %d", ev2.OriginServerTS(), ce2.OriginServerTS) - } - if ce2.StateKey == nil || *ce.StateKey != testSenderID { - t.Errorf("ClientEvent.StateKey: wanted %s, got %v", testSenderID, ce2.StateKey) - } - if ce2.Type != ev2.Type() { - t.Errorf("ClientEvent.Type: wanted %s, got %s", ev2.Type(), ce2.Type) - } - if !bytes.Equal(ce2.Content, ev2.Content()) { - t.Errorf("ClientEvent.Content: wanted %s, got %s", string(ev2.Content()), string(ce2.Content)) - } - if !bytes.Equal(ce2.Unsigned, ev2.Unsigned()) { - t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev2.Unsigned()), string(ce2.Unsigned)) - } - if ce2.Sender != testSenderID { - t.Errorf("ClientEvent.Sender: wanted %s, got %s", testSenderID, ce2.Sender) - } - if ce2.Depth != ev2.Depth() { - t.Errorf("ClientEvent.Depth: wanted %d, got %d", ev2.Depth(), ce2.Depth) - } - if !reflect.DeepEqual(ce2.PrevEvents, ev2.PrevEventIDs()) { - t.Errorf("ClientEvent.PrevEvents: wanted %v, got %v", ev2.PrevEventIDs(), ce2.PrevEvents) - } - if !reflect.DeepEqual(ce2.AuthEvents, ev2.AuthEventIDs()) { - t.Errorf("ClientEvent.AuthEvents: wanted %v, got %v", ev2.AuthEventIDs(), ce2.AuthEvents) - } + verifyEventFields(t, + EventFieldsToVerify{ + EventID: ce2.EventID, + Type: ce2.Type, + OriginServerTS: ce2.OriginServerTS, + StateKey: ce2.StateKey, + Content: ce2.Content, + Unsigned: ce2.Unsigned, + Sender: ce2.Sender, + Depth: ce2.Depth, + PrevEvents: ce2.PrevEvents, + AuthEvents: ce2.AuthEvents, + }, + EventFieldsToVerify{ + EventID: ev2.EventID(), + Type: ev2.Type(), + OriginServerTS: ev2.OriginServerTS(), + StateKey: &sk, + Content: ev2.Content(), + Unsigned: ev2.Unsigned(), + Sender: testSenderID, + Depth: ev2.Depth(), + PrevEvents: ev2.PrevEventIDs(), + AuthEvents: ev2.AuthEventIDs(), + }) } func TestToClientEventsFormatSync(t *testing.T) { // nolint: gocyclo @@ -363,27 +402,26 @@ func TestToClientEventsFormatSync(t *testing.T) { // nolint: gocyclo clientEvents := ToClientEvents([]gomatrixserverlib.PDU{ev, ev2}, FormatSync, userIDForSender) ce := clientEvents[0] - if ce.EventID != ev.EventID() { - t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) - } - if ce.OriginServerTS != ev.OriginServerTS() { - t.Errorf("ClientEvent.OriginServerTS: wanted %d, got %d", ev.OriginServerTS(), ce.OriginServerTS) - } - if ce.StateKey == nil || *ce.StateKey != testUserID { - t.Errorf("ClientEvent.StateKey: wanted %s, got %v", testUserID, ce.StateKey) - } - if ce.Type != ev.Type() { - t.Errorf("ClientEvent.Type: wanted %s, got %s", ev.Type(), ce.Type) - } - if !bytes.Equal(ce.Content, ev.Content()) { - t.Errorf("ClientEvent.Content: wanted %s, got %s", string(ev.Content()), string(ce.Content)) - } - if !bytes.Equal(ce.Unsigned, ev.Unsigned()) { - t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) - } - if ce.Sender != testUserID { - t.Errorf("ClientEvent.Sender: wanted %s, got %s", testUserID, ce.Sender) - } + sk := testUserID + verifyEventFields(t, + EventFieldsToVerify{ + EventID: ce.EventID, + Type: ce.Type, + OriginServerTS: ce.OriginServerTS, + StateKey: ce.StateKey, + Content: ce.Content, + Unsigned: ce.Unsigned, + Sender: ce.Sender, + }, + EventFieldsToVerify{ + EventID: ev.EventID(), + Type: ev.Type(), + OriginServerTS: ev.OriginServerTS(), + StateKey: &sk, + Content: ev.Content(), + Unsigned: ev.Unsigned(), + Sender: testUserID, + }) var prev PrevEventRef prev.PrevContent = []byte(`{"name": "Goodbye World 2"}`) @@ -391,27 +429,25 @@ func TestToClientEventsFormatSync(t *testing.T) { // nolint: gocyclo expectedUnsigned, _ := json.Marshal(prev) ce2 := clientEvents[1] - if ce2.EventID != ev2.EventID() { - t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev2.EventID(), ce2.EventID) - } - if ce2.OriginServerTS != ev2.OriginServerTS() { - t.Errorf("ClientEvent.OriginServerTS: wanted %d, got %d", ev2.OriginServerTS(), ce2.OriginServerTS) - } - if ce2.StateKey == nil || *ce.StateKey != testUserID { - t.Errorf("ClientEvent.StateKey: wanted %s, got %v", testUserID, ce2.StateKey) - } - if ce2.Type != ev2.Type() { - t.Errorf("ClientEvent.Type: wanted %s, got %s", ev2.Type(), ce2.Type) - } - if !bytes.Equal(ce2.Content, ev2.Content()) { - t.Errorf("ClientEvent.Content: wanted %s, got %s", string(ev2.Content()), string(ce2.Content)) - } - if !bytes.Equal(ce2.Unsigned, expectedUnsigned) { - t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(expectedUnsigned), string(ce2.Unsigned)) - } - if ce2.Sender != testUserID { - t.Errorf("ClientEvent.Sender: wanted %s, got %s", testUserID, ce2.Sender) - } + verifyEventFields(t, + EventFieldsToVerify{ + EventID: ce2.EventID, + Type: ce2.Type, + OriginServerTS: ce2.OriginServerTS, + StateKey: ce2.StateKey, + Content: ce2.Content, + Unsigned: ce2.Unsigned, + Sender: ce2.Sender, + }, + EventFieldsToVerify{ + EventID: ev2.EventID(), + Type: ev2.Type(), + OriginServerTS: ev2.OriginServerTS(), + StateKey: &sk, + Content: ev2.Content(), + Unsigned: expectedUnsigned, + Sender: testUserID, + }) } func TestToClientEventsFormatSyncUnknownPrevSender(t *testing.T) { // nolint: gocyclo @@ -458,27 +494,26 @@ func TestToClientEventsFormatSyncUnknownPrevSender(t *testing.T) { // nolint: go clientEvents := ToClientEvents([]gomatrixserverlib.PDU{ev, ev2}, FormatSync, userIDForSender) ce := clientEvents[0] - if ce.EventID != ev.EventID() { - t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) - } - if ce.OriginServerTS != ev.OriginServerTS() { - t.Errorf("ClientEvent.OriginServerTS: wanted %d, got %d", ev.OriginServerTS(), ce.OriginServerTS) - } - if ce.StateKey == nil || *ce.StateKey != testUserID { - t.Errorf("ClientEvent.StateKey: wanted %s, got %v", testUserID, ce.StateKey) - } - if ce.Type != ev.Type() { - t.Errorf("ClientEvent.Type: wanted %s, got %s", ev.Type(), ce.Type) - } - if !bytes.Equal(ce.Content, ev.Content()) { - t.Errorf("ClientEvent.Content: wanted %s, got %s", string(ev.Content()), string(ce.Content)) - } - if !bytes.Equal(ce.Unsigned, ev.Unsigned()) { - t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) - } - if ce.Sender != testUserID { - t.Errorf("ClientEvent.Sender: wanted %s, got %s", testUserID, ce.Sender) - } + sk := testUserID + verifyEventFields(t, + EventFieldsToVerify{ + EventID: ce.EventID, + Type: ce.Type, + OriginServerTS: ce.OriginServerTS, + StateKey: ce.StateKey, + Content: ce.Content, + Unsigned: ce.Unsigned, + Sender: ce.Sender, + }, + EventFieldsToVerify{ + EventID: ev.EventID(), + Type: ev.Type(), + OriginServerTS: ev.OriginServerTS(), + StateKey: &sk, + Content: ev.Content(), + Unsigned: ev.Unsigned(), + Sender: testUserID, + }) var prev PrevEventRef prev.PrevContent = []byte(`{"name": "Goodbye World 2"}`) @@ -486,25 +521,23 @@ func TestToClientEventsFormatSyncUnknownPrevSender(t *testing.T) { // nolint: go expectedUnsigned, _ := json.Marshal(prev) ce2 := clientEvents[1] - if ce2.EventID != ev2.EventID() { - t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev2.EventID(), ce2.EventID) - } - if ce2.OriginServerTS != ev2.OriginServerTS() { - t.Errorf("ClientEvent.OriginServerTS: wanted %d, got %d", ev2.OriginServerTS(), ce2.OriginServerTS) - } - if ce2.StateKey == nil || *ce.StateKey != testUserID { - t.Errorf("ClientEvent.StateKey: wanted %s, got %v", testUserID, ce2.StateKey) - } - if ce2.Type != ev2.Type() { - t.Errorf("ClientEvent.Type: wanted %s, got %s", ev2.Type(), ce2.Type) - } - if !bytes.Equal(ce2.Content, ev2.Content()) { - t.Errorf("ClientEvent.Content: wanted %s, got %s", string(ev2.Content()), string(ce2.Content)) - } - if !bytes.Equal(ce2.Unsigned, expectedUnsigned) { - t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(expectedUnsigned), string(ce2.Unsigned)) - } - if ce2.Sender != testUserID { - t.Errorf("ClientEvent.Sender: wanted %s, got %s", testUserID, ce2.Sender) - } + verifyEventFields(t, + EventFieldsToVerify{ + EventID: ce2.EventID, + Type: ce2.Type, + OriginServerTS: ce2.OriginServerTS, + StateKey: ce2.StateKey, + Content: ce2.Content, + Unsigned: ce2.Unsigned, + Sender: ce2.Sender, + }, + EventFieldsToVerify{ + EventID: ev2.EventID(), + Type: ev2.Type(), + OriginServerTS: ev2.OriginServerTS(), + StateKey: &sk, + Content: ev2.Content(), + Unsigned: expectedUnsigned, + Sender: testUserID, + }) }