From b21dbde7875a29c1b83df770e53a7bf51318de8b Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Fri, 30 Oct 2020 18:36:28 +0000 Subject: [PATCH] Begin implementing core msc2836 logic --- internal/mscs/msc2836/msc2836.go | 143 +++++++++++++++++++++++++++++-- internal/mscs/msc2836/storage.go | 48 +++++++++-- 2 files changed, 181 insertions(+), 10 deletions(-) diff --git a/internal/mscs/msc2836/msc2836.go b/internal/mscs/msc2836/msc2836.go index 1277a44df..57e2768dd 100644 --- a/internal/mscs/msc2836/msc2836.go +++ b/internal/mscs/msc2836/msc2836.go @@ -20,6 +20,7 @@ import ( "encoding/json" "fmt" "net/http" + "sort" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/hooks" @@ -35,14 +36,49 @@ type eventRelationshipRequest struct { MaxDepth int `json:"max_depth"` MaxBreadth int `json:"max_breadth"` Limit int `json:"limit"` - DepthFirst bool `json:"depth_first"` - RecentFirst bool `json:"recent_first"` - IncludeParent bool `json:"include_parent"` - IncludeChildren bool `json:"include_children"` + DepthFirst *bool `json:"depth_first"` + RecentFirst *bool `json:"recent_first"` + IncludeParent *bool `json:"include_parent"` + IncludeChildren *bool `json:"include_children"` Direction string `json:"direction"` Batch string `json:"batch"` } +func (r *eventRelationshipRequest) applyDefaults() { + if r.Limit > 100 || r.Limit < 1 { + r.Limit = 100 + } + if r.MaxBreadth == 0 { + r.MaxBreadth = 10 + } + if r.MaxDepth == 0 { + r.MaxDepth = 3 + } + t := true + f := false + if r.DepthFirst == nil { + r.DepthFirst = &f + } + if r.RecentFirst == nil { + r.RecentFirst = &t + } + if r.IncludeParent == nil { + r.IncludeParent = &f + } + if r.IncludeChildren == nil { + r.IncludeChildren = &f + } + if r.Direction != "up" { + r.Direction = "down" + } +} + +type eventRelationshipResponse struct { + Events []gomatrixserverlib.ClientEvent `json:"events"` + NextBatch string `json:"next_batch"` + Limited bool `json:"limited"` +} + // Enable this MSC func Enable(base *setup.BaseDendrite, monolith *setup.Monolith) error { db, err := NewDatabase(&base.Cfg.MSCs.Database) @@ -70,11 +106,108 @@ func Enable(base *setup.BaseDendrite, monolith *setup.Monolith) error { JSON: jsonerror.BadJSON(fmt.Sprintf("invalid json: %s", err)), } } + // Sanity check request and set defaults. + relation.applyDefaults() + var res eventRelationshipResponse + var returnEvents []*gomatrixserverlib.HeaderedEvent + + // Can the user see (according to history visibility) event_id? If no, reject the request, else continue. + event := getEventIfVisible(req.Context(), relation.EventID, device.UserID) + if event == nil { + return util.JSONResponse{ + Code: 403, + JSON: jsonerror.Forbidden("Event does not exist or you are not authorised to see it"), + } + } + + // Retrieve the event. Add it to response array. + returnEvents = append(returnEvents, event) + + if *relation.IncludeParent { + if parentEvent := includeParent(req.Context(), event, device.UserID); parentEvent != nil { + returnEvents = append(returnEvents, parentEvent) + } + } + + if *relation.IncludeChildren { + remaining := relation.Limit - len(returnEvents) + if remaining > 0 { + children, resErr := includeChildren(req.Context(), db, event.EventID(), remaining, *relation.RecentFirst, device.UserID) + if resErr != nil { + return *resErr + } + returnEvents = append(returnEvents, children...) + } + } + + remaining := relation.Limit - len(returnEvents) + var walkLimited bool + if remaining > 0 { + // Begin to walk the thread DAG in the direction specified, either depth or breadth first according to the depth_first flag, + // honouring the limit, max_depth and max_breadth values according to the following rules + var events []*gomatrixserverlib.HeaderedEvent + events, walkLimited = walkThread(req.Context(), db, remaining) + returnEvents = append(returnEvents, events...) + } + res.Events = make([]gomatrixserverlib.ClientEvent, len(returnEvents)) + for i, ev := range returnEvents { + res.Events[i] = gomatrixserverlib.HeaderedToClientEvent(*ev, gomatrixserverlib.FormatAll) + } + res.Limited = remaining == 0 || walkLimited + return util.JSONResponse{ Code: 200, - JSON: struct{}{}, + JSON: res, } }), ).Methods(http.MethodPost, http.MethodOptions) return nil } + +// If include_parent: true and there is a valid m.relationship field in the event, +// retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array. +func includeParent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, userID string) (parent *gomatrixserverlib.HeaderedEvent) { + parentID, _ := parentChildEventIDs(event) + if parentID == "" { + return nil + } + return getEventIfVisible(ctx, parentID, userID) +} + +// If include_children: true, lookup all events which have event_id as an m.relationship +// Apply history visibility checks to all these events and add the ones which pass into the response array, +// honouring the recent_first flag and the limit. +func includeChildren(ctx context.Context, db Database, parentID string, limit int, recentFirst bool, userID string) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) { + children, err := db.ChildrenForParent(ctx, parentID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("failed to get ChildrenForParent") + resErr := jsonerror.InternalServerError() + return nil, &resErr + } + var childEvents []*gomatrixserverlib.HeaderedEvent + for _, child := range children { + childEvent := getEventIfVisible(ctx, child, userID) + if childEvent != nil { + childEvents = append(childEvents, childEvent) + } + } + // sort childEvents by origin_server_ts in ASC or DESC depending on recent_first + sort.SliceStable(childEvents, func(i, j int) bool { + if recentFirst { + return childEvents[i].OriginServerTS().Time().After(childEvents[j].OriginServerTS().Time()) + } + return childEvents[i].OriginServerTS().Time().Before(childEvents[j].OriginServerTS().Time()) + }) + if len(childEvents) > limit { + return childEvents[:limit], nil + } + return childEvents, nil +} + +func walkThread(ctx context.Context, db Database, limit int) ([]*gomatrixserverlib.HeaderedEvent, bool) { + return nil, false +} + +func getEventIfVisible(ctx context.Context, eventID, userID string) *gomatrixserverlib.HeaderedEvent { + return nil +} diff --git a/internal/mscs/msc2836/storage.go b/internal/mscs/msc2836/storage.go index 7401d7fbc..358a7c980 100644 --- a/internal/mscs/msc2836/storage.go +++ b/internal/mscs/msc2836/storage.go @@ -13,11 +13,13 @@ import ( type Database interface { // StoreRelation stores the parent->child and child->parent relationship for later querying. StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error + ChildrenForParent(ctx context.Context, eventID string) ([]string, error) } type Postgres struct { - db *sql.DB - insertRelationStmt *sql.Stmt + db *sql.DB + insertRelationStmt *sql.Stmt + selectChildrenForParentStmt *sql.Stmt } func NewPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { @@ -39,6 +41,11 @@ func NewPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { `); err != nil { return nil, err } + if p.selectChildrenForParentStmt, err = p.db.Prepare(` + SELECT child_event_id FROM msc2836_relationships WHERE parent_event_id = $1 + `); err != nil { + return nil, err + } return &p, err } @@ -51,10 +58,15 @@ func (p *Postgres) StoreRelation(ctx context.Context, ev *gomatrixserverlib.Head return err } +func (p *Postgres) ChildrenForParent(ctx context.Context, eventID string) ([]string, error) { + return childrenForParent(ctx, eventID, p.selectChildrenForParentStmt) +} + type SQLite struct { - db *sql.DB - insertRelationStmt *sql.Stmt - writer sqlutil.Writer + db *sql.DB + insertRelationStmt *sql.Stmt + selectChildrenForParentStmt *sql.Stmt + writer sqlutil.Writer } func NewSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { @@ -77,6 +89,11 @@ func NewSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { `); err != nil { return nil, err } + if s.selectChildrenForParentStmt, err = s.db.Prepare(` + SELECT child_event_id FROM msc2836_relationships WHERE parent_event_id = $1 + `); err != nil { + return nil, err + } return &s, nil } @@ -89,6 +106,10 @@ func (s *SQLite) StoreRelation(ctx context.Context, ev *gomatrixserverlib.Header return err } +func (s *SQLite) ChildrenForParent(ctx context.Context, eventID string) ([]string, error) { + return childrenForParent(ctx, eventID, s.selectChildrenForParentStmt) +} + // NewDatabase loads the database for msc2836 func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) { if dbOpts.ConnectionString.IsPostgres() { @@ -115,3 +136,20 @@ func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent string, ch } return } + +func childrenForParent(ctx context.Context, eventID string, stmt *sql.Stmt) ([]string, error) { + rows, err := stmt.QueryContext(ctx, eventID) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + var children []string + for rows.Next() { + var childID string + if err := rows.Scan(&childID); err != nil { + return nil, err + } + children = append(children, childID) + } + return children, nil +}