Initial working version of MSC2946

This commit is contained in:
Kegan Dougal 2022-02-28 16:23:58 +00:00
parent c887ba31ea
commit 3b869f5226
2 changed files with 44 additions and 28 deletions

View file

@ -40,9 +40,10 @@ import (
) )
const ( const (
ConstCreateEventContentKey = "type" ConstCreateEventContentKey = "type"
ConstSpaceChildEventType = "m.space.child" ConstCreateEventContentValueSpace = "m.space"
ConstSpaceParentEventType = "m.space.parent" ConstSpaceChildEventType = "m.space.child"
ConstSpaceParentEventType = "m.space.parent"
) )
type MSC2946ClientResponse struct { type MSC2946ClientResponse struct {
@ -230,7 +231,7 @@ func (w *walker) walk() util.JSONResponse {
// Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms
// Depth first -> stack data structure // Depth first -> stack data structure
unvisited := []roomVisit{roomVisit{ unvisited := []roomVisit{{
roomID: w.rootRoomID, roomID: w.rootRoomID,
depth: 0, depth: 0,
}} }}
@ -251,6 +252,14 @@ func (w *walker) walk() util.JSONResponse {
// Mark this room as processed. // Mark this room as processed.
processed[rv.roomID] = true processed[rv.roomID] = true
// if this room is not a space room, skip.
var roomType string
create := w.stateEvent(rv.roomID, gomatrixserverlib.MRoomCreate, "")
if create != nil {
// escape the `.`s so gjson doesn't think it's nested
roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
}
// Collect rooms/events to send back (either locally or fetched via federation) // Collect rooms/events to send back (either locally or fetched via federation)
var discoveredChildEvents []gomatrixserverlib.MSC2946StrippedEvent var discoveredChildEvents []gomatrixserverlib.MSC2946StrippedEvent
@ -266,12 +275,6 @@ func (w *walker) walk() util.JSONResponse {
discoveredChildEvents = events discoveredChildEvents = events
pubRoom := w.publicRoomsChunk(rv.roomID) pubRoom := w.publicRoomsChunk(rv.roomID)
roomType := ""
create := w.stateEvent(rv.roomID, gomatrixserverlib.MRoomCreate, "")
if create != nil {
// escape the `.`s so gjson doesn't think it's nested
roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
}
discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{ discoveredRooms = append(discoveredRooms, gomatrixserverlib.MSC2946Room{
PublicRoom: *pubRoom, PublicRoom: *pubRoom,
@ -299,6 +302,12 @@ func (w *walker) walk() util.JSONResponse {
} }
} }
// don't walk the children
// if the parent is not a space room
if roomType != ConstCreateEventContentValueSpace {
continue
}
uniqueRooms := make(set) uniqueRooms := make(set)
for _, ev := range discoveredChildEvents { for _, ev := range discoveredChildEvents {
uniqueRooms[ev.StateKey] = true uniqueRooms[ev.StateKey] = true
@ -375,9 +384,9 @@ func (w *walker) federatedRoomInfo(roomID string) (*gomatrixserverlib.MSC2946Spa
return nil, nil return nil, nil
} }
// extract events which point to this room ID and extract their vias // extract events which point to this room ID and extract their vias
events, err := w.db.References(w.ctx, roomID) events, err := w.db.ChildReferences(w.ctx, roomID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get References events: %w", err) return nil, fmt.Errorf("failed to get ChildReferences events: %w", err)
} }
vias := make(set) vias := make(set)
for _, ev := range events { for _, ev := range events {
@ -516,15 +525,22 @@ func (w *walker) authorisedUser(roomID string) bool {
// references returns all child references pointing to or from this room. // references returns all child references pointing to or from this room.
func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) { func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946StrippedEvent, error) {
events, err := w.db.References(w.ctx, roomID) // don't return any child refs if the room is not a space room
create := w.stateEvent(roomID, gomatrixserverlib.MRoomCreate, "")
if create != nil {
// escape the `.`s so gjson doesn't think it's nested
roomType := gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str
if roomType != ConstCreateEventContentValueSpace {
return nil, nil
}
}
events, err := w.db.ChildReferences(w.ctx, roomID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(events)) el := make([]gomatrixserverlib.MSC2946StrippedEvent, 0, len(events))
for _, ev := range events { for _, ev := range events {
if ev.Type() != ConstSpaceChildEventType {
continue
}
// only return events that have a `via` key as per MSC1772 // only return events that have a `via` key as per MSC1772
// else we'll incorrectly walk redacted events (as the link // else we'll incorrectly walk redacted events (as the link
// is in the state_key) // is in the state_key)

View file

@ -34,15 +34,15 @@ var (
type Database interface { type Database interface {
// StoreReference persists a child or parent space mapping. // StoreReference persists a child or parent space mapping.
StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedEvent) error
// References returns all events which have the given roomID as a parent or child space. // ChildReferences returns all space child events in the given room.
References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) ChildReferences(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error)
} }
type DB struct { type DB struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer writer sqlutil.Writer
insertEdgeStmt *sql.Stmt insertEdgeStmt *sql.Stmt
selectEdgesStmt *sql.Stmt selectEdgesOfTypeStmt *sql.Stmt
} }
// NewDatabase loads the database for msc2836 // NewDatabase loads the database for msc2836
@ -84,9 +84,9 @@ func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
`); err != nil { `); err != nil {
return nil, err return nil, err
} }
if d.selectEdgesStmt, err = d.db.Prepare(` if d.selectEdgesOfTypeStmt, err = d.db.Prepare(`
SELECT room_version, event_json FROM msc2946_edges SELECT room_version, event_json FROM msc2946_edges
WHERE source_room_id = $1 OR dest_room_id = $2 WHERE source_room_id = $1 AND rel_type = $2
`); err != nil { `); err != nil {
return nil, err return nil, err
} }
@ -124,9 +124,9 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
`); err != nil { `); err != nil {
return nil, err return nil, err
} }
if d.selectEdgesStmt, err = d.db.Prepare(` if d.selectEdgesOfTypeStmt, err = d.db.Prepare(`
SELECT room_version, event_json FROM msc2946_edges SELECT room_version, event_json FROM msc2946_edges
WHERE source_room_id = $1 OR dest_room_id = $2 WHERE source_room_id = $1 AND rel_type = $2
`); err != nil { `); err != nil {
return nil, err return nil, err
} }
@ -143,8 +143,8 @@ func (d *DB) StoreReference(ctx context.Context, he *gomatrixserverlib.HeaderedE
return err return err
} }
func (d *DB) References(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) { func (d *DB) ChildReferences(ctx context.Context, roomID string) ([]*gomatrixserverlib.HeaderedEvent, error) {
rows, err := d.selectEdgesStmt.QueryContext(ctx, roomID, roomID) rows, err := d.selectEdgesOfTypeStmt.QueryContext(ctx, roomID, relTypes[ConstSpaceChildEventType])
if err != nil { if err != nil {
return nil, err return nil, err
} }