diff --git a/keyserver/storage/postgres/key_changes_table.go b/keyserver/storage/postgres/key_changes_table.go index 004f15d82..5ad5a9749 100644 --- a/keyserver/storage/postgres/key_changes_table.go +++ b/keyserver/storage/postgres/key_changes_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/lib/pq" + "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -63,30 +64,54 @@ func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { return s, err } + if err = executeMigration(context.Background(), db); err != nil { + return nil, err + } + return s, nil +} + +func executeMigration(ctx context.Context, db *sql.DB) error { // TODO: Remove when we are sure we are not having goose artefacts in the db // This forces an error, which indicates the migration is already applied, since the // column partition was removed from the table - var count int - err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan(&count) + migrationName := "keyserver: refactor key changes" + var migrationCount int + + err := db.QueryRowContext(ctx, "SELECT count(*) FROM db_migrations WHERE version = $1", migrationName).Scan(&migrationCount) + if err != nil { + return err + } + if migrationCount > 0 { + return nil + } + + var partition int + err = db.QueryRowContext(ctx, "SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan(&partition) if err == nil { m := sqlutil.NewMigrator(db) m.AddMigrations(sqlutil.Migration{ - Version: "keyserver: refactor key changes", + Version: migrationName, Up: deltas.UpRefactorKeyChanges, }) - return s, m.Up(context.Background()) + if err = m.Up(ctx); err != nil { + return err + } } else { switch e := err.(type) { case *pq.Error: // ignore undefined_column (42703) errors, as this is expected at this point if e.Code != "42703" { - return nil, err + return err + } + if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil { + // not a fatal error, log and continue + logrus.WithError(err).Warnf("unable to manually insert migration '%s'", migrationName) } default: - return nil, err + return err } } - return s, nil + return nil } func (s *keyChangesStatements) Prepare() (err error) { diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go index 217fa7a5d..f3d743cac 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -17,6 +17,10 @@ package sqlite3 import ( "context" "database/sql" + "strings" + + "github.com/mattn/go-sqlite3" + "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -58,23 +62,60 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { if err != nil { return s, err } - // TODO: Remove when we are sure we are not having goose artefacts in the db - // This forces an error, which indicates the migration is already applied, since the - // column partition was removed from the table - var count int - err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan(&count) - if err == nil { - m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "keyserver: refactor key changes", - Up: deltas.UpRefactorKeyChanges, - }) - return s, m.Up(context.Background()) + + if err = executeMigration(context.Background(), db); err != nil { + return nil, err } return s, nil } +func executeMigration(ctx context.Context, db *sql.DB) error { + // TODO: Remove when we are sure we are not having goose artefacts in the db + // This forces an error, which indicates the migration is already applied, since the + // column partition was removed from the table + migrationName := "keyserver: refactor key changes" + var migrationCount int + + err := db.QueryRowContext(ctx, "SELECT count(*) FROM db_migrations WHERE version = $1", migrationName).Scan(&migrationCount) + if err != nil { + return err + } + if migrationCount > 0 { + return nil + } + + var partition int + err = db.QueryRowContext(ctx, "SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan(&partition) + if err == nil { + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: migrationName, + Up: deltas.UpRefactorKeyChanges, + }) + if err = m.Up(ctx); err != nil { + return err + } + } else { + switch e := err.(type) { + case *sqlite3.Error: + // ignore "no such column" errors, as this is expected at this point + if !strings.Contains(e.Error(), "no such column") { + return err + } + + // reset the error to nil, so the deferred function will insert the migration + if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil { + // not a fatal error, log and continue + logrus.WithError(err).Warnf("unable to manually insert migration '%s'", migrationName) + } + default: + return err + } + } + return nil +} + func (s *keyChangesStatements) Prepare() (err error) { if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil { return err