diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go index 254da84c3..4eaa5b581 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -75,11 +75,12 @@ type accountsStatements struct { serverName gomatrixserverlib.ServerName } +func (s *accountsStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(accountsSchema) + return err +} + func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - _, err = db.Exec(accountsSchema) - if err != nil { - return - } if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { return } diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index 768bd208b..40c4b8ff5 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -57,6 +57,18 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver db: db, writer: sqlutil.NewDummyWriter(), } + + // Create tables before executing migrations so we don't fail if the table is missing, + // and THEN prepare statements so we don't fail due to referencing new columns + if err = d.accounts.execSchema(db); err != nil { + return nil, err + } + m := sqlutil.NewMigrations() + deltas.LoadIsActive(m) + if err = m.RunDeltas(db, dbProperties); err != nil { + return nil, err + } + if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil { return nil, err } @@ -73,10 +85,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver return nil, err } - m := sqlutil.NewMigrations() - deltas.LoadIsActive(m) - - return d, m.RunDeltas(db, dbProperties) + return d, nil } // GetAccountByPassword returns the account associated with the given localpart and password. diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index d0ea8a8bc..50f07237e 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -74,13 +74,13 @@ type accountsStatements struct { serverName gomatrixserverlib.ServerName } +func (s *accountsStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(accountsSchema) + return err +} + func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { s.db = db - - _, err = db.Exec(accountsSchema) - if err != nil { - return - } if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { return } diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index fe0d2eacc..0be7bcbe7 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -61,6 +61,18 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver db: db, writer: sqlutil.NewExclusiveWriter(), } + + // Create tables before executing migrations so we don't fail if the table is missing, + // and THEN prepare statements so we don't fail due to referencing new columns + if err = d.accounts.execSchema(db); err != nil { + return nil, err + } + m := sqlutil.NewMigrations() + deltas.LoadIsActive(m) + if err = m.RunDeltas(db, dbProperties); err != nil { + return nil, err + } + partitions := sqlutil.PartitionOffsetStatements{} if err = partitions.Prepare(db, d.writer, "account"); err != nil { return nil, err @@ -77,10 +89,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.threepids.prepare(db); err != nil { return nil, err } - m := sqlutil.NewMigrations() - deltas.LoadIsActive(m) - return d, m.RunDeltas(db, dbProperties) + return d, nil } // GetAccountByPassword returns the account associated with the given localpart and password. diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go index 2a4d337c7..379fed794 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/devices/postgres/devices_table.go @@ -111,11 +111,12 @@ type devicesStatements struct { serverName gomatrixserverlib.ServerName } +func (s *devicesStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(devicesSchema) + return err +} + func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - _, err = db.Exec(devicesSchema) - if err != nil { - return - } if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { return } diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index d1fabcf92..602051dcc 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -43,13 +43,21 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver return nil, err } d := devicesStatements{} + + // Create tables before executing migrations so we don't fail if the table is missing, + // and THEN prepare statements so we don't fail due to referencing new columns + d.execSchema(db) + m := sqlutil.NewMigrations() + deltas.LoadLastSeenTSIP(m) + if err = m.RunDeltas(db, dbProperties); err != nil { + return nil, err + } + if err = d.prepare(db, serverName); err != nil { return nil, err } - m := sqlutil.NewMigrations() - deltas.LoadLastSeenTSIP(m) - return &Database{db, d}, m.RunDeltas(db, dbProperties) + return &Database{db, d}, nil } // GetDeviceByAccessToken returns the device matching the given access token. diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index 6b0de10ee..26c03222a 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -98,13 +98,14 @@ type devicesStatements struct { serverName gomatrixserverlib.ServerName } +func (s *devicesStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(devicesSchema) + return err +} + func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { s.db = db s.writer = writer - _, err = db.Exec(devicesSchema) - if err != nil { - return - } if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { return } diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index 370d0d61a..8c02372db 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -47,12 +47,19 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver } writer := sqlutil.NewExclusiveWriter() d := devicesStatements{} + + // Create tables before executing migrations so we don't fail if the table is missing, + // and THEN prepare statements so we don't fail due to referencing new columns + d.execSchema(db) + m := sqlutil.NewMigrations() + deltas.LoadLastSeenTSIP(m) + if err = m.RunDeltas(db, dbProperties); err != nil { + return nil, err + } if err = d.prepare(db, writer, serverName); err != nil { return nil, err } - m := sqlutil.NewMigrations() - deltas.LoadLastSeenTSIP(m) - return &Database{db, writer, d}, m.RunDeltas(db, dbProperties) + return &Database{db, writer, d}, nil } // GetDeviceByAccessToken returns the device matching the given access token.