diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go index c3d661afb..307911ef1 100644 --- a/internal/hooks/hooks.go +++ b/internal/hooks/hooks.go @@ -16,6 +16,50 @@ // Hooks can only be run in monolith mode. package hooks -func Attach() { +import "sync" +const ( + // KindNewEvent is a hook which is called with *gomatrixserverlib.HeaderedEvent + // It is run when a new event is persisted in the roomserver. + // Usage: + // hooks.Attach(hooks.KindNewEvent, func(headeredEvent interface{}) { ... }) + KindNewEvent = "new_event" +) + +var ( + hookMap = make(map[string][]func(interface{})) + hookMu = sync.Mutex{} + enabled = false +) + +// Enable all hooks. This may slow down the server slightly. Required for MSCs to work. +func Enable() { + enabled = true +} + +// Run any hooks +func Run(kind string, data interface{}) { + if !enabled { + return + } + cbs := callbacks(kind) + for _, cb := range cbs { + cb(data) + } +} + +// Attach a hook +func Attach(kind string, callback func(interface{})) { + if !enabled { + return + } + hookMu.Lock() + defer hookMu.Unlock() + hookMap[kind] = append(hookMap[kind], callback) +} + +func callbacks(kind string) []func(interface{}) { + hookMu.Lock() + defer hookMu.Unlock() + return hookMap[kind] } diff --git a/internal/mscs/msc2836/msc2836.go b/internal/mscs/msc2836/msc2836.go index 491b89e14..1277a44df 100644 --- a/internal/mscs/msc2836/msc2836.go +++ b/internal/mscs/msc2836/msc2836.go @@ -16,14 +16,17 @@ package msc2836 import ( + "context" "encoding/json" "fmt" "net/http" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/setup" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -42,10 +45,20 @@ type eventRelationshipRequest struct { // Enable this MSC func Enable(base *setup.BaseDendrite, monolith *setup.Monolith) error { - _, err := NewDatabase(&base.Cfg.MSCs.Database) + db, err := NewDatabase(&base.Cfg.MSCs.Database) if err != nil { return fmt.Errorf("Cannot enable MSC2836: %w", err) } + hooks.Enable() + hooks.Attach(hooks.KindNewEvent, func(headeredEvent interface{}) { + he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) + hookErr := db.StoreRelation(context.Background(), he) + if hookErr != nil { + util.GetLogger(context.Background()).WithError(hookErr).Error( + "failed to StoreRelation", + ) + } + }) base.PublicClientAPIMux.Handle("/unstable/event_relationships", httputil.MakeAuthAPI("eventRelationships", monolith.UserAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/internal/mscs/msc2836/storage.go b/internal/mscs/msc2836/storage.go index 2a8cb99a6..7401d7fbc 100644 --- a/internal/mscs/msc2836/storage.go +++ b/internal/mscs/msc2836/storage.go @@ -3,6 +3,7 @@ package msc2836 import ( "context" "database/sql" + "encoding/json" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -15,7 +16,8 @@ type Database interface { } type Postgres struct { - db *sql.DB + db *sql.DB + insertRelationStmt *sql.Stmt } func NewPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { @@ -24,16 +26,35 @@ func NewPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { if p.db, err = sqlutil.Open(dbOpts); err != nil { return nil, err } - return &p, nil + _, err = p.db.Exec(` + CREATE TABLE IF NOT EXISTS msc2836_relationships ( + parent_event_id TEXT NOT NULL, + child_event_id TEXT NOT NULL, + parent_room_id TEXT NOT NULL, + CONSTRAINT msc2836_relationships_unique UNIQUE (parent_event_id, child_event_id) + ); + `) + if p.insertRelationStmt, err = p.db.Prepare(` + INSERT INTO msc2836_relationships(parent_event_id, child_event_id, parent_room_id) VALUES($1, $2, $3) ON CONFLICT DO NOTHING + `); err != nil { + return nil, err + } + return &p, err } -func (db *Postgres) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error { - return nil +func (p *Postgres) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error { + parent, child := parentChildEventIDs(ev) + if parent == "" || child == "" { + return nil + } + _, err := p.insertRelationStmt.ExecContext(ctx, parent, child, "") + return err } type SQLite struct { - db *sql.DB - writer sqlutil.Writer + db *sql.DB + insertRelationStmt *sql.Stmt + writer sqlutil.Writer } func NewSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { @@ -43,11 +64,29 @@ func NewSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { return nil, err } s.writer = sqlutil.NewExclusiveWriter() + _, err = s.db.Exec(` + CREATE TABLE IF NOT EXISTS msc2836_relationships ( + parent_event_id TEXT NOT NULL, + child_event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + UNIQUE (parent_event_id, child_event_id) + ); + `) + if s.insertRelationStmt, err = s.db.Prepare(` + INSERT INTO msc2836_relationships(parent_event_id, child_event_id, room_id) VALUES($1, $2, $3) ON CONFLICT (parent_event_id, child_event_id) DO NOTHING + `); err != nil { + return nil, err + } return &s, nil } -func (db *SQLite) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error { - return nil +func (s *SQLite) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error { + parent, child := parentChildEventIDs(ev) + if parent == "" || child == "" { + return nil + } + _, err := s.insertRelationStmt.ExecContext(ctx, parent, child, "") + return err } // NewDatabase loads the database for msc2836 @@ -57,3 +96,22 @@ func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) { } return NewSQLiteDatabase(dbOpts) } + +func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent string, child string) { + if ev == nil { + return + } + body := struct { + Relationship struct { + RelType string `json:"rel_type"` + EventID string `json:"event_id"` + } `json:"m.relationship"` + }{} + if err := json.Unmarshal(ev.Content(), &body); err != nil { + return + } + if body.Relationship.RelType == "m.reference" && body.Relationship.EventID != "" { + return body.Relationship.EventID, ev.EventID() + } + return +} diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 99c15f77a..d25030a3c 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -22,6 +22,7 @@ import ( "time" "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage" @@ -62,6 +63,9 @@ func (w *inputWorker) start() { select { case task := <-w.input: _, task.err = w.r.processRoomEvent(task.ctx, task.event) + if task.err == nil { + hooks.Run(hooks.KindNewEvent, &task.event.Event) + } task.wg.Done() case <-time.After(time.Second * 5): return