diff --git a/roomserver/storage/postgres/previous_events_table.go b/roomserver/storage/postgres/previous_events_table.go index e3ad5dc80..b3d32c953 100644 --- a/roomserver/storage/postgres/previous_events_table.go +++ b/roomserver/storage/postgres/previous_events_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -63,19 +64,20 @@ type previousEventStatements struct { selectPreviousEventExistsStmt *sql.Stmt } -func (s *previousEventStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(previousEventSchema) +func NewPostgresPreviousEventsTable(db *sql.DB) (tables.PreviousEvents, error) { + s := &previousEventStatements{} + _, err := db.Exec(previousEventSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, }.prepare(db) } -func (s *previousEventStatements) insertPreviousEvent( +func (s *previousEventStatements) InsertPreviousEvent( ctx context.Context, txn *sql.Tx, previousEventID string, @@ -91,7 +93,7 @@ func (s *previousEventStatements) insertPreviousEvent( // Check if the event reference exists // Returns sql.ErrNoRows if the event reference doesn't exist. -func (s *previousEventStatements) selectPreviousEventExists( +func (s *previousEventStatements) SelectPreviousEventExists( ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, ) error { var ok int64 diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go index c77edd0e0..f869cf4fb 100644 --- a/roomserver/storage/postgres/room_aliases_table.go +++ b/roomserver/storage/postgres/room_aliases_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" ) const roomAliasesSchema = ` @@ -59,12 +60,13 @@ type roomAliasesStatements struct { deleteRoomAliasStmt *sql.Stmt } -func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(roomAliasesSchema) +func NewPostgresRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { + s := &roomAliasesStatements{} + _, err := db.Exec(roomAliasesSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, @@ -73,14 +75,14 @@ func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *roomAliasesStatements) insertRoomAlias( +func (s *roomAliasesStatements) InsertRoomAlias( ctx context.Context, alias string, roomID string, creatorUserID string, ) (err error) { _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) return } -func (s *roomAliasesStatements) selectRoomIDFromAlias( +func (s *roomAliasesStatements) SelectRoomIDFromAlias( ctx context.Context, alias string, ) (roomID string, err error) { err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) @@ -90,7 +92,7 @@ func (s *roomAliasesStatements) selectRoomIDFromAlias( return } -func (s *roomAliasesStatements) selectAliasesFromRoomID( +func (s *roomAliasesStatements) SelectAliasesFromRoomID( ctx context.Context, roomID string, ) ([]string, error) { rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) @@ -111,7 +113,7 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID( return aliases, rows.Err() } -func (s *roomAliasesStatements) selectCreatorIDFromAlias( +func (s *roomAliasesStatements) SelectCreatorIDFromAlias( ctx context.Context, alias string, ) (creatorID string, err error) { err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) @@ -121,7 +123,7 @@ func (s *roomAliasesStatements) selectCreatorIDFromAlias( return } -func (s *roomAliasesStatements) deleteRoomAlias( +func (s *roomAliasesStatements) DeleteRoomAlias( ctx context.Context, alias string, ) (err error) { _, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias) diff --git a/roomserver/storage/postgres/sql.go b/roomserver/storage/postgres/sql.go index 914f269c5..eb626dd88 100644 --- a/roomserver/storage/postgres/sql.go +++ b/roomserver/storage/postgres/sql.go @@ -38,10 +38,6 @@ func (s *statements) prepare(db *sql.DB) error { var err error for _, prepare := range []func(db *sql.DB) error{ - s.stateSnapshotStatements.prepare, - s.stateBlockStatements.prepare, - s.previousEventStatements.prepare, - s.roomAliasesStatements.prepare, s.inviteStatements.prepare, s.membershipStatements.prepare, } { diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index 38334fa9e..d1aaaa003 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" ) @@ -87,13 +88,14 @@ type stateBlockStatements struct { bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt } -func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(stateDataSchema) +func NewPostgresStateBlockTable(db *sql.DB) (tables.StateBlock, error) { + s := &stateBlockStatements{} + _, err := db.Exec(stateDataSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertStateDataStmt, insertStateDataSQL}, {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, @@ -101,11 +103,15 @@ func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *stateBlockStatements) bulkInsertStateData( +func (s *stateBlockStatements) BulkInsertStateData( ctx context.Context, - stateBlockNID types.StateBlockNID, + txn *sql.Tx, entries []types.StateEntry, -) error { +) (types.StateBlockNID, error) { + stateBlockNID, err := s.selectNextStateBlockNID(ctx) + if err != nil { + return 0, err + } for _, entry := range entries { _, err := s.insertStateDataStmt.ExecContext( ctx, @@ -115,10 +121,10 @@ func (s *stateBlockStatements) bulkInsertStateData( int64(entry.EventNID), ) if err != nil { - return err + return 0, err } } - return nil + return stateBlockNID, nil } func (s *stateBlockStatements) selectNextStateBlockNID( @@ -129,7 +135,7 @@ func (s *stateBlockStatements) selectNextStateBlockNID( return types.StateBlockNID(stateBlockNID), err } -func (s *stateBlockStatements) bulkSelectStateBlockEntries( +func (s *stateBlockStatements) BulkSelectStateBlockEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { nids := make([]int64, len(stateBlockNIDs)) @@ -180,7 +186,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries( return results, err } -func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( +func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index a1f26e228..8971292fc 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -21,6 +21,7 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -64,30 +65,31 @@ type stateSnapshotStatements struct { bulkSelectStateBlockNIDsStmt *sql.Stmt } -func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(stateSnapshotSchema) +func NewPostgresStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { + s := &stateSnapshotStatements{} + _, err := db.Exec(stateSnapshotSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, }.prepare(db) } -func (s *stateSnapshotStatements) insertState( - ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, +func (s *stateSnapshotStatements) InsertState( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, ) (stateNID types.StateSnapshotNID, err error) { nids := make([]int64, len(stateBlockNIDs)) for i := range stateBlockNIDs { nids[i] = int64(stateBlockNIDs[i]) } - err = s.insertStateStmt.QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) + err = txn.Stmt(s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) return } -func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( +func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { nids := make([]int64, len(stateNIDs)) diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index d44da8585..03cfb7f0e 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -40,10 +40,12 @@ type Database struct { eventJSON tables.EventJSON rooms tables.Rooms transactions tables.Transactions + prevEvents tables.PreviousEvents db *sql.DB } // Open a postgres database. +// nolint: gocyclo func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, error) { var d Database var err error @@ -77,6 +79,22 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, if err != nil { return nil, err } + stateBlock, err := NewPostgresStateBlockTable(d.db) + if err != nil { + return nil, err + } + stateSnapshot, err := NewPostgresStateSnapshotTable(d.db) + if err != nil { + return nil, err + } + roomAliases, err := NewPostgresRoomAliasesTable(d.db) + if err != nil { + return nil, err + } + d.prevEvents, err = NewPostgresPreviousEventsTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, EventTypesTable: d.eventTypes, @@ -85,6 +103,10 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, EventsTable: d.events, RoomsTable: d.rooms, TransactionsTable: d.transactions, + StateBlockTable: stateBlock, + StateSnapshotTable: stateSnapshot, + PrevEventsTable: d.prevEvents, + RoomAliasesTable: roomAliases, } return &d, nil } @@ -122,41 +144,6 @@ func (d *Database) assignStateKeyNID( return eventStateKeyNID, err } -// AddState implements input.EventDatabase -func (d *Database) AddState( - ctx context.Context, - roomNID types.RoomNID, - stateBlockNIDs []types.StateBlockNID, - state []types.StateEntry, -) (types.StateSnapshotNID, error) { - if len(state) > 0 { - stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx) - if err != nil { - return 0, err - } - if err = d.statements.bulkInsertStateData(ctx, stateBlockNID, state); err != nil { - return 0, err - } - stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) - } - - return d.statements.insertState(ctx, roomNID, stateBlockNIDs) -} - -// StateBlockNIDs implements state.RoomStateDatabase -func (d *Database) StateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, -) ([]types.StateBlockNIDList, error) { - return d.statements.bulkSelectStateBlockNIDs(ctx, stateNIDs) -} - -// StateEntries implements state.RoomStateDatabase -func (d *Database) StateEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) ([]types.StateEntryList, error) { - return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs) -} - // GetLatestEventsForUpdate implements input.EventDatabase func (d *Database) GetLatestEventsForUpdate( ctx context.Context, roomNID types.RoomNID, @@ -222,7 +209,7 @@ func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotN // StorePreviousEvents implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { for _, ref := range previousEventReferences { - if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + if err := u.d.prevEvents.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { return err } } @@ -231,7 +218,7 @@ func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, p // IsReferenced implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { - err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) + err := u.d.prevEvents.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) if err == nil { return true, nil } @@ -276,44 +263,6 @@ func (d *Database) GetInvitesForUser( return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) } -// SetRoomAlias implements alias.RoomserverAliasAPIDB -func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { - return d.statements.insertRoomAlias(ctx, alias, roomID, creatorUserID) -} - -// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB -func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { - return d.statements.selectRoomIDFromAlias(ctx, alias) -} - -// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB -func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { - return d.statements.selectAliasesFromRoomID(ctx, roomID) -} - -// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB -func (d *Database) GetCreatorIDForAlias( - ctx context.Context, alias string, -) (string, error) { - return d.statements.selectCreatorIDFromAlias(ctx, alias) -} - -// RemoveRoomAlias implements alias.RoomserverAliasAPIDB -func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { - return d.statements.deleteRoomAlias(ctx, alias) -} - -// StateEntriesForTuples implements state.RoomStateDatabase -func (d *Database) StateEntriesForTuples( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - return d.statements.bulkSelectFilteredStateBlockEntries( - ctx, stateBlockNIDs, stateKeyTuples, - ) -} - // MembershipUpdater implements input.RoomEventDatabase func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 814b6e812..29c6f73eb 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -20,6 +20,10 @@ type Database struct { EventStateKeysTable tables.EventStateKeys RoomsTable tables.Rooms TransactionsTable tables.Transactions + StateSnapshotTable tables.StateSnapshot + StateBlockTable tables.StateBlock + RoomAliasesTable tables.RoomAliases + PrevEventsTable tables.PreviousEvents } // EventTypeNIDs implements state.RoomStateDatabase @@ -50,6 +54,42 @@ func (d *Database) StateEntriesForEventIDs( return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs) } +// StateEntriesForTuples implements state.RoomStateDatabase +func (d *Database) StateEntriesForTuples( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + return d.StateBlockTable.BulkSelectFilteredStateBlockEntries( + ctx, stateBlockNIDs, stateKeyTuples, + ) +} + +// AddState implements input.EventDatabase +func (d *Database) AddState( + ctx context.Context, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, +) (stateNID types.StateSnapshotNID, err error) { + err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { + if len(state) > 0 { + var stateBlockNID types.StateBlockNID + stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state) + if err != nil { + return err + } + stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) + } + stateNID, err = d.StateSnapshotTable.InsertState(ctx, txn, roomNID, stateBlockNIDs) + return err + }) + if err != nil { + return 0, err + } + return +} + // EventNIDs implements query.RoomserverQueryAPIDatabase func (d *Database) EventNIDs( ctx context.Context, eventIDs []string, @@ -150,6 +190,20 @@ func (d *Database) LatestEventIDs( return } +// StateBlockNIDs implements state.RoomStateDatabase +func (d *Database) StateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs) +} + +// StateEntries implements state.RoomStateDatabase +func (d *Database) StateEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) +} + func (d *Database) GetRoomVersionForRoom( ctx context.Context, roomID string, ) (gomatrixserverlib.RoomVersion, error) { @@ -166,6 +220,33 @@ func (d *Database) GetRoomVersionForRoomNID( ) } +// SetRoomAlias implements alias.RoomserverAliasAPIDB +func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { + return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID) +} + +// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB +func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { + return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, alias) +} + +// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB +func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { + return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, roomID) +} + +// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB +func (d *Database) GetCreatorIDForAlias( + ctx context.Context, alias string, +) (string, error) { + return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, alias) +} + +// RemoveRoomAlias implements alias.RoomserverAliasAPIDB +func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { + return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias) +} + // Events implements input.EventDatabase func (d *Database) Events( ctx context.Context, eventNIDs []types.EventNID, diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index f344bda95..6b758ccc6 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -55,19 +56,20 @@ type previousEventStatements struct { selectPreviousEventExistsStmt *sql.Stmt } -func (s *previousEventStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(previousEventSchema) +func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { + s := &previousEventStatements{} + _, err := db.Exec(previousEventSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, }.prepare(db) } -func (s *previousEventStatements) insertPreviousEvent( +func (s *previousEventStatements) InsertPreviousEvent( ctx context.Context, txn *sql.Tx, previousEventID string, @@ -83,7 +85,7 @@ func (s *previousEventStatements) insertPreviousEvent( // Check if the event reference exists // Returns sql.ErrNoRows if the event reference doesn't exist. -func (s *previousEventStatements) selectPreviousEventExists( +func (s *previousEventStatements) SelectPreviousEventExists( ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, ) error { var ok int64 diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index 592ef9782..686391f68 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" ) const roomAliasesSchema = ` @@ -60,12 +61,13 @@ type roomAliasesStatements struct { deleteRoomAliasStmt *sql.Stmt } -func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(roomAliasesSchema) +func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { + s := &roomAliasesStatements{} + _, err := db.Exec(roomAliasesSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, @@ -74,31 +76,28 @@ func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *roomAliasesStatements) insertRoomAlias( - ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string, +func (s *roomAliasesStatements) InsertRoomAlias( + ctx context.Context, alias string, roomID string, creatorUserID string, ) (err error) { - insertStmt := internal.TxStmt(txn, s.insertRoomAliasStmt) - _, err = insertStmt.ExecContext(ctx, alias, roomID, creatorUserID) + _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) return } -func (s *roomAliasesStatements) selectRoomIDFromAlias( - ctx context.Context, txn *sql.Tx, alias string, +func (s *roomAliasesStatements) SelectRoomIDFromAlias( + ctx context.Context, alias string, ) (roomID string, err error) { - selectStmt := internal.TxStmt(txn, s.selectRoomIDFromAliasStmt) - err = selectStmt.QueryRowContext(ctx, alias).Scan(&roomID) + err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) if err == sql.ErrNoRows { return "", nil } return } -func (s *roomAliasesStatements) selectAliasesFromRoomID( - ctx context.Context, txn *sql.Tx, roomID string, +func (s *roomAliasesStatements) SelectAliasesFromRoomID( + ctx context.Context, roomID string, ) (aliases []string, err error) { aliases = []string{} - selectStmt := internal.TxStmt(txn, s.selectAliasesFromRoomIDStmt) - rows, err := selectStmt.QueryContext(ctx, roomID) + rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) if err != nil { return } @@ -117,21 +116,19 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID( return } -func (s *roomAliasesStatements) selectCreatorIDFromAlias( - ctx context.Context, txn *sql.Tx, alias string, +func (s *roomAliasesStatements) SelectCreatorIDFromAlias( + ctx context.Context, alias string, ) (creatorID string, err error) { - selectStmt := internal.TxStmt(txn, s.selectCreatorIDFromAliasStmt) - err = selectStmt.QueryRowContext(ctx, alias).Scan(&creatorID) + err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) if err == sql.ErrNoRows { return "", nil } return } -func (s *roomAliasesStatements) deleteRoomAlias( - ctx context.Context, txn *sql.Tx, alias string, +func (s *roomAliasesStatements) DeleteRoomAlias( + ctx context.Context, alias string, ) (err error) { - deleteStmt := internal.TxStmt(txn, s.deleteRoomAliasStmt) - _, err = deleteStmt.ExecContext(ctx, alias) + _, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias) return } diff --git a/roomserver/storage/sqlite3/sql.go b/roomserver/storage/sqlite3/sql.go index fe899174a..df994d508 100644 --- a/roomserver/storage/sqlite3/sql.go +++ b/roomserver/storage/sqlite3/sql.go @@ -38,10 +38,6 @@ func (s *statements) prepare(db *sql.DB) error { var err error for _, prepare := range []func(db *sql.DB) error{ - s.stateSnapshotStatements.prepare, - s.stateBlockStatements.prepare, - s.previousEventStatements.prepare, - s.roomAliasesStatements.prepare, s.inviteStatements.prepare, s.membershipStatements.prepare, } { diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 861d76cf9..e3ca0b3c8 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -23,6 +23,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" ) @@ -77,14 +78,15 @@ type stateBlockStatements struct { bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt } -func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { +func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { + s := &stateBlockStatements{} s.db = db - _, err = db.Exec(stateDataSchema) + _, err := db.Exec(stateDataSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertStateDataStmt, insertStateDataSQL}, {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, @@ -92,7 +94,7 @@ func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *stateBlockStatements) bulkInsertStateData( +func (s *stateBlockStatements) BulkInsertStateData( ctx context.Context, txn *sql.Tx, entries []types.StateEntry, ) (types.StateBlockNID, error) { @@ -120,19 +122,18 @@ func (s *stateBlockStatements) bulkInsertStateData( return stateBlockNID, nil } -func (s *stateBlockStatements) bulkSelectStateBlockEntries( - ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID, +func (s *stateBlockStatements) BulkSelectStateBlockEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { nids := make([]interface{}, len(stateBlockNIDs)) for k, v := range stateBlockNIDs { nids[k] = v } selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", internal.QueryVariadic(len(nids)), 1) - selectPrep, err := s.db.Prepare(selectOrig) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } - selectStmt := internal.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, nids...) if err != nil { return nil, err @@ -174,8 +175,8 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries( return results, nil } -func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( - ctx context.Context, txn *sql.Tx, // nolint: unparam +func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntryList, error) { diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 8682627a0..42eae1a7f 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -23,6 +23,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -51,20 +52,21 @@ type stateSnapshotStatements struct { bulkSelectStateBlockNIDsStmt *sql.Stmt } -func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) { +func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { + s := &stateSnapshotStatements{} s.db = db - _, err = db.Exec(stateSnapshotSchema) + _, err := db.Exec(stateSnapshotSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, }.prepare(db) } -func (s *stateSnapshotStatements) insertState( +func (s *stateSnapshotStatements) InsertState( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, ) (stateNID types.StateSnapshotNID, err error) { stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs) @@ -82,15 +84,15 @@ func (s *stateSnapshotStatements) insertState( return } -func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( - ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, +func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { nids := make([]interface{}, len(stateNIDs)) for k, v := range stateNIDs { nids[k] = v } selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", internal.QueryVariadic(len(nids)), 1) - selectStmt, err := txn.Prepare(selectOrig) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index a0a1b5684..16697f1b9 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -41,10 +41,12 @@ type Database struct { eventStateKeys tables.EventStateKeys rooms tables.Rooms transactions tables.Transactions + prevEvents tables.PreviousEvents db *sql.DB } // Open a sqlite database. +// nolint: gocyclo func Open(dataSourceName string) (*Database, error) { var d Database uri, err := url.Parse(dataSourceName) @@ -97,6 +99,22 @@ func Open(dataSourceName string) (*Database, error) { if err != nil { return nil, err } + stateBlock, err := NewSqliteStateBlockTable(d.db) + if err != nil { + return nil, err + } + stateSnapshot, err := NewSqliteStateSnapshotTable(d.db) + if err != nil { + return nil, err + } + d.prevEvents, err = NewSqlitePrevEventsTable(d.db) + if err != nil { + return nil, err + } + roomAliases, err := NewSqliteRoomAliasesTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, EventsTable: d.events, @@ -105,6 +123,10 @@ func Open(dataSourceName string) (*Database, error) { EventJSONTable: d.eventJSON, RoomsTable: d.rooms, TransactionsTable: d.transactions, + StateBlockTable: stateBlock, + StateSnapshotTable: stateSnapshot, + PrevEventsTable: d.prevEvents, + RoomAliasesTable: roomAliases, } return &d, nil } @@ -142,53 +164,6 @@ func (d *Database) assignStateKeyNID( return } -// AddState implements input.EventDatabase -func (d *Database) AddState( - ctx context.Context, - roomNID types.RoomNID, - stateBlockNIDs []types.StateBlockNID, - state []types.StateEntry, -) (stateNID types.StateSnapshotNID, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - if len(state) > 0 { - var stateBlockNID types.StateBlockNID - stateBlockNID, err = d.statements.bulkInsertStateData(ctx, txn, state) - if err != nil { - return err - } - stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) - } - stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs) - return err - }) - if err != nil { - return 0, err - } - return -} - -// StateBlockNIDs implements state.RoomStateDatabase -func (d *Database) StateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, -) (sl []types.StateBlockNIDList, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - sl, err = d.statements.bulkSelectStateBlockNIDs(ctx, txn, stateNIDs) - return err - }) - return -} - -// StateEntries implements state.RoomStateDatabase -func (d *Database) StateEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) (sel []types.StateEntryList, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - sel, err = d.statements.bulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs) - return err - }) - return -} - // GetLatestEventsForUpdate implements input.EventDatabase func (d *Database) GetLatestEventsForUpdate( ctx context.Context, roomNID types.RoomNID, @@ -264,7 +239,7 @@ func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotN func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { err := internal.WithTransaction(u.d.db, func(txn *sql.Tx) error { for _, ref := range previousEventReferences { - if err := u.d.statements.insertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + if err := u.d.prevEvents.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { return err } } @@ -276,7 +251,7 @@ func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, p // IsReferenced implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (res bool, err error) { err = internal.WithTransaction(u.d.db, func(txn *sql.Tx) error { - err := u.d.statements.selectPreviousEventExists(u.ctx, txn, eventReference.EventID, eventReference.EventSHA256) + err := u.d.prevEvents.SelectPreviousEventExists(u.ctx, txn, eventReference.EventID, eventReference.EventSHA256) if err == nil { res = true err = nil @@ -339,44 +314,6 @@ func (d *Database) GetInvitesForUser( return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) } -// SetRoomAlias implements alias.RoomserverAliasAPIDB -func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { - return d.statements.insertRoomAlias(ctx, nil, alias, roomID, creatorUserID) -} - -// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB -func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { - return d.statements.selectRoomIDFromAlias(ctx, nil, alias) -} - -// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB -func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { - return d.statements.selectAliasesFromRoomID(ctx, nil, roomID) -} - -// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB -func (d *Database) GetCreatorIDForAlias( - ctx context.Context, alias string, -) (string, error) { - return d.statements.selectCreatorIDFromAlias(ctx, nil, alias) -} - -// RemoveRoomAlias implements alias.RoomserverAliasAPIDB -func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { - return d.statements.deleteRoomAlias(ctx, nil, alias) -} - -// StateEntriesForTuples implements state.RoomStateDatabase -func (d *Database) StateEntriesForTuples( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - return d.statements.bulkSelectFilteredStateBlockEntries( - ctx, nil, stateBlockNIDs, stateKeyTuples, - ) -} - // MembershipUpdater implements input.RoomEventDatabase func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index e913de6b4..56ef5dd39 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -70,3 +70,29 @@ type Transactions interface { InsertTransaction(ctx context.Context, txn *sql.Tx, transactionID string, sessionID int64, userID string, eventID string) error SelectTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (eventID string, err error) } + +type StateSnapshot interface { + InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error) + BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) +} + +type StateBlock interface { + BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries []types.StateEntry) (types.StateBlockNID, error) + BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) + BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) +} + +type RoomAliases interface { + InsertRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) (err error) + SelectRoomIDFromAlias(ctx context.Context, alias string) (roomID string, err error) + SelectAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error) + SelectCreatorIDFromAlias(ctx context.Context, alias string) (creatorID string, err error) + DeleteRoomAlias(ctx context.Context, alias string) (err error) +} + +type PreviousEvents interface { + InsertPreviousEvent(ctx context.Context, txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error + // Check if the event reference exists + // Returns sql.ErrNoRows if the event reference doesn't exist. + SelectPreviousEventExists(ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error +}