diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 7185bac94..732bca423 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -2,6 +2,8 @@ package syncapi import ( "context" + "encoding/json" + "fmt" "net/http" "net/http/httptest" "testing" @@ -11,13 +13,48 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/nats-io/nats.go" ) type syncRoomserverAPI struct { - rsapi.RoomserverInternalAPI + rsapi.SyncRoomserverAPI + rooms []*test.Room +} + +func (s *syncRoomserverAPI) QueryEventsByID(ctx context.Context, req *rsapi.QueryEventsByIDRequest, res *rsapi.QueryEventsByIDResponse) error { +NextEvent: + for _, eventID := range req.EventIDs { + for _, r := range s.rooms { + for _, ev := range r.Events() { + fmt.Println(ev.EventID()) + if ev.EventID() == eventID { + res.Events = append(res.Events, ev) + continue NextEvent + } + } + } + } + fmt.Println("QueryEventsByID", req.EventIDs, " returning ", len(res.Events)) + return nil +} + +func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error { + var room *test.Room + for _, r := range s.rooms { + if r.ID == req.RoomID { + room = r + break + } + } + if room == nil { + res.RoomExists = false + return nil + } + res.RoomVersion = room.Version + return nil // TODO: return state } type syncUserAPI struct { @@ -45,9 +82,11 @@ type syncKeyAPI struct { } func TestSyncAPI(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - testSync(t, dbType) - }) + testSync(t, test.DBTypePostgres) + /* + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + testSync(t, dbType) + }) */ } func testSync(t *testing.T, dbType test.DBType) { @@ -68,16 +107,21 @@ func testSync(t *testing.T, dbType test.DBType) { defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) var msgs []*nats.Msg for _, ev := range room.Events() { + var addsStateIDs []string + if ev.StateKey() != nil { + addsStateIDs = append(addsStateIDs, ev.EventID()) + } msgs = append(msgs, test.NewOutputEventMsg(t, base, room.ID, api.OutputEvent{ Type: rsapi.OutputTypeNewRoomEvent, NewRoomEvent: &rsapi.OutputNewRoomEvent{ - Event: ev, + Event: ev, + AddsStateEventIDs: addsStateIDs, }, })) } test.MustPublishMsgs(t, jsctx, msgs...) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) testCases := []struct { name string @@ -111,7 +155,7 @@ func testSync(t *testing.T, dbType test.DBType) { }, } // TODO: find a better way - time.Sleep(100 * time.Millisecond) + time.Sleep(1000 * time.Millisecond) for _, tc := range testCases { w := httptest.NewRecorder() @@ -119,15 +163,14 @@ func testSync(t *testing.T, dbType test.DBType) { if w.Code != tc.wantCode { t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode) } - /* - if tc.wantJoinedRooms != nil { - var res types.Response - if err := json.NewDecoder(w.Body).Decode(&res); err != nil { - t.Fatalf("%s: failed to decode response body: %s", tc.name, err) - } - if len(res.Rooms.Join) != len(tc.wantJoinedRooms) { - t.Errorf("%s: got %v joined rooms, want %v.\nResponse: %+v", tc.name, len(res.Rooms.Join), len(tc.wantJoinedRooms), res) - } - } */ + if tc.wantJoinedRooms != nil { + var res types.Response + if err := json.NewDecoder(w.Body).Decode(&res); err != nil { + t.Fatalf("%s: failed to decode response body: %s", tc.name, err) + } + if len(res.Rooms.Join) != len(tc.wantJoinedRooms) { + t.Errorf("%s: got %v joined rooms, want %v.\nResponse: %+v", tc.name, len(res.Rooms.Join), len(tc.wantJoinedRooms), res) + } + } } }