Move loop to within database transaction

This commit is contained in:
Neil Alexander 2021-08-18 11:28:35 +01:00
parent 8ef8e073fd
commit 688cd00b9c
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
3 changed files with 12 additions and 12 deletions

View file

@ -183,11 +183,9 @@ func (a *KeyInternalAPI) claimRemoteKeys(
}
func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) {
for _, keyID := range req.KeyIDs {
if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, keyID); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to delete device keys: %s", err),
}
if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to delete device keys: %s", err),
}
}
}

View file

@ -60,7 +60,7 @@ type Database interface {
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
// cross-signing signatures relating to that device.
DeleteDeviceKeys(ctx context.Context, userID string, deviceID gomatrixserverlib.KeyID) error
DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.

View file

@ -160,13 +160,15 @@ func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isSta
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
// cross-signing signatures relating to that device.
func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceID gomatrixserverlib.KeyID) error {
func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error {
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil {
return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err)
}
if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil {
return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err)
for _, deviceID := range deviceIDs {
if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil {
return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err)
}
if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil {
return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err)
}
}
return nil
})