From 688cd00b9c6bdbdc23df99794a46b5c3b42cb855 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 18 Aug 2021 11:28:35 +0100 Subject: [PATCH] Move loop to within database transaction --- keyserver/internal/internal.go | 8 +++----- keyserver/storage/interface.go | 2 +- keyserver/storage/shared/storage.go | 14 ++++++++------ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index bdb5afc41..a546e94b5 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -183,11 +183,9 @@ func (a *KeyInternalAPI) claimRemoteKeys( } func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) { - for _, keyID := range req.KeyIDs { - if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, keyID); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("Failed to delete device keys: %s", err), - } + if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("Failed to delete device keys: %s", err), } } } diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 52a24791c..99842bc58 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -60,7 +60,7 @@ type Database interface { // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying // cross-signing signatures relating to that device. - DeleteDeviceKeys(ctx context.Context, userID string, deviceID gomatrixserverlib.KeyID) error + DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 664a613e9..5f73935fb 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -160,13 +160,15 @@ func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isSta // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying // cross-signing signatures relating to that device. -func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceID gomatrixserverlib.KeyID) error { +func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error { return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { - if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil { - return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err) - } - if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil { - return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err) + for _, deviceID := range deviceIDs { + if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil { + return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err) + } + if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil { + return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err) + } } return nil })