diff --git a/internal/mscs/msc2836/msc2836.go b/internal/mscs/msc2836/msc2836.go index db2334240..98e3ff36c 100644 --- a/internal/mscs/msc2836/msc2836.go +++ b/internal/mscs/msc2836/msc2836.go @@ -501,6 +501,7 @@ func (rc *reqCtx) remoteEventRelationships(eventID string) *gomatrixserverlib.MS // lookForEvent returns the event for the event ID given, by trying to auto-join rooms if not authorised and by querying remote servers // if the event ID is unknown. If `exploreThread` is true, remote requests will use /event_relationships instead of /event. This is // desirable when walking the thread, but is not desirable when satisfying include_parent|children flags. +// nolint:gocyclo func (rc *reqCtx) lookForEvent(eventID string, exploreThread bool) *gomatrixserverlib.HeaderedEvent { event := rc.getLocalEvent(eventID) if event == nil { @@ -641,6 +642,9 @@ func (rc *reqCtx) injectResponseToRoomserver(res *gomatrixserverlib.MSC2836Event func (rc *reqCtx) addChildMetadata(ev *gomatrixserverlib.HeaderedEvent) { count, hash := rc.getChildMetadata(ev.EventID()) + if count == 0 { + return + } err := ev.SetUnsignedField("children_hash", gomatrixserverlib.Base64Bytes(hash)) if err != nil { util.GetLogger(rc.ctx).WithError(err).Warn("Failed to set children_hash") diff --git a/internal/mscs/msc2836/storage.go b/internal/mscs/msc2836/storage.go index 9bec02b4c..ab253b6e7 100644 --- a/internal/mscs/msc2836/storage.go +++ b/internal/mscs/msc2836/storage.go @@ -1,14 +1,15 @@ package msc2836 import ( + "bytes" "context" "database/sql" + "encoding/base64" "encoding/json" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" ) type eventInfo struct { @@ -50,6 +51,9 @@ type DB struct { selectChildrenForParentOldestFirstStmt *sql.Stmt selectChildrenForParentRecentFirstStmt *sql.Stmt selectParentForChildStmt *sql.Stmt + updateChildMetadataStmt *sql.Stmt + selectChildMetadataStmt *sql.Stmt + updateChildMetadataExploredStmt *sql.Stmt } // NewDatabase loads the database for msc2836 @@ -81,19 +85,26 @@ func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { CREATE TABLE IF NOT EXISTS msc2836_nodes ( event_id TEXT PRIMARY KEY NOT NULL, origin_server_ts BIGINT NOT NULL, - room_id TEXT NOT NULL + room_id TEXT NOT NULL, + unsigned_children_count BIGINT NOT NULL, + unsigned_children_hash TEXT NOT NULL, + explored SMALLINT NOT NULL ); `) if err != nil { return nil, err } if d.insertEdgeStmt, err = d.db.Prepare(` - INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) VALUES($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING + INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) + VALUES($1, $2, $3, $4, $5) + ON CONFLICT DO NOTHING `); err != nil { return nil, err } if d.insertNodeStmt, err = d.db.Prepare(` - INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id) VALUES($1, $2, $3) ON CONFLICT DO NOTHING + INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id, unsigned_children_count, unsigned_children_hash, explored) + VALUES($1, $2, $3, $4, $5, $6) + ON CONFLICT DO NOTHING `); err != nil { return nil, err } @@ -110,7 +121,23 @@ func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { return nil, err } if d.selectParentForChildStmt, err = d.db.Prepare(` - SELECT parent_event_id, parent_room_id FROM msc2836_edges WHERE child_event_id = $1 AND rel_type = $2 + SELECT parent_event_id, parent_room_id FROM msc2836_edges + WHERE child_event_id = $1 AND rel_type = $2 + `); err != nil { + return nil, err + } + if d.updateChildMetadataStmt, err = d.db.Prepare(` + UPDATE msc2836_nodes SET unsigned_children_count=$1, unsigned_children_hash=$2, explored=$3 WHERE event_id=$4 + `); err != nil { + return nil, err + } + if d.selectChildMetadataStmt, err = d.db.Prepare(` + SELECT unsigned_children_count, unsigned_children_hash, explored FROM msc2836_nodes WHERE event_id=$1 + `); err != nil { + return nil, err + } + if d.updateChildMetadataExploredStmt, err = d.db.Prepare(` + UPDATE msc2836_nodes SET explored=$1 WHERE event_id=$2 `); err != nil { return nil, err } @@ -138,19 +165,26 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { CREATE TABLE IF NOT EXISTS msc2836_nodes ( event_id TEXT PRIMARY KEY NOT NULL, origin_server_ts BIGINT NOT NULL, - room_id TEXT NOT NULL + room_id TEXT NOT NULL, + unsigned_children_count BIGINT NOT NULL, + unsigned_children_hash TEXT NOT NULL, + explored SMALLINT NOT NULL ); `) if err != nil { return nil, err } if d.insertEdgeStmt, err = d.db.Prepare(` - INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) VALUES($1, $2, $3, $4, $5) ON CONFLICT (parent_event_id, child_event_id, rel_type) DO NOTHING + INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) + VALUES($1, $2, $3, $4, $5) + ON CONFLICT (parent_event_id, child_event_id, rel_type) DO NOTHING `); err != nil { return nil, err } if d.insertNodeStmt, err = d.db.Prepare(` - INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id) VALUES($1, $2, $3) ON CONFLICT DO NOTHING + INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id, unsigned_children_count, unsigned_children_hash, explored) + VALUES($1, $2, $3, $4, $5, $6) + ON CONFLICT DO NOTHING `); err != nil { return nil, err } @@ -167,7 +201,23 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { return nil, err } if d.selectParentForChildStmt, err = d.db.Prepare(` - SELECT parent_event_id, parent_room_id FROM msc2836_edges WHERE child_event_id = $1 AND rel_type = $2 + SELECT parent_event_id, parent_room_id FROM msc2836_edges + WHERE child_event_id = $1 AND rel_type = $2 + `); err != nil { + return nil, err + } + if d.updateChildMetadataStmt, err = d.db.Prepare(` + UPDATE msc2836_nodes SET unsigned_children_count=$1, unsigned_children_hash=$2, explored=$3 WHERE event_id=$4 + `); err != nil { + return nil, err + } + if d.selectChildMetadataStmt, err = d.db.Prepare(` + SELECT unsigned_children_count, unsigned_children_hash, explored FROM msc2836_nodes WHERE event_id=$1 + `); err != nil { + return nil, err + } + if d.updateChildMetadataExploredStmt, err = d.db.Prepare(` + UPDATE msc2836_nodes SET explored=$1 WHERE event_id=$2 `); err != nil { return nil, err } @@ -184,16 +234,50 @@ func (p *DB) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEv if err != nil { return err } + count, hash := extractChildMetadata(ev) return p.writer.Do(p.db, nil, func(txn *sql.Tx) error { _, err := txn.Stmt(p.insertEdgeStmt).ExecContext(ctx, parent, child, relType, relationRoomID, string(relationServersJSON)) if err != nil { return err } - _, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID()) + _, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID(), count, base64.RawStdEncoding.EncodeToString(hash), 0) return err }) } +func (p *DB) UpdateChildMetadata(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error { + // extract current children count/hash, if they are less than the current event then update the columns and set to unexplored + count, hash, _, err := p.ChildMetadata(ctx, ev.EventID()) + if err != nil { + return err + } + eventCount, eventHash := extractChildMetadata(ev) + if eventCount == 0 { + return nil // nothing to update with + } + if eventCount > count || (eventCount == count && !bytes.Equal(hash, eventHash)) { + _, err = p.updateChildMetadataStmt.ExecContext(ctx, eventCount, base64.RawStdEncoding.EncodeToString(eventHash), 0, ev.EventID()) + return err + } + return nil +} + +func (p *DB) ChildMetadata(ctx context.Context, eventID string) (count int, hash []byte, explored bool, err error) { + var b64hash string + var exploredInt int + if err = p.selectChildMetadataStmt.QueryRowContext(ctx, eventID).Scan(&count, &b64hash, &exploredInt); err != nil { + return + } + hash, err = base64.RawStdEncoding.DecodeString(b64hash) + explored = exploredInt > 0 + return +} + +func (p *DB) MarkChildrenExplored(ctx context.Context, eventID string) error { + _, err := p.updateChildMetadataExploredStmt.ExecContext(ctx, 1, eventID) + return err +} + func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) { var rows *sql.Rows var err error @@ -268,7 +352,7 @@ func extractChildMetadata(ev *gomatrixserverlib.HeaderedEvent) (count int, hash Hash gomatrixserverlib.Base64Bytes `json:"children_hash"` }{} if err := json.Unmarshal(ev.Unsigned(), &unsigned); err != nil { - util.GetLogger(context.Background()).WithError(err).Error("failed to read unsigned field of event") + // expected if there is no unsigned field at all return } for _, c := range unsigned.Counts {