Fix deleting devices by local part

This commit is contained in:
Neil Alexander 2020-09-04 14:53:27 +01:00
parent d600113296
commit 1c858b867b
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
6 changed files with 23 additions and 25 deletions

View file

@ -135,11 +135,9 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 { if len(req.DeviceIDs) == 0 {
var devices []api.Device var devices []api.Device
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local) devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
for _, d := range devices { for _, d := range devices {
if d.ID != req.ExceptDeviceID { deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
}
} }
} else { } else {
err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs)

View file

@ -36,5 +36,5 @@ type Database interface {
RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted. // RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart string) (devices []api.Device, err error) RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
} }

View file

@ -70,7 +70,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" + const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1" "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" + const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@ -79,7 +79,7 @@ const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" + const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1" "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" + const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
@ -179,10 +179,10 @@ func (s *devicesStatements) deleteDevices(
// deleteDevicesByLocalpart removes all devices for the // deleteDevicesByLocalpart removes all devices for the
// given user localpart. // given user localpart.
func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart) _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
return err return err
} }
@ -251,10 +251,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
} }
func (s *devicesStatements) selectDevicesByLocalpart( func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) { ) ([]api.Device, error) {
devices := []api.Device{} devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
if err != nil { if err != nil {
return devices, err return devices, err

View file

@ -68,7 +68,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart( func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ([]api.Device, error) { ) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
} }
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -175,14 +175,14 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart. // database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error. // If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices( func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string, ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) { ) (devices []api.Device, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil { if err != nil {
return err return err
} }
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err return err
} }
return nil return nil

View file

@ -59,7 +59,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" + const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1" "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" + const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@ -68,7 +68,7 @@ const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" + const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1" "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" + const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
@ -182,10 +182,10 @@ func (s *devicesStatements) deleteDevices(
} }
func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart) _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
return err return err
} }
@ -231,10 +231,10 @@ func (s *devicesStatements) selectDeviceByID(
} }
func (s *devicesStatements) selectDevicesByLocalpart( func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) { ) ([]api.Device, error) {
devices := []api.Device{} devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
if err != nil { if err != nil {
return devices, err return devices, err

View file

@ -72,7 +72,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart( func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ([]api.Device, error) { ) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
} }
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -179,14 +179,14 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart. // database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error. // If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices( func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string, ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) { ) (devices []api.Device, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil { if err != nil {
return err return err
} }
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err return err
} }
return nil return nil