mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-23 14:53:10 -06:00
Fix deleting devices by local part
This commit is contained in:
parent
d600113296
commit
1c858b867b
|
|
@ -135,12 +135,10 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
|
||||||
deletedDeviceIDs := req.DeviceIDs
|
deletedDeviceIDs := req.DeviceIDs
|
||||||
if len(req.DeviceIDs) == 0 {
|
if len(req.DeviceIDs) == 0 {
|
||||||
var devices []api.Device
|
var devices []api.Device
|
||||||
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local)
|
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
|
||||||
for _, d := range devices {
|
for _, d := range devices {
|
||||||
if d.ID != req.ExceptDeviceID {
|
|
||||||
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
|
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs)
|
err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -36,5 +36,5 @@ type Database interface {
|
||||||
RemoveDevice(ctx context.Context, deviceID, localpart string) error
|
RemoveDevice(ctx context.Context, deviceID, localpart string) error
|
||||||
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
||||||
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
||||||
RemoveAllDevices(ctx context.Context, localpart string) (devices []api.Device, err error)
|
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ const selectDeviceByIDSQL = "" +
|
||||||
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||||
|
|
||||||
const selectDevicesByLocalpartSQL = "" +
|
const selectDevicesByLocalpartSQL = "" +
|
||||||
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
|
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||||
|
|
||||||
const updateDeviceNameSQL = "" +
|
const updateDeviceNameSQL = "" +
|
||||||
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||||
|
|
@ -79,7 +79,7 @@ const deleteDeviceSQL = "" +
|
||||||
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
||||||
|
|
||||||
const deleteDevicesByLocalpartSQL = "" +
|
const deleteDevicesByLocalpartSQL = "" +
|
||||||
"DELETE FROM device_devices WHERE localpart = $1"
|
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||||
|
|
||||||
const deleteDevicesSQL = "" +
|
const deleteDevicesSQL = "" +
|
||||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
||||||
|
|
@ -179,10 +179,10 @@ func (s *devicesStatements) deleteDevices(
|
||||||
// deleteDevicesByLocalpart removes all devices for the
|
// deleteDevicesByLocalpart removes all devices for the
|
||||||
// given user localpart.
|
// given user localpart.
|
||||||
func (s *devicesStatements) deleteDevicesByLocalpart(
|
func (s *devicesStatements) deleteDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||||
_, err := stmt.ExecContext(ctx, localpart)
|
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -251,10 +251,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) selectDevicesByLocalpart(
|
func (s *devicesStatements) selectDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||||
) ([]api.Device, error) {
|
) ([]api.Device, error) {
|
||||||
devices := []api.Device{}
|
devices := []api.Device{}
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart)
|
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return devices, err
|
return devices, err
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ func (d *Database) GetDeviceByID(
|
||||||
func (d *Database) GetDevicesByLocalpart(
|
func (d *Database) GetDevicesByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) ([]api.Device, error) {
|
) ([]api.Device, error) {
|
||||||
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart)
|
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||||
|
|
@ -175,14 +175,14 @@ func (d *Database) RemoveDevices(
|
||||||
// database matching the given user ID localpart.
|
// database matching the given user ID localpart.
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
func (d *Database) RemoveAllDevices(
|
func (d *Database) RemoveAllDevices(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart, exceptDeviceID string,
|
||||||
) (devices []api.Device, err error) {
|
) (devices []api.Device, err error) {
|
||||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart)
|
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
|
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ const selectDeviceByIDSQL = "" +
|
||||||
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||||
|
|
||||||
const selectDevicesByLocalpartSQL = "" +
|
const selectDevicesByLocalpartSQL = "" +
|
||||||
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
|
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||||
|
|
||||||
const updateDeviceNameSQL = "" +
|
const updateDeviceNameSQL = "" +
|
||||||
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||||
|
|
@ -68,7 +68,7 @@ const deleteDeviceSQL = "" +
|
||||||
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
||||||
|
|
||||||
const deleteDevicesByLocalpartSQL = "" +
|
const deleteDevicesByLocalpartSQL = "" +
|
||||||
"DELETE FROM device_devices WHERE localpart = $1"
|
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||||
|
|
||||||
const deleteDevicesSQL = "" +
|
const deleteDevicesSQL = "" +
|
||||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||||
|
|
@ -182,10 +182,10 @@ func (s *devicesStatements) deleteDevices(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) deleteDevicesByLocalpart(
|
func (s *devicesStatements) deleteDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||||
_, err := stmt.ExecContext(ctx, localpart)
|
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -231,10 +231,10 @@ func (s *devicesStatements) selectDeviceByID(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) selectDevicesByLocalpart(
|
func (s *devicesStatements) selectDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||||
) ([]api.Device, error) {
|
) ([]api.Device, error) {
|
||||||
devices := []api.Device{}
|
devices := []api.Device{}
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart)
|
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return devices, err
|
return devices, err
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,7 @@ func (d *Database) GetDeviceByID(
|
||||||
func (d *Database) GetDevicesByLocalpart(
|
func (d *Database) GetDevicesByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) ([]api.Device, error) {
|
) ([]api.Device, error) {
|
||||||
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart)
|
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||||
|
|
@ -179,14 +179,14 @@ func (d *Database) RemoveDevices(
|
||||||
// database matching the given user ID localpart.
|
// database matching the given user ID localpart.
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
func (d *Database) RemoveAllDevices(
|
func (d *Database) RemoveAllDevices(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart, exceptDeviceID string,
|
||||||
) (devices []api.Device, err error) {
|
) (devices []api.Device, err error) {
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart)
|
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
|
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue