diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 8c7b610ed..461c548cc 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -135,11 +135,9 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe deletedDeviceIDs := req.DeviceIDs if len(req.DeviceIDs) == 0 { var devices []api.Device - devices, err = a.DeviceDB.RemoveAllDevices(ctx, local) + devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) for _, d := range devices { - if d.ID != req.ExceptDeviceID { - deletedDeviceIDs = append(deletedDeviceIDs, d.ID) - } + deletedDeviceIDs = append(deletedDeviceIDs, d.ID) } } else { err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index 9b4261c9d..168c84c5c 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -36,5 +36,5 @@ type Database interface { RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. - RemoveAllDevices(ctx context.Context, localpart string) (devices []api.Device, err error) + RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) } diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go index 282466f8d..c06af7549 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/devices/postgres/devices_table.go @@ -70,7 +70,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name FROM device_devices WHERE localpart = $1" + "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2" const updateDeviceNameSQL = "" + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" @@ -79,7 +79,7 @@ const deleteDeviceSQL = "" + "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" @@ -179,10 +179,10 @@ func (s *devicesStatements) deleteDevices( // deleteDevicesByLocalpart removes all devices for the // given user localpart. func (s *devicesStatements) deleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart) + _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) return err } @@ -251,10 +251,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s } func (s *devicesStatements) selectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) if err != nil { return devices, err diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index 04dae9864..c5bd5b6cf 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -68,7 +68,7 @@ func (d *Database) GetDeviceByID( func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -175,14 +175,14 @@ func (d *Database) RemoveDevices( // database matching the given user ID localpart. // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( - ctx context.Context, localpart string, + ctx context.Context, localpart, exceptDeviceID string, ) (devices []api.Device, err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) if err != nil { return err } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { return err } return nil diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index ecf43524a..c75e19825 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -59,7 +59,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name FROM device_devices WHERE localpart = $1" + "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2" const updateDeviceNameSQL = "" + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" @@ -68,7 +68,7 @@ const deleteDeviceSQL = "" + "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" @@ -182,10 +182,10 @@ func (s *devicesStatements) deleteDevices( } func (s *devicesStatements) deleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart) + _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) return err } @@ -231,10 +231,10 @@ func (s *devicesStatements) selectDeviceByID( } func (s *devicesStatements) selectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) if err != nil { return devices, err diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index f775fb664..7c6645dd6 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -72,7 +72,7 @@ func (d *Database) GetDeviceByID( func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -179,14 +179,14 @@ func (d *Database) RemoveDevices( // database matching the given user ID localpart. // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( - ctx context.Context, localpart string, + ctx context.Context, localpart, exceptDeviceID string, ) (devices []api.Device, err error) { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) if err != nil { return err } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { return err } return nil