diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index f47a64c80..413163704 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -16,10 +16,13 @@ package postgres import ( + "context" "database/sql" "fmt" "github.com/lib/pq" + "github.com/sirupsen/logrus" + // Import the postgres database driver. _ "github.com/lib/pq" @@ -52,30 +55,8 @@ func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache c // Special case, since this migration uses several tables, so it needs to // be sure that all tables are created first. - // TODO: Remove when we are sure we are not having goose artefacts in the db - // This forces an error, which indicates the migration is already applied, since the - // column event_nid was removed from the table - var eventNID int - err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan(&eventNID) - if err == nil { - m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "roomserver: state blocks refactor", - Up: deltas.UpStateBlocksRefactor, - }) - if err = m.Up(base.Context()); err != nil { - return nil, err - } - } else { - switch e := err.(type) { - case *pq.Error: - // ignore undefined_column (42703) errors, as this is expected at this point - if e.Code != "42703" { - return nil, err - } - default: - return nil, err - } + if err = executeMigration(base.Context(), db); err != nil { + return nil, err } // Then prepare the statements. Now that the migrations have run, any columns referred @@ -87,6 +68,50 @@ func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache c return &d, nil } +func executeMigration(ctx context.Context, db *sql.DB) error { + // TODO: Remove when we are sure we are not having goose artefacts in the db + // This forces an error, which indicates the migration is already applied, since the + // column event_nid was removed from the table + migrationName := "roomserver: state blocks refactor" + var migrationCount int + + err := db.QueryRowContext(ctx, "SELECT count(*) FROM db_migrations WHERE version = $1", migrationName).Scan(&migrationCount) + if err != nil { + return err + } + if migrationCount > 0 { + return nil + } + + var eventNID int + err = db.QueryRowContext(ctx, "SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan(&eventNID) + if err == nil { + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: migrationName, + Up: deltas.UpStateBlocksRefactor, + }) + if err = m.Up(ctx); err != nil { + return err + } + } else { + switch e := err.(type) { + case *pq.Error: + // ignore undefined_column (42703) errors, as this is expected at this point + if e.Code != "42703" { + return err + } + if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil { + // not a fatal error, log and continue + logrus.WithError(err).Warnf("unable to manually insert migration '%s'", migrationName) + } + default: + return err + } + } + return nil +} + func (d *Database) create(db *sql.DB) error { if err := CreateEventStateKeysTable(db); err != nil { return err diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 9f8a1b118..3c7996d85 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -19,8 +19,11 @@ import ( "context" "database/sql" "fmt" + "strings" "github.com/matrix-org/gomatrixserverlib" + "github.com/mattn/go-sqlite3" + "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -61,20 +64,8 @@ func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache c // Special case, since this migration uses several tables, so it needs to // be sure that all tables are created first. - // TODO: Remove when we are sure we are not having goose artefacts in the db - // This forces an error, which indicates the migration is already applied, since the - // column event_nid was removed from the table - var eventNID int - err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan(&eventNID) - if err == nil { - m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "roomserver: state blocks refactor", - Up: deltas.UpStateBlocksRefactor, - }) - if err = m.Up(base.Context()); err != nil { - return nil, err - } + if err = executeMigration(base.Context(), db); err != nil { + return nil, err } // Then prepare the statements. Now that the migrations have run, any columns referred @@ -86,6 +77,52 @@ func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache c return &d, nil } +func executeMigration(ctx context.Context, db *sql.DB) error { + // TODO: Remove when we are sure we are not having goose artefacts in the db + // This forces an error, which indicates the migration is already applied, since the + // column event_nid was removed from the table + migrationName := "roomserver: state blocks refactor" + var migrationCount int + + err := db.QueryRowContext(ctx, "SELECT count(*) FROM db_migrations WHERE version = $1", migrationName).Scan(&migrationCount) + if err != nil { + return err + } + if migrationCount > 0 { + return nil + } + + var eventNID int + err = db.QueryRowContext(ctx, "SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan(&eventNID) + if err == nil { + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: migrationName, + Up: deltas.UpStateBlocksRefactor, + }) + if err = m.Up(ctx); err != nil { + return err + } + } else { + switch e := err.(type) { + case *sqlite3.Error: + // ignore "no such column" errors, as this is expected at this point + if !strings.Contains(e.Error(), "no such column") { + return err + } + + // reset the error to nil, so the deferred function will insert the migration + if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil { + // not a fatal error, log and continue + logrus.WithError(err).Warnf("unable to manually insert migration '%s'", migrationName) + } + default: + return err + } + } + return nil +} + func (d *Database) create(db *sql.DB) error { if err := CreateEventStateKeysTable(db); err != nil { return err