dendrite/internal/mscs/msc2836/storage.go
2020-11-03 15:31:10 +00:00

174 lines
4.9 KiB
Go

package msc2836
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
// StoreRelation stores the parent->child and child->parent relationship for later querying.
// Also stores the event metadata e.g timestamp
StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error
// ChildrenForParent returns the event IDs of events who have the given `eventID` as an m.relationship with the
// provided `relType`.
ChildrenForParent(ctx context.Context, eventID, relType string) ([]string, error)
}
type DB struct {
db *sql.DB
writer sqlutil.Writer
insertEdgeStmt *sql.Stmt
insertNodeStmt *sql.Stmt
selectChildrenForParentStmt *sql.Stmt
}
// NewDatabase loads the database for msc2836
func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
if dbOpts.ConnectionString.IsPostgres() {
return newPostgresDatabase(dbOpts)
}
return newSQLiteDatabase(dbOpts)
}
func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
d := DB{
writer: sqlutil.NewDummyWriter(),
}
var err error
if d.db, err = sqlutil.Open(dbOpts); err != nil {
return nil, err
}
_, err = d.db.Exec(`
CREATE TABLE IF NOT EXISTS msc2836_edges (
parent_event_id TEXT NOT NULL,
child_event_id TEXT NOT NULL,
rel_type TEXT NOT NULL,
CONSTRAINT msc2836_edges UNIQUE (parent_event_id, child_event_id, rel_type)
);
CREATE TABLE IF NOT EXISTS msc2836_nodes (
event_id TEXT PRIMARY KEY NOT NULL,
origin_server_ts BIGINT NOT NULL,
room_id TEXT NOT NULL
);
`)
if err != nil {
return nil, err
}
if d.insertEdgeStmt, err = d.db.Prepare(`
INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type) VALUES($1, $2, $3) ON CONFLICT DO NOTHING
`); err != nil {
return nil, err
}
if d.insertNodeStmt, err = d.db.Prepare(`
INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id) VALUES($1, $2, $3) ON CONFLICT DO NOTHING
`); err != nil {
return nil, err
}
if d.selectChildrenForParentStmt, err = d.db.Prepare(`
SELECT child_event_id FROM msc2836_edges WHERE parent_event_id = $1 AND rel_type = $2
`); err != nil {
return nil, err
}
return &d, err
}
func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
d := DB{
writer: sqlutil.NewExclusiveWriter(),
}
var err error
if d.db, err = sqlutil.Open(dbOpts); err != nil {
return nil, err
}
_, err = d.db.Exec(`
CREATE TABLE IF NOT EXISTS msc2836_edges (
parent_event_id TEXT NOT NULL,
child_event_id TEXT NOT NULL,
rel_type TEXT NOT NULL,
UNIQUE (parent_event_id, child_event_id, rel_type)
);
CREATE TABLE IF NOT EXISTS msc2836_nodes (
event_id TEXT PRIMARY KEY NOT NULL,
origin_server_ts BIGINT NOT NULL,
room_id TEXT NOT NULL
);
`)
if err != nil {
return nil, err
}
if d.insertEdgeStmt, err = d.db.Prepare(`
INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type) VALUES($1, $2, $3) ON CONFLICT (parent_event_id, child_event_id, rel_type) DO NOTHING
`); err != nil {
return nil, err
}
if d.insertNodeStmt, err = d.db.Prepare(`
INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id) VALUES($1, $2, $3) ON CONFLICT DO NOTHING
`); err != nil {
return nil, err
}
if d.selectChildrenForParentStmt, err = d.db.Prepare(`
SELECT child_event_id FROM msc2836_edges WHERE parent_event_id = $1 AND rel_type = $2
`); err != nil {
return nil, err
}
return &d, nil
}
func (p *DB) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error {
parent, child, relType := parentChildEventIDs(ev)
if parent == "" || child == "" {
return nil
}
return p.writer.Do(p.db, nil, func(txn *sql.Tx) error {
_, err := txn.Stmt(p.insertEdgeStmt).ExecContext(ctx, parent, child, relType)
if err != nil {
return err
}
_, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID())
return err
})
}
func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string) ([]string, error) {
rows, err := p.selectChildrenForParentStmt.QueryContext(ctx, eventID, relType)
if err != nil {
return nil, err
}
defer rows.Close() // nolint: errcheck
var children []string
for rows.Next() {
var childID string
if err := rows.Scan(&childID); err != nil {
return nil, err
}
children = append(children, childID)
}
return children, nil
}
func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, relType string) {
if ev == nil {
return
}
body := struct {
Relationship struct {
RelType string `json:"rel_type"`
EventID string `json:"event_id"`
} `json:"m.relationship"`
}{}
if err := json.Unmarshal(ev.Content(), &body); err != nil {
return
}
if body.Relationship.EventID == "" || body.Relationship.RelType == "" {
return
}
return body.Relationship.EventID, ev.EventID(), body.Relationship.RelType
}