diff --git a/roomserver/storage/postgres/deltas/20210416150927_state_blocks_refactor.go b/roomserver/storage/postgres/deltas/20210416150927_state_blocks_refactor.go index c4dcd639b..73997484b 100644 --- a/roomserver/storage/postgres/deltas/20210416150927_state_blocks_refactor.go +++ b/roomserver/storage/postgres/deltas/20210416150927_state_blocks_refactor.go @@ -41,7 +41,8 @@ func LoadStateBlocksRefactor(m *sqlutil.Migrations) { } func UpStateBlocksRefactor(tx *sql.Tx) error { - logrus.Warn("Performing state block refactor upgrade. Please wait, this may take some time!") + logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!") + defer logrus.Warn("State storage upgrade complete") if _, err := tx.Exec(`ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil { return fmt.Errorf("tx.Exec: %w", err) diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index 8f11d1d8e..e0976b12c 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -59,12 +59,14 @@ type eventJSONStatements struct { bulkSelectEventJSONStmt *sql.Stmt } -func NewPostgresEventJSONTable(db *sql.DB) (tables.EventJSON, error) { - s := &eventJSONStatements{} +func createEventJSONTable(db *sql.DB) error { _, err := db.Exec(eventJSONSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) { + s := &eventJSONStatements{} + return s, shared.StatementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go index 500ff20e4..616823561 100644 --- a/roomserver/storage/postgres/event_state_keys_table.go +++ b/roomserver/storage/postgres/event_state_keys_table.go @@ -77,12 +77,14 @@ type eventStateKeyStatements struct { bulkSelectEventStateKeyStmt *sql.Stmt } -func NewPostgresEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { - s := &eventStateKeyStatements{} +func createEventStateKeysTable(db *sql.DB) error { _, err := db.Exec(eventStateKeysSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { + s := &eventStateKeyStatements{} + return s, shared.StatementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index 02d6ad079..f4257850a 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -100,12 +100,13 @@ type eventTypeStatements struct { bulkSelectEventTypeNIDStmt *sql.Stmt } -func NewPostgresEventTypesTable(db *sql.DB) (tables.EventTypes, error) { - s := &eventTypeStatements{} +func createEventTypesTable(db *sql.DB) error { _, err := db.Exec(eventTypesSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) { + s := &eventTypeStatements{} return s, shared.StatementList{ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 052b0b13b..605051edb 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -149,12 +149,13 @@ type eventStatements struct { selectRoomNIDsForEventNIDsStmt *sql.Stmt } -func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { - s := &eventStatements{} +func createEventsTable(db *sql.DB) error { _, err := db.Exec(eventsSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareEventsTable(db *sql.DB) (tables.Events, error) { + s := &eventStatements{} return s, shared.StatementList{ {&s.insertEventStmt, insertEventSQL}, diff --git a/roomserver/storage/postgres/invite_table.go b/roomserver/storage/postgres/invite_table.go index bb7195164..0a2183e27 100644 --- a/roomserver/storage/postgres/invite_table.go +++ b/roomserver/storage/postgres/invite_table.go @@ -82,12 +82,13 @@ type inviteStatements struct { updateInviteRetiredStmt *sql.Stmt } -func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { - s := &inviteStatements{} +func createInvitesTable(db *sql.DB) error { _, err := db.Exec(inviteSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareInvitesTable(db *sql.DB) (tables.Invites, error) { + s := &inviteStatements{} return s, shared.StatementList{ {&s.insertInviteEventStmt, insertInviteEventSQL}, diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index e392a4fbb..3466da6d2 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -139,12 +139,13 @@ type membershipStatements struct { updateMembershipForgetRoomStmt *sql.Stmt } -func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { - s := &membershipStatements{} +func createMembershipTable(db *sql.DB) error { _, err := db.Exec(membershipSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { + s := &membershipStatements{} return s, shared.StatementList{ {&s.insertMembershipStmt, insertMembershipSQL}, @@ -162,11 +163,6 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { }.Prepare(db) } -func (s *membershipStatements) execSchema(db *sql.DB) error { - _, err := db.Exec(membershipSchema) - return err -} - func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, diff --git a/roomserver/storage/postgres/previous_events_table.go b/roomserver/storage/postgres/previous_events_table.go index 1a4ba6732..4a93c3d65 100644 --- a/roomserver/storage/postgres/previous_events_table.go +++ b/roomserver/storage/postgres/previous_events_table.go @@ -65,12 +65,13 @@ type previousEventStatements struct { selectPreviousEventExistsStmt *sql.Stmt } -func NewPostgresPreviousEventsTable(db *sql.DB) (tables.PreviousEvents, error) { - s := &previousEventStatements{} +func createPrevEventsTable(db *sql.DB) error { _, err := db.Exec(previousEventSchema) - if err != nil { - return nil, err - } + return err +} + +func preparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { + s := &previousEventStatements{} return s, shared.StatementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go index 440ae7842..c180576e3 100644 --- a/roomserver/storage/postgres/published_table.go +++ b/roomserver/storage/postgres/published_table.go @@ -50,12 +50,14 @@ type publishedStatements struct { selectPublishedStmt *sql.Stmt } -func NewPostgresPublishedTable(db *sql.DB) (tables.Published, error) { - s := &publishedStatements{} +func createPublishedTable(db *sql.DB) error { _, err := db.Exec(publishedSchema) - if err != nil { - return nil, err - } + return err +} + +func preparePublishedTable(db *sql.DB) (tables.Published, error) { + s := &publishedStatements{} + return s, shared.StatementList{ {&s.upsertPublishedStmt, upsertPublishedSQL}, {&s.selectAllPublishedStmt, selectAllPublishedSQL}, diff --git a/roomserver/storage/postgres/redactions_table.go b/roomserver/storage/postgres/redactions_table.go index 42aba5985..3741d5f67 100644 --- a/roomserver/storage/postgres/redactions_table.go +++ b/roomserver/storage/postgres/redactions_table.go @@ -60,12 +60,13 @@ type redactionStatements struct { markRedactionValidatedStmt *sql.Stmt } -func NewPostgresRedactionsTable(db *sql.DB) (tables.Redactions, error) { - s := &redactionStatements{} +func createRedactionsTable(db *sql.DB) error { _, err := db.Exec(redactionsSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareRedactionsTable(db *sql.DB) (tables.Redactions, error) { + s := &redactionStatements{} return s, shared.StatementList{ {&s.insertRedactionStmt, insertRedactionSQL}, diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go index b603a673c..c808813ee 100644 --- a/roomserver/storage/postgres/room_aliases_table.go +++ b/roomserver/storage/postgres/room_aliases_table.go @@ -62,12 +62,14 @@ type roomAliasesStatements struct { deleteRoomAliasStmt *sql.Stmt } -func NewPostgresRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { - s := &roomAliasesStatements{} +func createRoomAliasesTable(db *sql.DB) error { _, err := db.Exec(roomAliasesSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { + s := &roomAliasesStatements{} + return s, shared.StatementList{ {&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 637680bde..f2b39fe54 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -96,12 +96,14 @@ type roomStatements struct { bulkSelectRoomNIDsStmt *sql.Stmt } -func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { - s := &roomStatements{} +func createRoomsTable(db *sql.DB) error { _, err := db.Exec(roomsSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { + s := &roomStatements{} + return s, shared.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index e75663fb7..6a986aaa4 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -61,12 +61,13 @@ type stateBlockStatements struct { bulkSelectStateBlockEntriesStmt *sql.Stmt } -func NewPostgresStateBlockTable(db *sql.DB) (tables.StateBlock, error) { - s := &stateBlockStatements{} +func createStateBlockTable(db *sql.DB) error { _, err := db.Exec(stateDataSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { + s := &stateBlockStatements{} return s, shared.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index f5b0c0cd4..9841019c5 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -67,12 +67,13 @@ type stateSnapshotStatements struct { bulkSelectStateBlockNIDsStmt *sql.Stmt } -func NewPostgresStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { - s := &stateSnapshotStatements{} +func createStateSnapshotTable(db *sql.DB) error { _, err := db.Exec(stateSnapshotSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { + s := &stateSnapshotStatements{} return s, shared.StatementList{ {&s.insertStateStmt, insertStateSQL}, diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index a7eb4e100..863a15939 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -43,80 +43,130 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) return nil, fmt.Errorf("sqlutil.Open: %w", err) } - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - ms := membershipStatements{} - if err := ms.execSchema(db); err != nil { - return nil, fmt.Errorf("ms.execSchema: %w", err) + // Create the tables. + if err := d.create(db); err != nil { + return nil, err } + + // Then execute the migrations. By this point the tables are created with the latest + // schemas. m := sqlutil.NewMigrations() deltas.LoadAddForgottenColumn(m) deltas.LoadStateBlocksRefactor(m) if err := m.RunDeltas(db, dbProperties); err != nil { - return nil, fmt.Errorf("m.RunDeltas: %w", err) + return nil, err } + + // Then prepare the statements. Now that the migrations have run, any columns referred + // to in the database code should now exist. if err := d.prepare(db, cache); err != nil { - return nil, fmt.Errorf("d.prepare: %w", err) + return nil, err } return &d, nil } -// nolint: gocyclo -func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) (err error) { - eventStateKeys, err := NewPostgresEventStateKeysTable(db) +func (d *Database) create(db *sql.DB) error { + if err := createEventStateKeysTable(db); err != nil { + return err + } + if err := createEventTypesTable(db); err != nil { + return err + } + if err := createEventJSONTable(db); err != nil { + return err + } + if err := createEventsTable(db); err != nil { + return err + } + if err := createRoomsTable(db); err != nil { + return err + } + if err := createTransactionsTable(db); err != nil { + return err + } + if err := createStateBlockTable(db); err != nil { + return err + } + if err := createStateSnapshotTable(db); err != nil { + return err + } + if err := createPrevEventsTable(db); err != nil { + return err + } + if err := createRoomAliasesTable(db); err != nil { + return err + } + if err := createInvitesTable(db); err != nil { + return err + } + if err := createMembershipTable(db); err != nil { + return err + } + if err := createPublishedTable(db); err != nil { + return err + } + if err := createRedactionsTable(db); err != nil { + return err + } + + return nil +} + +func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { + eventStateKeys, err := prepareEventStateKeysTable(db) if err != nil { return err } - eventTypes, err := NewPostgresEventTypesTable(db) + eventTypes, err := prepareEventTypesTable(db) if err != nil { return err } - eventJSON, err := NewPostgresEventJSONTable(db) + eventJSON, err := prepareEventJSONTable(db) if err != nil { return err } - events, err := NewPostgresEventsTable(db) + events, err := prepareEventsTable(db) if err != nil { return err } - rooms, err := NewPostgresRoomsTable(db) + rooms, err := prepareRoomsTable(db) if err != nil { return err } - transactions, err := NewPostgresTransactionsTable(db) + transactions, err := prepareTransactionsTable(db) if err != nil { return err } - stateBlock, err := NewPostgresStateBlockTable(db) + stateBlock, err := prepareStateBlockTable(db) if err != nil { return err } - stateSnapshot, err := NewPostgresStateSnapshotTable(db) + stateSnapshot, err := prepareStateSnapshotTable(db) if err != nil { return err } - roomAliases, err := NewPostgresRoomAliasesTable(db) + prevEvents, err := preparePrevEventsTable(db) if err != nil { return err } - prevEvents, err := NewPostgresPreviousEventsTable(db) + roomAliases, err := prepareRoomAliasesTable(db) if err != nil { return err } - invites, err := NewPostgresInvitesTable(db) + invites, err := prepareInvitesTable(db) if err != nil { return err } - membership, err := NewPostgresMembershipTable(db) + membership, err := prepareMembershipTable(db) if err != nil { return err } - published, err := NewPostgresPublishedTable(db) + published, err := preparePublishedTable(db) if err != nil { return err } - redactions, err := NewPostgresRedactionsTable(db) + redactions, err := prepareRedactionsTable(db) if err != nil { return err } diff --git a/roomserver/storage/postgres/transactions_table.go b/roomserver/storage/postgres/transactions_table.go index 5e59ae16d..cada0d8aa 100644 --- a/roomserver/storage/postgres/transactions_table.go +++ b/roomserver/storage/postgres/transactions_table.go @@ -54,12 +54,13 @@ type transactionStatements struct { selectTransactionEventIDStmt *sql.Stmt } -func NewPostgresTransactionsTable(db *sql.DB) (tables.Transactions, error) { - s := &transactionStatements{} +func createTransactionsTable(db *sql.DB) error { _, err := db.Exec(transactionsSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareTransactionsTable(db *sql.DB) (tables.Transactions, error) { + s := &transactionStatements{} return s, shared.StatementList{ {&s.insertTransactionStmt, insertTransactionSQL}, diff --git a/roomserver/storage/sqlite3/deltas/20210416150927_state_blocks_refactor.go b/roomserver/storage/sqlite3/deltas/20210416150927_state_blocks_refactor.go index 53cf9ae69..246c2054c 100644 --- a/roomserver/storage/sqlite3/deltas/20210416150927_state_blocks_refactor.go +++ b/roomserver/storage/sqlite3/deltas/20210416150927_state_blocks_refactor.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) func LoadStateBlocksRefactor(m *sqlutil.Migrations) { @@ -31,6 +32,8 @@ func LoadStateBlocksRefactor(m *sqlutil.Migrations) { } func UpStateBlocksRefactor(tx *sql.Tx) error { + logrus.Warn("Performing state block refactor upgrade. Please wait, this may take some time!") + if _, err := tx.Exec(`ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil { return fmt.Errorf("tx.Exec: %w", err) } diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index 3cd44b1dc..29d54b83d 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -53,14 +53,16 @@ type eventJSONStatements struct { bulkSelectEventJSONStmt *sql.Stmt } -func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { +func createEventJSONTable(db *sql.DB) error { + _, err := db.Exec(eventJSONSchema) + return err +} + +func prepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) { s := &eventJSONStatements{ db: db, } - _, err := db.Exec(eventJSONSchema) - if err != nil { - return nil, err - } + return s, shared.StatementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index 345df8c62..d430e5535 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -70,14 +70,16 @@ type eventStateKeyStatements struct { bulkSelectEventStateKeyStmt *sql.Stmt } -func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { +func createEventStateKeysTable(db *sql.DB) error { + _, err := db.Exec(eventStateKeysSchema) + return err +} + +func prepareEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { s := &eventStateKeyStatements{ db: db, } - _, err := db.Exec(eventStateKeysSchema) - if err != nil { - return nil, err - } + return s, shared.StatementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 26e2bf843..694f4e217 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -85,14 +85,15 @@ type eventTypeStatements struct { bulkSelectEventTypeNIDStmt *sql.Stmt } -func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) { +func createEventTypesTable(db *sql.DB) error { + _, err := db.Exec(eventTypesSchema) + return err +} + +func prepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) { s := &eventTypeStatements{ db: db, } - _, err := db.Exec(eventTypesSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 420a4845e..5cbce9f5f 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -121,14 +121,15 @@ type eventStatements struct { //selectRoomNIDsForEventNIDsStmt *sql.Stmt } -func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { +func createEventsTable(db *sql.DB) error { + _, err := db.Exec(eventsSchema) + return err +} + +func prepareEventsTable(db *sql.DB) (tables.Events, error) { s := &eventStatements{ db: db, } - _, err := db.Exec(eventsSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertEventStmt, insertEventSQL}, diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index 327be6a03..e1aa1ebd3 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -70,14 +70,15 @@ type inviteStatements struct { selectInvitesAboutToRetireStmt *sql.Stmt } -func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) { +func createInvitesTable(db *sql.DB) error { + _, err := db.Exec(inviteSchema) + return err +} + +func prepareInvitesTable(db *sql.DB) (tables.Invites, error) { s := &inviteStatements{ db: db, } - _, err := db.Exec(inviteSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertInviteEventStmt, insertInviteEventSQL}, diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index d716ced04..d9fe32cf8 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -115,7 +115,12 @@ type membershipStatements struct { updateMembershipForgetRoomStmt *sql.Stmt } -func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { +func createMembershipTable(db *sql.DB) error { + _, err := db.Exec(membershipSchema) + return err +} + +func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { s := &membershipStatements{ db: db, } @@ -135,11 +140,6 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { }.Prepare(db) } -func (s *membershipStatements) execSchema(db *sql.DB) error { - _, err := db.Exec(membershipSchema) - return err -} - func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index aaee62733..3cb527678 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -71,14 +71,15 @@ type previousEventStatements struct { selectPreviousEventExistsStmt *sql.Stmt } -func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { +func createPrevEventsTable(db *sql.DB) error { + _, err := db.Exec(previousEventSchema) + return err +} + +func preparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { s := &previousEventStatements{ db: db, } - _, err := db.Exec(previousEventSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index dcf6f697a..6d9d91355 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -50,14 +50,16 @@ type publishedStatements struct { selectPublishedStmt *sql.Stmt } -func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) { +func createPublishedTable(db *sql.DB) error { + _, err := db.Exec(publishedSchema) + return err +} + +func preparePublishedTable(db *sql.DB) (tables.Published, error) { s := &publishedStatements{ db: db, } - _, err := db.Exec(publishedSchema) - if err != nil { - return nil, err - } + return s, shared.StatementList{ {&s.upsertPublishedStmt, upsertPublishedSQL}, {&s.selectAllPublishedStmt, selectAllPublishedSQL}, diff --git a/roomserver/storage/sqlite3/redactions_table.go b/roomserver/storage/sqlite3/redactions_table.go index e64714862..b34981829 100644 --- a/roomserver/storage/sqlite3/redactions_table.go +++ b/roomserver/storage/sqlite3/redactions_table.go @@ -59,14 +59,15 @@ type redactionStatements struct { markRedactionValidatedStmt *sql.Stmt } -func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) { +func createRedactionsTable(db *sql.DB) error { + _, err := db.Exec(redactionsSchema) + return err +} + +func prepareRedactionsTable(db *sql.DB) (tables.Redactions, error) { s := &redactionStatements{ db: db, } - _, err := db.Exec(redactionsSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertRedactionStmt, insertRedactionSQL}, diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index f053e3981..5215fa6f7 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -64,14 +64,16 @@ type roomAliasesStatements struct { deleteRoomAliasStmt *sql.Stmt } -func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { +func createRoomAliasesTable(db *sql.DB) error { + _, err := db.Exec(roomAliasesSchema) + return err +} + +func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { s := &roomAliasesStatements{ db: db, } - _, err := db.Exec(roomAliasesSchema) - if err != nil { - return nil, err - } + return s, shared.StatementList{ {&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index fe8e601f5..534a870cc 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -86,14 +86,16 @@ type roomStatements struct { selectRoomIDsStmt *sql.Stmt } -func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { +func createRoomsTable(db *sql.DB) error { + _, err := db.Exec(roomsSchema) + return err +} + +func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { s := &roomStatements{ db: db, } - _, err := db.Exec(roomsSchema) - if err != nil { - return nil, err - } + return s, shared.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 508779dd5..22240aa6b 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -56,14 +56,15 @@ type stateBlockStatements struct { bulkSelectStateBlockEntriesStmt *sql.Stmt } -func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { +func createStateBlockTable(db *sql.DB) error { + _, err := db.Exec(stateDataSchema) + return err +} + +func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { s := &stateBlockStatements{ db: db, } - _, err := db.Exec(stateDataSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 15bf521f4..8a7d640e1 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -59,14 +59,15 @@ type stateSnapshotStatements struct { bulkSelectStateBlockNIDsStmt *sql.Stmt } -func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { +func createStateSnapshotTable(db *sql.DB) error { + _, err := db.Exec(stateSnapshotSchema) + return err +} + +func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { s := &stateSnapshotStatements{ db: db, } - _, err := db.Exec(stateSnapshotSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertStateStmt, insertStateSQL}, diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 351bbc316..c07ab507a 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -53,18 +53,22 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) // which it will never obtain. db.SetMaxOpenConns(20) - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - ms := membershipStatements{} - if err := ms.execSchema(db); err != nil { + // Create the tables. + if err := d.create(db); err != nil { return nil, err } + + // Then execute the migrations. By this point the tables are created with the latest + // schemas. m := sqlutil.NewMigrations() deltas.LoadAddForgottenColumn(m) deltas.LoadStateBlocksRefactor(m) if err := m.RunDeltas(db, dbProperties); err != nil { return nil, err } + + // Then prepare the statements. Now that the migrations have run, any columns referred + // to in the database code should now exist. if err := d.prepare(db, cache); err != nil { return nil, err } @@ -72,62 +76,107 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) return &d, nil } -// nolint: gocyclo +func (d *Database) create(db *sql.DB) error { + if err := createEventStateKeysTable(db); err != nil { + return err + } + if err := createEventTypesTable(db); err != nil { + return err + } + if err := createEventJSONTable(db); err != nil { + return err + } + if err := createEventsTable(db); err != nil { + return err + } + if err := createRoomsTable(db); err != nil { + return err + } + if err := createTransactionsTable(db); err != nil { + return err + } + if err := createStateBlockTable(db); err != nil { + return err + } + if err := createStateSnapshotTable(db); err != nil { + return err + } + if err := createPrevEventsTable(db); err != nil { + return err + } + if err := createRoomAliasesTable(db); err != nil { + return err + } + if err := createInvitesTable(db); err != nil { + return err + } + if err := createMembershipTable(db); err != nil { + return err + } + if err := createPublishedTable(db); err != nil { + return err + } + if err := createRedactionsTable(db); err != nil { + return err + } + + return nil +} + func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { - var err error - eventStateKeys, err := NewSqliteEventStateKeysTable(db) + eventStateKeys, err := prepareEventStateKeysTable(db) if err != nil { return err } - eventTypes, err := NewSqliteEventTypesTable(db) + eventTypes, err := prepareEventTypesTable(db) if err != nil { return err } - eventJSON, err := NewSqliteEventJSONTable(db) + eventJSON, err := prepareEventJSONTable(db) if err != nil { return err } - events, err := NewSqliteEventsTable(db) + events, err := prepareEventsTable(db) if err != nil { return err } - rooms, err := NewSqliteRoomsTable(db) + rooms, err := prepareRoomsTable(db) if err != nil { return err } - transactions, err := NewSqliteTransactionsTable(db) + transactions, err := prepareTransactionsTable(db) if err != nil { return err } - stateBlock, err := NewSqliteStateBlockTable(db) + stateBlock, err := prepareStateBlockTable(db) if err != nil { return err } - stateSnapshot, err := NewSqliteStateSnapshotTable(db) + stateSnapshot, err := prepareStateSnapshotTable(db) if err != nil { return err } - prevEvents, err := NewSqlitePrevEventsTable(db) + prevEvents, err := preparePrevEventsTable(db) if err != nil { return err } - roomAliases, err := NewSqliteRoomAliasesTable(db) + roomAliases, err := prepareRoomAliasesTable(db) if err != nil { return err } - invites, err := NewSqliteInvitesTable(db) + invites, err := prepareInvitesTable(db) if err != nil { return err } - membership, err := NewSqliteMembershipTable(db) + membership, err := prepareMembershipTable(db) if err != nil { return err } - published, err := NewSqlitePublishedTable(db) + published, err := preparePublishedTable(db) if err != nil { return err } - redactions, err := NewSqliteRedactionsTable(db) + redactions, err := prepareRedactionsTable(db) if err != nil { return err } diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go index 029122c5e..e7471d7b0 100644 --- a/roomserver/storage/sqlite3/transactions_table.go +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -49,14 +49,15 @@ type transactionStatements struct { selectTransactionEventIDStmt *sql.Stmt } -func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) { +func createTransactionsTable(db *sql.DB) error { + _, err := db.Exec(transactionsSchema) + return err +} + +func prepareTransactionsTable(db *sql.DB) (tables.Transactions, error) { s := &transactionStatements{ db: db, } - _, err := db.Exec(transactionsSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertTransactionStmt, insertTransactionSQL},