mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-01 03:03:10 -06:00
Remove AddMigration
Use WithTransaction Add Dendrite version to table
This commit is contained in:
parent
673e0f601b
commit
ffe666fef9
|
|
@ -85,7 +85,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationC
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
m := sqlutil.NewMigrator(d.db)
|
m := sqlutil.NewMigrator(d.db)
|
||||||
m.AddMigration(sqlutil.Migration{
|
m.AddMigrations(sqlutil.Migration{
|
||||||
Version: "drop federationsender_rooms",
|
Version: "drop federationsender_rooms",
|
||||||
Up: deltas.UpRemoveRoomsTable,
|
Up: deltas.UpRemoveRoomsTable,
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationC
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
m := sqlutil.NewMigrator(d.db)
|
m := sqlutil.NewMigrator(d.db)
|
||||||
m.AddMigration(sqlutil.Migration{
|
m.AddMigrations(sqlutil.Migration{
|
||||||
Version: "drop federationsender_rooms",
|
Version: "drop federationsender_rooms",
|
||||||
Up: deltas.UpRemoveRoomsTable,
|
Up: deltas.UpRemoveRoomsTable,
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -20,8 +20,24 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const createDBMigrationsSQL = "" +
|
||||||
|
"CREATE TABLE IF NOT EXISTS db_migrations (" +
|
||||||
|
" version TEXT PRIMARY KEY," +
|
||||||
|
" time TEXT," +
|
||||||
|
" dendrite_version TEXT" +
|
||||||
|
");"
|
||||||
|
|
||||||
|
const insertVersionSQL = "" +
|
||||||
|
"INSERT INTO db_migrations (version, time, dendrite_version)" +
|
||||||
|
" VALUES ($1, $2, $3) " +
|
||||||
|
" ON CONFLICT(version) DO UPDATE SET dendrite_version = $4, time = $5"
|
||||||
|
|
||||||
|
const selectDBMigrationsSQL = "SELECT version FROM db_migrations"
|
||||||
|
|
||||||
// Migration defines a migration to be run.
|
// Migration defines a migration to be run.
|
||||||
type Migration struct {
|
type Migration struct {
|
||||||
// Version is a simple name description/name of this migration
|
// Version is a simple name description/name of this migration
|
||||||
|
|
@ -50,73 +66,68 @@ func NewMigrator(db *sql.DB) *Migrator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddMigration adds new migrations to the list.
|
// AddMigrations adds new migrations to the list.
|
||||||
// De-duplicates migrations by their version
|
// De-duplicates migrations by their version
|
||||||
func (m *Migrator) AddMigration(migration Migration) {
|
func (m *Migrator) AddMigrations(migrations ...Migration) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
if !m.knownMigrations[migration.Version] {
|
|
||||||
m.migrations = append(m.migrations, migration)
|
|
||||||
m.knownMigrations[migration.Version] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddMigrations is a convenience method to add migrations
|
|
||||||
func (m *Migrator) AddMigrations(migrations ...Migration) {
|
|
||||||
for _, mig := range migrations {
|
for _, mig := range migrations {
|
||||||
m.AddMigration(mig)
|
if !m.knownMigrations[mig.Version] {
|
||||||
|
m.migrations = append(m.migrations, mig)
|
||||||
|
m.knownMigrations[mig.Version] = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Up executes all migrations
|
// Up executes all migrations
|
||||||
func (m *Migrator) Up(ctx context.Context) error {
|
func (m *Migrator) Up(ctx context.Context) error {
|
||||||
var err error
|
var (
|
||||||
|
err error
|
||||||
|
dendriteVersion = internal.VersionString()
|
||||||
|
)
|
||||||
// ensure there is a table for known migrations
|
// ensure there is a table for known migrations
|
||||||
executedMigrations, err := m.ExecutedMigrations(ctx)
|
executedMigrations, err := m.ExecutedMigrations(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to create/get migrations: %w", err)
|
return fmt.Errorf("unable to create/get migrations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
txn, err := m.db.BeginTx(ctx, nil)
|
return WithTransaction(m.db, func(txn *sql.Tx) error {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to begin transaction: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
_ = txn.Rollback()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for i := range m.migrations {
|
for i := range m.migrations {
|
||||||
|
now := time.Now().UTC().Format(time.RFC3339)
|
||||||
migration := m.migrations[i]
|
migration := m.migrations[i]
|
||||||
if !executedMigrations[migration.Version] {
|
if !executedMigrations[migration.Version] {
|
||||||
err = migration.Up(ctx, txn)
|
err = migration.Up(ctx, txn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err)
|
return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err)
|
||||||
}
|
}
|
||||||
_, err = txn.ExecContext(ctx, "INSERT INTO db_migrations (version, time) VALUES ($1, $2)", migration.Version, time.Now().UTC().Format(time.RFC3339))
|
_, err = txn.ExecContext(ctx, insertVersionSQL,
|
||||||
|
migration.Version,
|
||||||
|
now,
|
||||||
|
dendriteVersion,
|
||||||
|
dendriteVersion,
|
||||||
|
now,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to insert executed migrations: %w", err)
|
return fmt.Errorf("unable to insert executed migrations: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err = txn.Commit(); err != nil {
|
|
||||||
return fmt.Errorf("unable to commit transaction: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecutedMigrations returns a map with already executed migrations
|
// ExecutedMigrations returns a map with already executed migrations
|
||||||
func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]bool, error) {
|
func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]bool, error) {
|
||||||
result := make(map[string]bool)
|
result := make(map[string]bool)
|
||||||
_, err := m.db.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS db_migrations ( version TEXT, time TEXT );")
|
_, err := m.db.ExecContext(ctx, createDBMigrationsSQL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to create db_migrations: %w", err)
|
return nil, fmt.Errorf("unable to create db_migrations: %w", err)
|
||||||
}
|
}
|
||||||
rows, err := m.db.QueryContext(ctx, "SELECT version FROM db_migrations")
|
rows, err := m.db.QueryContext(ctx, selectDBMigrationsSQL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to query db_migrations: %w", err)
|
return nil, fmt.Errorf("unable to query db_migrations: %w", err)
|
||||||
}
|
}
|
||||||
|
defer rows.Close() // nolint: errcheck
|
||||||
var version string
|
var version string
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
if err := rows.Scan(&version); err != nil {
|
if err := rows.Scan(&version); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
||||||
err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan(&count)
|
err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan(&count)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
m := sqlutil.NewMigrator(db)
|
m := sqlutil.NewMigrator(db)
|
||||||
m.AddMigration(sqlutil.Migration{
|
m.AddMigrations(sqlutil.Migration{
|
||||||
Version: "refactor key changes",
|
Version: "refactor key changes",
|
||||||
Up: deltas.UpRefactorKeyChanges,
|
Up: deltas.UpRefactorKeyChanges,
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
||||||
err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan(&count)
|
err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan(&count)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
m := sqlutil.NewMigrator(db)
|
m := sqlutil.NewMigrator(db)
|
||||||
m.AddMigration(sqlutil.Migration{
|
m.AddMigrations(sqlutil.Migration{
|
||||||
Version: "refactor key changes",
|
Version: "refactor key changes",
|
||||||
Up: deltas.UpRefactorKeyChanges,
|
Up: deltas.UpRefactorKeyChanges,
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -167,7 +167,7 @@ func createMembershipTable(db *sql.DB) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
m := sqlutil.NewMigrator(db)
|
m := sqlutil.NewMigrator(db)
|
||||||
m.AddMigration(sqlutil.Migration{
|
m.AddMigrations(sqlutil.Migration{
|
||||||
Version: "add forgotten column",
|
Version: "add forgotten column",
|
||||||
Up: deltas.UpAddForgottenColumn,
|
Up: deltas.UpAddForgottenColumn,
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -57,12 +57,10 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches)
|
||||||
err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan(&count)
|
err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan(&count)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
m := sqlutil.NewMigrator(db)
|
m := sqlutil.NewMigrator(db)
|
||||||
m.AddMigrations([]sqlutil.Migration{
|
m.AddMigrations(sqlutil.Migration{
|
||||||
{
|
|
||||||
Version: "state blocks refactor",
|
Version: "state blocks refactor",
|
||||||
Up: deltas.UpStateBlocksRefactor,
|
Up: deltas.UpStateBlocksRefactor,
|
||||||
},
|
})
|
||||||
}...)
|
|
||||||
if err := m.Up(context.Background()); err != nil {
|
if err := m.Up(context.Background()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -65,12 +65,10 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches)
|
||||||
err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan(&count)
|
err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan(&count)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
m := sqlutil.NewMigrator(db)
|
m := sqlutil.NewMigrator(db)
|
||||||
m.AddMigrations([]sqlutil.Migration{
|
m.AddMigrations(sqlutil.Migration{
|
||||||
{
|
|
||||||
Version: "state blocks refactor",
|
Version: "state blocks refactor",
|
||||||
Up: deltas.UpStateBlocksRefactor,
|
Up: deltas.UpStateBlocksRefactor,
|
||||||
},
|
})
|
||||||
}...)
|
|
||||||
if err := m.Up(context.Background()); err != nil {
|
if err := m.Up(context.Background()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,7 @@ func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
m := sqlutil.NewMigrator(db)
|
m := sqlutil.NewMigrator(db)
|
||||||
m.AddMigration(sqlutil.Migration{
|
m.AddMigrations(sqlutil.Migration{
|
||||||
Version: "fix sequences",
|
Version: "fix sequences",
|
||||||
Up: deltas.UpFixSequences,
|
Up: deltas.UpFixSequences,
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,7 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
m := sqlutil.NewMigrator(db)
|
m := sqlutil.NewMigrator(db)
|
||||||
m.AddMigration(sqlutil.Migration{
|
m.AddMigrations(sqlutil.Migration{
|
||||||
Version: "drop sent_by_token",
|
Version: "drop sent_by_token",
|
||||||
Up: deltas.UpRemoveSendToDeviceSentColumn,
|
Up: deltas.UpRemoveSendToDeviceSentColumn,
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Re
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
m := sqlutil.NewMigrator(db)
|
m := sqlutil.NewMigrator(db)
|
||||||
m.AddMigration(sqlutil.Migration{
|
m.AddMigrations(sqlutil.Migration{
|
||||||
Version: "fix sequences",
|
Version: "fix sequences",
|
||||||
Up: deltas.UpFixSequences,
|
Up: deltas.UpFixSequences,
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
m := sqlutil.NewMigrator(db)
|
m := sqlutil.NewMigrator(db)
|
||||||
m.AddMigration(sqlutil.Migration{
|
m.AddMigrations(sqlutil.Migration{
|
||||||
Version: "drop sent_by_token",
|
Version: "drop sent_by_token",
|
||||||
Up: deltas.UpRemoveSendToDeviceSentColumn,
|
Up: deltas.UpRemoveSendToDeviceSentColumn,
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -122,7 +122,7 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
m := sqlutil.NewMigrator(db)
|
m := sqlutil.NewMigrator(db)
|
||||||
m.AddMigration(sqlutil.Migration{
|
m.AddMigrations(sqlutil.Migration{
|
||||||
Version: "add last_seen_ts",
|
Version: "add last_seen_ts",
|
||||||
Up: deltas.UpLastSeenTSIP,
|
Up: deltas.UpLastSeenTSIP,
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -109,12 +109,10 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
m := sqlutil.NewMigrator(db)
|
m := sqlutil.NewMigrator(db)
|
||||||
m.AddMigrations([]sqlutil.Migration{
|
m.AddMigrations(sqlutil.Migration{
|
||||||
{
|
|
||||||
Version: "add last_seen_ts",
|
Version: "add last_seen_ts",
|
||||||
Up: deltas.UpLastSeenTSIP,
|
Up: deltas.UpLastSeenTSIP,
|
||||||
},
|
})
|
||||||
}...)
|
|
||||||
err = m.Up(context.Background())
|
err = m.Up(context.Background())
|
||||||
return s, sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertDeviceStmt, insertDeviceSQL},
|
{&s.insertDeviceStmt, insertDeviceSQL},
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue