diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go index b213e057b..81b9b06e8 100644 --- a/roomserver/storage/postgres/event_state_keys_table.go +++ b/roomserver/storage/postgres/event_state_keys_table.go @@ -21,6 +21,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -74,12 +75,13 @@ type eventStateKeyStatements struct { bulkSelectEventStateKeyStmt *sql.Stmt } -func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(eventStateKeysSchema) +func NewPostgresEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { + s := &eventStateKeyStatements{} + _, err := db.Exec(eventStateKeysSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, @@ -87,7 +89,7 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *eventStateKeyStatements) insertEventStateKeyNID( +func (s *eventStateKeyStatements) InsertEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 @@ -96,7 +98,7 @@ func (s *eventStateKeyStatements) insertEventStateKeyNID( return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *eventStateKeyStatements) selectEventStateKeyNID( +func (s *eventStateKeyStatements) SelectEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 @@ -105,7 +107,7 @@ func (s *eventStateKeyStatements) selectEventStateKeyNID( return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( +func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( ctx context.Context, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext( @@ -128,7 +130,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( return result, rows.Err() } -func (s *eventStateKeyStatements) bulkSelectEventStateKey( +func (s *eventStateKeyStatements) BulkSelectEventStateKey( ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { nIDs := make(pq.Int64Array, len(eventStateKeyNIDs)) diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index 2b0910e71..aaba614a4 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -98,36 +99,37 @@ type eventTypeStatements struct { bulkSelectEventTypeNIDStmt *sql.Stmt } -func (s *eventTypeStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(eventTypesSchema) +func NewPostgresEventTypesTable(db *sql.DB) (tables.EventTypes, error) { + s := &eventTypeStatements{} + _, err := db.Exec(eventTypesSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, {&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL}, }.prepare(db) } -func (s *eventTypeStatements) insertEventTypeNID( - ctx context.Context, eventType string, +func (s *eventTypeStatements) InsertEventTypeNID( + ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 - err := s.insertEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) + err := txn.Stmt(s.insertEventTypeNIDStmt).QueryRowContext(ctx, eventType).Scan(&eventTypeNID) return types.EventTypeNID(eventTypeNID), err } -func (s *eventTypeStatements) selectEventTypeNID( - ctx context.Context, eventType string, +func (s *eventTypeStatements) SelectEventTypeNID( + ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 - err := s.selectEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) + err := txn.Stmt(s.selectEventTypeNIDStmt).QueryRowContext(ctx, eventType).Scan(&eventTypeNID) return types.EventTypeNID(eventTypeNID), err } -func (s *eventTypeStatements) bulkSelectEventTypeNID( +func (s *eventTypeStatements) BulkSelectEventTypeNID( ctx context.Context, eventTypes []string, ) (map[string]types.EventTypeNID, error) { rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes)) diff --git a/roomserver/storage/postgres/sql.go b/roomserver/storage/postgres/sql.go index 5956886ce..e41c5a398 100644 --- a/roomserver/storage/postgres/sql.go +++ b/roomserver/storage/postgres/sql.go @@ -38,8 +38,6 @@ func (s *statements) prepare(db *sql.DB) error { var err error for _, prepare := range []func(db *sql.DB) error{ - s.eventTypeStatements.prepare, - s.eventStateKeyStatements.prepare, s.roomStatements.prepare, s.eventStatements.prepare, s.eventJSONStatements.prepare, diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 4d1d603e3..6fcceced5 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -26,14 +26,19 @@ import ( // Import the postgres database driver. _ "github.com/lib/pq" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) // A Database is used to store room events and stream offsets. type Database struct { - statements statements - db *sql.DB + shared.Database + statements statements + eventTypes tables.EventTypes + eventStateKeys tables.EventStateKeys + db *sql.DB } // Open a postgres database. @@ -46,6 +51,18 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, if err = d.statements.prepare(d.db); err != nil { return nil, err } + d.eventStateKeys, err = NewPostgresEventStateKeysTable(d.db) + if err != nil { + return nil, err + } + d.eventTypes, err = NewPostgresEventTypesTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + EventTypesTable: d.eventTypes, + EventStateKeysTable: d.eventStateKeys, + } return &d, nil } @@ -180,17 +197,20 @@ func (d *Database) assignRoomNID( func (d *Database) assignEventTypeNID( ctx context.Context, eventType string, -) (types.EventTypeNID, error) { - // Check if we already have a numeric ID in the database. - eventTypeNID, err := d.statements.selectEventTypeNID(ctx, eventType) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - eventTypeNID, err = d.statements.insertEventTypeNID(ctx, eventType) +) (eventTypeNID types.EventTypeNID, err error) { + err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { + // Check if we already have a numeric ID in the database. + eventTypeNID, err = d.eventTypes.SelectEventTypeNID(ctx, txn, eventType) if err == sql.ErrNoRows { - // We raced with another insert so run the select again. - eventTypeNID, err = d.statements.selectEventTypeNID(ctx, eventType) + // We don't have a numeric ID so insert one into the database. + eventTypeNID, err = d.eventTypes.InsertEventTypeNID(ctx, txn, eventType) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + eventTypeNID, err = d.eventTypes.SelectEventTypeNID(ctx, txn, eventType) + } } - } + return err + }) return eventTypeNID, err } @@ -198,13 +218,13 @@ func (d *Database) assignStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { // Check if we already have a numeric ID in the database. - eventStateKeyNID, err := d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) + eventStateKeyNID, err := d.eventStateKeys.SelectEventStateKeyNID(ctx, txn, eventStateKey) if err == sql.ErrNoRows { // We don't have a numeric ID so insert one into the database. - eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey) + eventStateKeyNID, err = d.eventStateKeys.InsertEventStateKeyNID(ctx, txn, eventStateKey) if err == sql.ErrNoRows { // We raced with another insert so run the select again. - eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) + eventStateKeyNID, err = d.eventStateKeys.SelectEventStateKeyNID(ctx, txn, eventStateKey) } } return eventStateKeyNID, err @@ -217,27 +237,6 @@ func (d *Database) StateEntriesForEventIDs( return d.statements.bulkSelectStateEventByID(ctx, eventIDs) } -// EventTypeNIDs implements state.RoomStateDatabase -func (d *Database) EventTypeNIDs( - ctx context.Context, eventTypes []string, -) (map[string]types.EventTypeNID, error) { - return d.statements.bulkSelectEventTypeNID(ctx, eventTypes) -} - -// EventStateKeyNIDs implements state.RoomStateDatabase -func (d *Database) EventStateKeyNIDs( - ctx context.Context, eventStateKeys []string, -) (map[string]types.EventStateKeyNID, error) { - return d.statements.bulkSelectEventStateKeyNID(ctx, eventStateKeys) -} - -// EventStateKeys implements query.RoomserverQueryAPIDatabase -func (d *Database) EventStateKeys( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, -) (map[types.EventStateKeyNID]string, error) { - return d.statements.bulkSelectEventStateKey(ctx, eventStateKeyNIDs) -} - // EventNIDs implements query.RoomserverQueryAPIDatabase func (d *Database) EventNIDs( ctx context.Context, eventIDs []string, diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go new file mode 100644 index 000000000..7a8da8658 --- /dev/null +++ b/roomserver/storage/shared/storage.go @@ -0,0 +1,34 @@ +package shared + +import ( + "context" + + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" +) + +type Database struct { + EventTypesTable tables.EventTypes + EventStateKeysTable tables.EventStateKeys +} + +// EventTypeNIDs implements state.RoomStateDatabase +func (d *Database) EventTypeNIDs( + ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + return d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes) +} + +// EventStateKeys implements query.RoomserverQueryAPIDatabase +func (d *Database) EventStateKeys( + ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, +) (map[types.EventStateKeyNID]string, error) { + return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs) +} + +// EventStateKeyNIDs implements state.RoomStateDatabase +func (d *Database) EventStateKeyNIDs( + ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + return d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys) +} diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index f49ebf554..0d3d323fb 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -67,13 +68,14 @@ type eventStateKeyStatements struct { bulkSelectEventStateKeyStmt *sql.Stmt } -func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { +func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { + s := &eventStateKeyStatements{} s.db = db - _, err = db.Exec(eventStateKeysSchema) + _, err := db.Exec(eventStateKeysSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, @@ -81,7 +83,7 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *eventStateKeyStatements) insertEventStateKeyNID( +func (s *eventStateKeyStatements) InsertEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 @@ -94,7 +96,7 @@ func (s *eventStateKeyStatements) insertEventStateKeyNID( return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *eventStateKeyStatements) selectEventStateKeyNID( +func (s *eventStateKeyStatements) SelectEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 @@ -103,8 +105,8 @@ func (s *eventStateKeyStatements) selectEventStateKeyNID( return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( - ctx context.Context, txn *sql.Tx, eventStateKeys []string, +func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( + ctx context.Context, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { iEventStateKeys := make([]interface{}, len(eventStateKeys)) for k, v := range eventStateKeys { @@ -112,7 +114,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( } selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", internal.QueryVariadic(len(eventStateKeys)), 1) - rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeys...) + rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...) if err != nil { return nil, err } @@ -129,8 +131,8 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( return result, nil } -func (s *eventStateKeyStatements) bulkSelectEventStateKey( - ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, +func (s *eventStateKeyStatements) BulkSelectEventStateKey( + ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs)) for k, v := range eventStateKeyNIDs { @@ -138,7 +140,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKey( } selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", internal.QueryVariadic(len(eventStateKeyNIDs)), 1) - rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) + rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 13abcd4df..d47be5453 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -81,14 +82,15 @@ type eventTypeStatements struct { bulkSelectEventTypeNIDStmt *sql.Stmt } -func (s *eventTypeStatements) prepare(db *sql.DB) (err error) { +func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) { + s := &eventTypeStatements{} s.db = db - _, err = db.Exec(eventTypesSchema) + _, err := db.Exec(eventTypesSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, {&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL}, {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, @@ -96,7 +98,7 @@ func (s *eventTypeStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *eventTypeStatements) insertEventTypeNID( +func (s *eventTypeStatements) InsertEventTypeNID( ctx context.Context, tx *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 @@ -109,7 +111,7 @@ func (s *eventTypeStatements) insertEventTypeNID( return types.EventTypeNID(eventTypeNID), err } -func (s *eventTypeStatements) selectEventTypeNID( +func (s *eventTypeStatements) SelectEventTypeNID( ctx context.Context, tx *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 @@ -118,8 +120,8 @@ func (s *eventTypeStatements) selectEventTypeNID( return types.EventTypeNID(eventTypeNID), err } -func (s *eventTypeStatements) bulkSelectEventTypeNID( - ctx context.Context, tx *sql.Tx, eventTypes []string, +func (s *eventTypeStatements) BulkSelectEventTypeNID( + ctx context.Context, eventTypes []string, ) (map[string]types.EventTypeNID, error) { /////////////// iEventTypes := make([]interface{}, len(eventTypes)) @@ -133,8 +135,7 @@ func (s *eventTypeStatements) bulkSelectEventTypeNID( } /////////////// - selectStmt := internal.TxStmt(tx, selectPrep) - rows, err := selectStmt.QueryContext(ctx, iEventTypes...) + rows, err := selectPrep.QueryContext(ctx, iEventTypes...) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/sql.go b/roomserver/storage/sqlite3/sql.go index 0d49432b8..bb3318b2d 100644 --- a/roomserver/storage/sqlite3/sql.go +++ b/roomserver/storage/sqlite3/sql.go @@ -38,8 +38,6 @@ func (s *statements) prepare(db *sql.DB) error { var err error for _, prepare := range []func(db *sql.DB) error{ - s.eventTypeStatements.prepare, - s.eventStateKeyStatements.prepare, s.roomStatements.prepare, s.eventStatements.prepare, s.eventJSONStatements.prepare, diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index bb38f800f..b9157e3a5 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -26,6 +26,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" _ "github.com/mattn/go-sqlite3" @@ -33,8 +35,11 @@ import ( // A Database is used to store room events and stream offsets. type Database struct { - statements statements - db *sql.DB + shared.Database + statements statements + eventTypes tables.EventTypes + eventStateKeys tables.EventStateKeys + db *sql.DB } // Open a sqlite database. @@ -66,6 +71,18 @@ func Open(dataSourceName string) (*Database, error) { if err = d.statements.prepare(d.db); err != nil { return nil, err } + d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db) + if err != nil { + return nil, err + } + d.eventTypes, err = NewSqliteEventTypesTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + EventTypesTable: d.eventTypes, + EventStateKeysTable: d.eventStateKeys, + } return &d, nil } @@ -210,13 +227,13 @@ func (d *Database) assignEventTypeNID( ctx context.Context, txn *sql.Tx, eventType string, ) (eventTypeNID types.EventTypeNID, err error) { // Check if we already have a numeric ID in the database. - eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType) + eventTypeNID, err = d.eventTypes.SelectEventTypeNID(ctx, txn, eventType) if err == sql.ErrNoRows { // We don't have a numeric ID so insert one into the database. - eventTypeNID, err = d.statements.insertEventTypeNID(ctx, txn, eventType) + eventTypeNID, err = d.eventTypes.InsertEventTypeNID(ctx, txn, eventType) if err == sql.ErrNoRows { // We raced with another insert so run the select again. - eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType) + eventTypeNID, err = d.eventTypes.SelectEventTypeNID(ctx, txn, eventType) } } return @@ -226,13 +243,13 @@ func (d *Database) assignStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (eventStateKeyNID types.EventStateKeyNID, err error) { // Check if we already have a numeric ID in the database. - eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) + eventStateKeyNID, err = d.eventStateKeys.SelectEventStateKeyNID(ctx, txn, eventStateKey) if err == sql.ErrNoRows { // We don't have a numeric ID so insert one into the database. - eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey) + eventStateKeyNID, err = d.eventStateKeys.InsertEventStateKeyNID(ctx, txn, eventStateKey) if err == sql.ErrNoRows { // We raced with another insert so run the select again. - eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) + eventStateKeyNID, err = d.eventStateKeys.SelectEventStateKeyNID(ctx, txn, eventStateKey) } } return @@ -249,39 +266,6 @@ func (d *Database) StateEntriesForEventIDs( return } -// EventTypeNIDs implements state.RoomStateDatabase -func (d *Database) EventTypeNIDs( - ctx context.Context, eventTypes []string, -) (etnids map[string]types.EventTypeNID, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - etnids, err = d.statements.bulkSelectEventTypeNID(ctx, txn, eventTypes) - return err - }) - return -} - -// EventStateKeyNIDs implements state.RoomStateDatabase -func (d *Database) EventStateKeyNIDs( - ctx context.Context, eventStateKeys []string, -) (esknids map[string]types.EventStateKeyNID, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - esknids, err = d.statements.bulkSelectEventStateKeyNID(ctx, txn, eventStateKeys) - return err - }) - return -} - -// EventStateKeys implements query.RoomserverQueryAPIDatabase -func (d *Database) EventStateKeys( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, -) (out map[types.EventStateKeyNID]string, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - out, err = d.statements.bulkSelectEventStateKey(ctx, txn, eventStateKeyNIDs) - return err - }) - return -} - // EventNIDs implements query.RoomserverQueryAPIDatabase func (d *Database) EventNIDs( ctx context.Context, eventIDs []string, diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go new file mode 100644 index 000000000..d607865dc --- /dev/null +++ b/roomserver/storage/tables/interface.go @@ -0,0 +1,21 @@ +package tables + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/roomserver/types" +) + +type EventTypes interface { + InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) + SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) + BulkSelectEventTypeNID(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) +} + +type EventStateKeys interface { + InsertEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) + SelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) + BulkSelectEventStateKeyNID(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + BulkSelectEventStateKey(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) +}