From 8c57745c8207fd1850231ffbf3d255219125f365 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 20 Aug 2020 17:55:56 +0100 Subject: [PATCH] Un-deadlock device database --- .../storage/devices/sqlite3/devices_table.go | 62 +++++++------------ userapi/storage/devices/sqlite3/storage.go | 8 +-- 2 files changed, 28 insertions(+), 42 deletions(-) diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index c902adc76..ddb2dbe45 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -138,19 +138,13 @@ func (s *devicesStatements) insertDevice( ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) - insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) - if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { - return err - } - sessionID++ - if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { - return err - } - return nil - }) - if err != nil { + countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) + insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) + if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { + return nil, err + } + sessionID++ + if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { return nil, err } return &api.Device{ @@ -164,11 +158,9 @@ func (s *devicesStatements) insertDevice( func (s *devicesStatements) deleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) - _, err := stmt.ExecContext(ctx, id, localpart) - return err - }) + stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) + _, err := stmt.ExecContext(ctx, id, localpart) + return err } func (s *devicesStatements) deleteDevices( @@ -179,36 +171,30 @@ func (s *devicesStatements) deleteDevices( if err != nil { return err } - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, prep) - params := make([]interface{}, len(devices)+1) - params[0] = localpart - for i, v := range devices { - params[i+1] = v - } - _, err = stmt.ExecContext(ctx, params...) - return err - }) + stmt := sqlutil.TxStmt(txn, prep) + params := make([]interface{}, len(devices)+1) + params[0] = localpart + for i, v := range devices { + params[i+1] = v + } + _, err = stmt.ExecContext(ctx, params...) + return err } func (s *devicesStatements) deleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart string, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart) - return err - }) + stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) + _, err := stmt.ExecContext(ctx, localpart) + return err } func (s *devicesStatements) updateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) - _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) - return err - }) + stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) + _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + return err } func (s *devicesStatements) selectDeviceByToken( diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index 99d4b1fb9..e46ad981d 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -140,7 +140,7 @@ func generateDeviceID() (string, error) { func (d *Database) UpdateDevice( ctx context.Context, localpart, deviceID string, displayName *string, ) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) }) } @@ -152,7 +152,7 @@ func (d *Database) UpdateDevice( func (d *Database) RemoveDevice( ctx context.Context, deviceID, localpart string, ) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { return err } @@ -167,7 +167,7 @@ func (d *Database) RemoveDevice( func (d *Database) RemoveDevices( ctx context.Context, localpart string, devices []string, ) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { return err } @@ -181,7 +181,7 @@ func (d *Database) RemoveDevices( func (d *Database) RemoveAllDevices( ctx context.Context, localpart string, ) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { return err }