implement add_state_ids

This commit is contained in:
Kegan Dougal 2022-05-09 09:23:29 +01:00
parent 6fdc3b0e4c
commit 71e400b8cf

View file

@ -2,6 +2,8 @@ package syncapi
import ( import (
"context" "context"
"encoding/json"
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -11,13 +13,48 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
) )
type syncRoomserverAPI struct { type syncRoomserverAPI struct {
rsapi.RoomserverInternalAPI rsapi.SyncRoomserverAPI
rooms []*test.Room
}
func (s *syncRoomserverAPI) QueryEventsByID(ctx context.Context, req *rsapi.QueryEventsByIDRequest, res *rsapi.QueryEventsByIDResponse) error {
NextEvent:
for _, eventID := range req.EventIDs {
for _, r := range s.rooms {
for _, ev := range r.Events() {
fmt.Println(ev.EventID())
if ev.EventID() == eventID {
res.Events = append(res.Events, ev)
continue NextEvent
}
}
}
}
fmt.Println("QueryEventsByID", req.EventIDs, " returning ", len(res.Events))
return nil
}
func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error {
var room *test.Room
for _, r := range s.rooms {
if r.ID == req.RoomID {
room = r
break
}
}
if room == nil {
res.RoomExists = false
return nil
}
res.RoomVersion = room.Version
return nil // TODO: return state
} }
type syncUserAPI struct { type syncUserAPI struct {
@ -45,9 +82,11 @@ type syncKeyAPI struct {
} }
func TestSyncAPI(t *testing.T) { func TestSyncAPI(t *testing.T) {
testSync(t, test.DBTypePostgres)
/*
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
testSync(t, dbType) testSync(t, dbType)
}) }) */
} }
func testSync(t *testing.T, dbType test.DBType) { func testSync(t *testing.T, dbType test.DBType) {
@ -68,16 +107,21 @@ func testSync(t *testing.T, dbType test.DBType) {
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
var msgs []*nats.Msg var msgs []*nats.Msg
for _, ev := range room.Events() { for _, ev := range room.Events() {
var addsStateIDs []string
if ev.StateKey() != nil {
addsStateIDs = append(addsStateIDs, ev.EventID())
}
msgs = append(msgs, test.NewOutputEventMsg(t, base, room.ID, api.OutputEvent{ msgs = append(msgs, test.NewOutputEventMsg(t, base, room.ID, api.OutputEvent{
Type: rsapi.OutputTypeNewRoomEvent, Type: rsapi.OutputTypeNewRoomEvent,
NewRoomEvent: &rsapi.OutputNewRoomEvent{ NewRoomEvent: &rsapi.OutputNewRoomEvent{
Event: ev, Event: ev,
AddsStateEventIDs: addsStateIDs,
}, },
})) }))
} }
test.MustPublishMsgs(t, jsctx, msgs...) test.MustPublishMsgs(t, jsctx, msgs...)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{}) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
testCases := []struct { testCases := []struct {
name string name string
@ -111,7 +155,7 @@ func testSync(t *testing.T, dbType test.DBType) {
}, },
} }
// TODO: find a better way // TODO: find a better way
time.Sleep(100 * time.Millisecond) time.Sleep(1000 * time.Millisecond)
for _, tc := range testCases { for _, tc := range testCases {
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -119,7 +163,6 @@ func testSync(t *testing.T, dbType test.DBType) {
if w.Code != tc.wantCode { if w.Code != tc.wantCode {
t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode) t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode)
} }
/*
if tc.wantJoinedRooms != nil { if tc.wantJoinedRooms != nil {
var res types.Response var res types.Response
if err := json.NewDecoder(w.Body).Decode(&res); err != nil { if err := json.NewDecoder(w.Body).Decode(&res); err != nil {
@ -128,6 +171,6 @@ func testSync(t *testing.T, dbType test.DBType) {
if len(res.Rooms.Join) != len(tc.wantJoinedRooms) { if len(res.Rooms.Join) != len(tc.wantJoinedRooms) {
t.Errorf("%s: got %v joined rooms, want %v.\nResponse: %+v", tc.name, len(res.Rooms.Join), len(tc.wantJoinedRooms), res) t.Errorf("%s: got %v joined rooms, want %v.\nResponse: %+v", tc.name, len(res.Rooms.Join), len(tc.wantJoinedRooms), res)
} }
} */ }
} }
} }