diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 4da9edf96..6d169573d 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -171,7 +171,7 @@ func (w *walker) walk() (*SpacesResponse, *util.JSONResponse) { roomID := unvisited[0] unvisited = unvisited[1:] // If this room has already been processed, skip. NB: do not remember this between calls - if processed[roomID] { + if processed[roomID] || roomID == "" { continue } // Mark this room as processed. @@ -219,6 +219,7 @@ func (w *walker) walk() (*SpacesResponse, *util.JSONResponse) { ev, gomatrixserverlib.FormatAll, )) uniqueRooms[ev.RoomID()] = true + uniqueRooms[SpaceTarget(ev)] = true w.markSent(ev.EventID()) } } @@ -227,10 +228,10 @@ func (w *walker) walk() (*SpacesResponse, *util.JSONResponse) { // are exceeded, stop adding events. If the event has already been added, do not add it again. numAdded := 0 for _, ev := range refs.events() { - if len(res.Events) >= w.req.Limit { + if w.req.Limit > 0 && len(res.Events) >= w.req.Limit { break } - if numAdded >= w.req.MaxRoomsPerSpace { + if w.req.MaxRoomsPerSpace > 0 && numAdded >= w.req.MaxRoomsPerSpace { break } if w.alreadySent(ev.EventID()) { @@ -240,6 +241,7 @@ func (w *walker) walk() (*SpacesResponse, *util.JSONResponse) { ev, gomatrixserverlib.FormatAll, )) uniqueRooms[ev.RoomID()] = true + uniqueRooms[SpaceTarget(ev)] = true w.markSent(ev.EventID()) // we don't distinguish between child state events and parent state events for the purposes of // max_rooms_per_space, maybe we should? @@ -331,9 +333,9 @@ func (el eventLookup) get(roomID, evType, stateKey string) *gomatrixserverlib.He func (el eventLookup) set(ev *gomatrixserverlib.HeaderedEvent) { evs := el[ev.Type()] if evs == nil { - evs = make([]*gomatrixserverlib.HeaderedEvent, 1) + evs = make([]*gomatrixserverlib.HeaderedEvent, 0) } - evs[0] = ev + evs = append(evs, ev) el[ev.Type()] = evs } diff --git a/setup/mscs/msc2946/msc2946_test.go b/setup/mscs/msc2946/msc2946_test.go index 587dc5b42..1eaf7cf65 100644 --- a/setup/mscs/msc2946/msc2946_test.go +++ b/setup/mscs/msc2946/msc2946_test.go @@ -57,8 +57,10 @@ var ( // |_________ // | | | // R3 R4 S2 -// | +// | <-- this link is just a parent, not a child // R5 +// +// TODO: Alice is not joined to R4, but R4 is "world_readable". func TestMSC2946(t *testing.T) { alice := "@alice:localhost" // give access tokens to all three users @@ -77,6 +79,7 @@ func TestMSC2946(t *testing.T) { room2 := "!room2:localhost" room3 := "!room3:localhost" room4 := "!room4:localhost" + empty := "" room5 := "!room5:localhost" allRooms := []string{ rootSpace, subSpaceS1, subSpaceS2, @@ -142,16 +145,26 @@ func TestMSC2946(t *testing.T) { "present": true, }, }) + // This is a parent link only s2ToR5 := mustCreateEvent(t, fledglingEvent{ - RoomID: subSpaceS2, + RoomID: room5, Sender: alice, - Type: msc2946.ConstSpaceChildEventType, - StateKey: &room5, + Type: msc2946.ConstSpaceParentEventType, + StateKey: &empty, Content: map[string]interface{}{ + "room_id": subSpaceS2, "via": []string{"localhost"}, "present": true, }, }) + roomNameTuple := gomatrixserverlib.StateKeyTuple{ + EventType: "m.room.name", + StateKey: "", + } + hisVisTuple := gomatrixserverlib.StateKeyTuple{ + EventType: "m.room.history_visibility", + StateKey: "", + } nopRsAPI := &testRoomserverAPI{ userToJoinedRooms: map[string][]string{ alice: allRooms, @@ -165,6 +178,35 @@ func TestMSC2946(t *testing.T) { s1ToS2.EventID(): s1ToS2, s2ToR5.EventID(): s2ToR5, }, + pubRoomState: map[string]map[gomatrixserverlib.StateKeyTuple]string{ + rootSpace: { + roomNameTuple: "Root", + hisVisTuple: "shared", + }, + subSpaceS1: { + roomNameTuple: "Sub-Space 1", + hisVisTuple: "joined", + }, + subSpaceS2: { + roomNameTuple: "Sub-Space 2", + hisVisTuple: "shared", + }, + room1: { + hisVisTuple: "joined", + }, + room2: { + hisVisTuple: "joined", + }, + room3: { + hisVisTuple: "joined", + }, + room4: { + hisVisTuple: "world_readable", + }, + room5: { + hisVisTuple: "joined", + }, + }, } router := injectEvents(t, nopUserAPI, nopRsAPI, []*gomatrixserverlib.HeaderedEvent{ rootToR1, rootToR2, rootToS1, @@ -183,6 +225,16 @@ func TestMSC2946(t *testing.T) { t.Errorf("got %d rooms, want 0", len(res.Rooms)) } }) + t.Run("returns the entire graph", func(t *testing.T) { + res := postSpaces(t, 200, "alice", rootSpace, newReq(t, map[string]interface{}{})) + if len(res.Events) != 7 { + t.Errorf("got %d events, want 7", len(res.Events)) + } + if len(res.Rooms) != len(allRooms) { + t.Errorf("got %d rooms, want %d", len(res.Rooms), len(allRooms)) + } + + }) } func newReq(t *testing.T, jsonBody map[string]interface{}) *msc2946.SpacesRequest { @@ -245,6 +297,7 @@ func postSpaces(t *testing.T, expectCode int, accessToken, roomID string, req *m if err != nil { t.Fatalf("response 200 OK but failed to read response body: %s", err) } + t.Logf("Body: %s", string(body)) if err := json.Unmarshal(body, &result); err != nil { t.Fatalf("response 200 OK but failed to deserialise JSON : %s\nbody: %s", err, string(body)) } @@ -365,13 +418,37 @@ type testRoomserverAPI struct { roomserver.RoomserverInternalAPITrace userToJoinedRooms map[string][]string events map[string]*gomatrixserverlib.HeaderedEvent + pubRoomState map[string]map[gomatrixserverlib.StateKeyTuple]string } -func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error { - for _, eventID := range req.EventIDs { - ev := r.events[eventID] - if ev != nil { - res.Events = append(res.Events, ev) +func (r *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *roomserver.QueryBulkStateContentRequest, res *roomserver.QueryBulkStateContentResponse) error { + res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string) + for _, roomID := range req.RoomIDs { + pubRoomData, ok := r.pubRoomState[roomID] + if ok { + res.Rooms[roomID] = pubRoomData + } + } + return nil +} + +func (r *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *roomserver.QueryCurrentStateRequest, res *roomserver.QueryCurrentStateResponse) error { + res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) + for _, he := range r.events { + if he.RoomID() != req.RoomID { + continue + } + if he.StateKey() == nil { + continue + } + tuple := gomatrixserverlib.StateKeyTuple{ + EventType: he.Type(), + StateKey: *he.StateKey(), + } + for _, t := range req.StateTuples { + if t == tuple { + res.StateEvents[t] = he + } } } return nil diff --git a/setup/mscs/msc2946/storage.go b/setup/mscs/msc2946/storage.go index 69096e64b..a1b59c817 100644 --- a/setup/mscs/msc2946/storage.go +++ b/setup/mscs/msc2946/storage.go @@ -18,31 +18,170 @@ import ( "context" "database/sql" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/tidwall/gjson" +) + +var ( + relTypes = map[string]int{ + ConstSpaceChildEventType: 1, + ConstSpaceParentEventType: 2, + } + relTypesEnum = map[int]string{ + 1: ConstSpaceChildEventType, + 2: ConstSpaceParentEventType, + } ) type Database interface { + // StoreReference persists a child or parent space mapping. StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error + // References returns all events which have the given roomID as a parent or child space. References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) } type DB struct { - db *sql.DB - writer sqlutil.Writer - insertEdgeStmt *sql.Stmt + db *sql.DB + writer sqlutil.Writer + insertEdgeStmt *sql.Stmt + selectEdgesStmt *sql.Stmt } // NewDatabase loads the database for msc2836 func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) { - return &DB{}, nil + if dbOpts.ConnectionString.IsPostgres() { + return newPostgresDatabase(dbOpts) + } + return newSQLiteDatabase(dbOpts) +} + +func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { + d := DB{ + writer: sqlutil.NewDummyWriter(), + } + var err error + if d.db, err = sqlutil.Open(dbOpts); err != nil { + return nil, err + } + _, err = d.db.Exec(` + CREATE TABLE IF NOT EXISTS msc2946_edges ( + room_version TEXT NOT NULL, + -- the room ID of the event, the source of the arrow + source_room_id TEXT NOT NULL, + -- the target room ID, the arrow destination + dest_room_id TEXT NOT NULL, + -- the kind of relation, either child or parent (1,2) + rel_type SMALLINT NOT NULL, + event_json TEXT NOT NULL, + CONSTRAINT msc2946_edges_uniq UNIQUE (source_room_id, dest_room_id, rel_type) + ); + `) + if err != nil { + return nil, err + } + if d.insertEdgeStmt, err = d.db.Prepare(` + INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json) + VALUES($1, $2, $3, $4, $5) + ON CONFLICT DO NOTHING + `); err != nil { + return nil, err + } + if d.selectEdgesStmt, err = d.db.Prepare(` + SELECT room_version, event_json FROM msc2946_edges + WHERE source_room_id = $1 OR dest_room_id = $2 + `); err != nil { + return nil, err + } + return &d, err +} + +func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { + d := DB{ + writer: sqlutil.NewExclusiveWriter(), + } + var err error + if d.db, err = sqlutil.Open(dbOpts); err != nil { + return nil, err + } + _, err = d.db.Exec(` + CREATE TABLE IF NOT EXISTS msc2946_edges ( + room_version TEXT NOT NULL, + -- the room ID of the event, the source of the arrow + source_room_id TEXT NOT NULL, + -- the target room ID, the arrow destination + dest_room_id TEXT NOT NULL, + -- the kind of relation, either child or parent (1,2) + rel_type SMALLINT NOT NULL, + event_json TEXT NOT NULL, + UNIQUE (source_room_id, dest_room_id, rel_type) + ); + `) + if err != nil { + return nil, err + } + if d.insertEdgeStmt, err = d.db.Prepare(` + INSERT INTO msc2946_edges(room_version, source_room_id, dest_room_id, rel_type, event_json) + VALUES($1, $2, $3, $4, $5) + ON CONFLICT DO NOTHING + `); err != nil { + return nil, err + } + if d.selectEdgesStmt, err = d.db.Prepare(` + SELECT room_version, event_json FROM msc2946_edges + WHERE source_room_id = $1 OR dest_room_id = $2 + `); err != nil { + return nil, err + } + return &d, err } func (d *DB) StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error { - return nil + target := SpaceTarget(he) + if target == "" { + return nil // malformed event + } + relType := relTypes[he.Type()] + _, err := d.insertEdgeStmt.ExecContext(ctx, he.RoomVersion, he.RoomID(), target, relType, he.JSON()) + return err } func (d *DB) References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) { - return nil, nil + rows, err := d.selectEdgesStmt.QueryContext(ctx, roomID, roomID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "failed to close References") + refs := make([]*gomatrixserverlib.HeaderedEvent, 0) + for rows.Next() { + var roomVer string + var jsonBytes []byte + if err := rows.Scan(&roomVer, &jsonBytes); err != nil { + return nil, err + } + ev, err := gomatrixserverlib.NewEventFromTrustedJSON(jsonBytes, false, gomatrixserverlib.RoomVersion(roomVer)) + if err != nil { + return nil, err + } + he := ev.Headered(gomatrixserverlib.RoomVersion(roomVer)) + refs = append(refs, he) + } + return refs, nil +} + +// SpaceTarget returns the destination room ID for the space event. This is either a child or a parent +// depending on the event type. +func SpaceTarget(he *gomatrixserverlib.HeaderedEvent) string { + if he.StateKey() == nil { + return "" // no-op + } + switch he.Type() { + case ConstSpaceParentEventType: + return gjson.GetBytes(he.Content(), "room_id").Str + case ConstSpaceChildEventType: + return *he.StateKey() + } + return "" }