From 3b869f5226eed2d3137b6289d746a7d8aff63318 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Mon, 28 Feb 2022 16:23:58 +0000 Subject: [PATCH] Initial working version of MSC2946 --- setup/mscs/msc2946/msc2946.go | 48 +++++++++++++++++++++++------------ setup/mscs/msc2946/storage.go | 24 +++++++++--------- 2 files changed, 44 insertions(+), 28 deletions(-) diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 3b7caea90..e3245c2e3 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -40,9 +40,10 @@ import ( ) const ( - ConstCreateEventContentKey = "type" - ConstSpaceChildEventType = "m.space.child" - ConstSpaceParentEventType = "m.space.parent" + ConstCreateEventContentKey = "type" + ConstCreateEventContentValueSpace = "m.space" + ConstSpaceChildEventType = "m.space.child" + ConstSpaceParentEventType = "m.space.parent" ) type MSC2946ClientResponse struct { @@ -230,7 +231,7 @@ func (w *walker) walk() util.JSONResponse { // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms // Depth first -> stack data structure - unvisited := []roomVisit{roomVisit{ + unvisited := []roomVisit{{ roomID: w.rootRoomID, depth: 0, }} @@ -251,6 +252,14 @@ func (w *walker) walk() util.JSONResponse { // Mark this room as processed. processed[rv.roomID] = true + // if this room is not a space room, skip. + var roomType string + create := w.stateEvent(rv.roomID, gomatrixserverlib.MRoomCreate, "") + if create != nil { + // escape the `.`s so gjson doesn't think it's nested + roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str + } + // Collect rooms/events to send back (either locally or fetched via federation) var discoveredChildEvents []gomatrixserverlib.MSC2946StrippedEvent @@ -266,12 +275,6 @@ func (w *walker) walk() util.JSONResponse { discoveredChildEvents = events pubRoom := w.publicRoomsChunk(rv.roomID) - roomType := "" - create := w.stateEvent(rv.roomID, gomatrixserverlib.MRoomCreate, "") - if create != nil { - // escape the `.`s so gjson doesn't think it's nested - roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str - } discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{ PublicRoom: *pubRoom, @@ -299,6 +302,12 @@ func (w *walker) walk() util.JSONResponse { } } + // don't walk the children + // if the parent is not a space room + if roomType != ConstCreateEventContentValueSpace { + continue + } + uniqueRooms := make(set) for _, ev := range discoveredChildEvents { uniqueRooms[ev.StateKey] = true @@ -375,9 +384,9 @@ func (w *walker) federatedRoomInfo(roomID string) (*gomatrixserverlib.MSC2946Spa return nil, nil } // extract events which point to this room ID and extract their vias - events, err := w.db.References(w.ctx, roomID) + events, err := w.db.ChildReferences(w.ctx, roomID) if err != nil { - return nil, fmt.Errorf("failed to get References events: %w", err) + return nil, fmt.Errorf("failed to get ChildReferences events: %w", err) } vias := make(set) for _, ev := range events { @@ -516,15 +525,22 @@ func (w *walker) authorisedUser(roomID string) bool { // references returns all child references pointing to or from this room. func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) { - events, err := w.db.References(w.ctx, roomID) + // don't return any child refs if the room is not a space room + create := w.stateEvent(roomID, gomatrixserverlib.MRoomCreate, "") + if create != nil { + // escape the `.`s so gjson doesn't think it's nested + roomType := gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str + if roomType != ConstCreateEventContentValueSpace { + return nil, nil + } + } + + events, err := w.db.ChildReferences(w.ctx, roomID) if err != nil { return nil, err } el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(events)) for _, ev := range events { - if ev.Type() != ConstSpaceChildEventType { - continue - } // only return events that have a `via` key as per MSC1772 // else we'll incorrectly walk redacted events (as the link // is in the state_key) diff --git a/setup/mscs/msc2946/storage.go b/setup/mscs/msc2946/storage.go index 20db18594..ef7734039 100644 --- a/setup/mscs/msc2946/storage.go +++ b/setup/mscs/msc2946/storage.go @@ -34,15 +34,15 @@ var ( type Database interface { // StoreReference persists a child or parent space mapping. StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error - // References returns all events which have the given roomID as a parent or child space. - References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) + // ChildReferences returns all space child events in the given room. + ChildReferences(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) } type DB struct { - db *sql.DB - writer sqlutil.Writer - insertEdgeStmt *sql.Stmt - selectEdgesStmt *sql.Stmt + db *sql.DB + writer sqlutil.Writer + insertEdgeStmt *sql.Stmt + selectEdgesOfTypeStmt *sql.Stmt } // NewDatabase loads the database for msc2836 @@ -84,9 +84,9 @@ func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { `); err != nil { return nil, err } - if d.selectEdgesStmt, err = d.db.Prepare(` + if d.selectEdgesOfTypeStmt, err = d.db.Prepare(` SELECT room_version, event_json FROM msc2946_edges - WHERE source_room_id = $1 OR dest_room_id = $2 + WHERE source_room_id = $1 AND rel_type = $2 `); err != nil { return nil, err } @@ -124,9 +124,9 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { `); err != nil { return nil, err } - if d.selectEdgesStmt, err = d.db.Prepare(` + if d.selectEdgesOfTypeStmt, err = d.db.Prepare(` SELECT room_version, event_json FROM msc2946_edges - WHERE source_room_id = $1 OR dest_room_id = $2 + WHERE source_room_id = $1 AND rel_type = $2 `); err != nil { return nil, err } @@ -143,8 +143,8 @@ func (d *DB) StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedE return err } -func (d *DB) References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) { - rows, err := d.selectEdgesStmt.QueryContext(ctx, roomID, roomID) +func (d *DB) ChildReferences(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) { + rows, err := d.selectEdgesOfTypeStmt.QueryContext(ctx, roomID, relTypes[ConstSpaceChildEventType]) if err != nil { return nil, err }