diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index 661c44721..a32629260 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -58,32 +59,28 @@ type eventJSONStatements struct { bulkSelectEventJSONStmt *sql.Stmt } -func (s *eventJSONStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(eventJSONSchema) +func NewPostgresEventJSONTable(db *sql.DB) (tables.EventJSON, error) { + s := &eventJSONStatements{} + _, err := db.Exec(eventJSONSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, }.prepare(db) } -func (s *eventJSONStatements) insertEventJSON( - ctx context.Context, eventNID types.EventNID, eventJSON []byte, +func (s *eventJSONStatements) InsertEventJSON( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { _, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON) return err } -type eventJSONPair struct { - EventNID types.EventNID - EventJSON []byte -} - -func (s *eventJSONStatements) bulkSelectEventJSON( +func (s *eventJSONStatements) BulkSelectEventJSON( ctx context.Context, eventNIDs []types.EventNID, -) ([]eventJSONPair, error) { +) ([]tables.EventJSONPair, error) { rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err @@ -94,7 +91,7 @@ func (s *eventJSONStatements) bulkSelectEventJSON( // because of the unique constraint on event NIDs. // So we can allocate an array of the correct size now. // We might get fewer results than NIDs so we adjust the length of the slice before returning it. - results := make([]eventJSONPair, len(eventNIDs)) + results := make([]tables.EventJSONPair, len(eventNIDs)) i := 0 for ; rows.Next(); i++ { result := &results[i] diff --git a/roomserver/storage/postgres/sql.go b/roomserver/storage/postgres/sql.go index e41c5a398..7afd1f830 100644 --- a/roomserver/storage/postgres/sql.go +++ b/roomserver/storage/postgres/sql.go @@ -40,7 +40,6 @@ func (s *statements) prepare(db *sql.DB) error { for _, prepare := range []func(db *sql.DB) error{ s.roomStatements.prepare, s.eventStatements.prepare, - s.eventJSONStatements.prepare, s.stateSnapshotStatements.prepare, s.stateBlockStatements.prepare, s.previousEventStatements.prepare, diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 6fcceced5..e952d7517 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -38,6 +38,7 @@ type Database struct { statements statements eventTypes tables.EventTypes eventStateKeys tables.EventStateKeys + eventJSON tables.EventJSON db *sql.DB } @@ -59,9 +60,14 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, if err != nil { return nil, err } + d.eventJSON, err = NewPostgresEventJSONTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ EventTypesTable: d.eventTypes, EventStateKeysTable: d.eventStateKeys, + EventJSON: d.eventJSON, } return &d, nil } @@ -139,7 +145,7 @@ func (d *Database) StoreEvent( } } - if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil { + if err = d.eventJSON.InsertEventJSON(ctx, nil, eventNID, event.JSON()); err != nil { return 0, types.StateAtEvent{}, err } @@ -248,7 +254,7 @@ func (d *Database) EventNIDs( func (d *Database) Events( ctx context.Context, eventNIDs []types.EventNID, ) ([]types.Event, error) { - eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs) + eventJSONs, err := d.eventJSON.BulkSelectEventJSON(ctx, eventNIDs) if err != nil { return nil, err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 7a8da8658..85cb0754a 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -8,6 +8,7 @@ import ( ) type Database struct { + EventJSON tables.EventJSON EventTypesTable tables.EventTypes EventStateKeysTable tables.EventStateKeys } diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index fbf35e711..34b067cb0 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -51,40 +52,36 @@ type eventJSONStatements struct { bulkSelectEventJSONStmt *sql.Stmt } -func (s *eventJSONStatements) prepare(db *sql.DB) (err error) { +func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { + s := &eventJSONStatements{} s.db = db - _, err = db.Exec(eventJSONSchema) + _, err := db.Exec(eventJSONSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, }.prepare(db) } -func (s *eventJSONStatements) insertEventJSON( +func (s *eventJSONStatements) InsertEventJSON( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { _, err := internal.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) return err } -type eventJSONPair struct { - EventNID types.EventNID - EventJSON []byte -} - -func (s *eventJSONStatements) bulkSelectEventJSON( - ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, -) ([]eventJSONPair, error) { +func (s *eventJSONStatements) BulkSelectEventJSON( + ctx context.Context, eventNIDs []types.EventNID, +) ([]tables.EventJSONPair, error) { iEventNIDs := make([]interface{}, len(eventNIDs)) for k, v := range eventNIDs { iEventNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", internal.QueryVariadic(len(iEventNIDs)), 1) - rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...) + rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...) if err != nil { return nil, err } @@ -94,7 +91,7 @@ func (s *eventJSONStatements) bulkSelectEventJSON( // because of the unique constraint on event NIDs. // So we can allocate an array of the correct size now. // We might get fewer results than NIDs so we adjust the length of the slice before returning it. - results := make([]eventJSONPair, len(eventNIDs)) + results := make([]tables.EventJSONPair, len(eventNIDs)) i := 0 for ; rows.Next(); i++ { result := &results[i] diff --git a/roomserver/storage/sqlite3/sql.go b/roomserver/storage/sqlite3/sql.go index bb3318b2d..00ab3df6f 100644 --- a/roomserver/storage/sqlite3/sql.go +++ b/roomserver/storage/sqlite3/sql.go @@ -40,7 +40,6 @@ func (s *statements) prepare(db *sql.DB) error { for _, prepare := range []func(db *sql.DB) error{ s.roomStatements.prepare, s.eventStatements.prepare, - s.eventJSONStatements.prepare, s.stateSnapshotStatements.prepare, s.stateBlockStatements.prepare, s.previousEventStatements.prepare, diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index b9157e3a5..d3f230ca9 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -37,6 +37,7 @@ import ( type Database struct { shared.Database statements statements + eventJSON tables.EventJSON eventTypes tables.EventTypes eventStateKeys tables.EventStateKeys db *sql.DB @@ -79,9 +80,14 @@ func Open(dataSourceName string) (*Database, error) { if err != nil { return nil, err } + d.eventJSON, err = NewSqliteEventJSONTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ EventTypesTable: d.eventTypes, EventStateKeysTable: d.eventStateKeys, + EventJSON: d.eventJSON, } return &d, nil } @@ -161,7 +167,7 @@ func (d *Database) StoreEvent( } } - if err = d.statements.insertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { + if err = d.eventJSON.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { return err } @@ -281,14 +287,14 @@ func (d *Database) EventNIDs( func (d *Database) Events( ctx context.Context, eventNIDs []types.EventNID, ) ([]types.Event, error) { - var eventJSONs []eventJSONPair + var eventJSONs []tables.EventJSONPair var err error var results []types.Event + eventJSONs, err = d.eventJSON.BulkSelectEventJSON(ctx, eventNIDs) + if err != nil || len(eventJSONs) == 0 { + return nil, nil + } err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs) - if err != nil || len(eventJSONs) == 0 { - return nil - } results = make([]types.Event, len(eventJSONs)) for i, eventJSON := range eventJSONs { var roomNID types.RoomNID diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index d607865dc..4553bacb1 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -7,6 +7,16 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" ) +type EventJSONPair struct { + EventNID types.EventNID + EventJSON []byte +} + +type EventJSON interface { + InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error + BulkSelectEventJSON(ctx context.Context, eventNIDs []types.EventNID) ([]EventJSONPair, error) +} + type EventTypes interface { InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error)