diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 11bca6474..c9ec59a75 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -340,5 +340,6 @@ type QuerySignaturesResponse struct { type PerformMarkAsStaleRequest struct { UserID string + Domain gomatrixserverlib.ServerName DeviceID string } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 569589516..2d85734ba 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -233,11 +233,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap return err } if len(knownDevices) == 0 { - _, remoteServer, err := gomatrixserverlib.SplitID('@', req.UserID) - if err != nil { - return err - } - return a.Updater.ManualUpdate(ctx, remoteServer, req.UserID) + return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID) } return nil } diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go index aaf1879c5..1735db917 100644 --- a/syncapi/consumers/sendtodevice.go +++ b/syncapi/consumers/sendtodevice.go @@ -109,14 +109,14 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs [] }) logger.Debugf("sync API received send-to-device event from the clientapi/federationsender") - // Check we actually got the requesting device in our store + // Check we actually got the requesting device in our store, if we receive a room key request if output.Type == "m.room_key_request" { requestingDeviceID := gjson.GetBytes(output.SendToDeviceEvent.Content, "requesting_device_id").Str - _, domain, _ := gomatrixserverlib.SplitID('@', output.Sender) - if requestingDeviceID != "" && domain != s.serverName { + _, senderDomain, _ := gomatrixserverlib.SplitID('@', output.Sender) + if requestingDeviceID != "" && senderDomain != s.serverName { // Mark the requesting device as stale, if we don't know about it. if err := s.keyAPI.PerformMarkAsStaleIfNeeded(ctx, &keyapi.PerformMarkAsStaleRequest{ - UserID: output.Sender, DeviceID: requestingDeviceID, + UserID: output.Sender, Domain: senderDomain, DeviceID: requestingDeviceID, }, &struct{}{}); err != nil { logger.WithError(err).Errorf("failed to mark as stale if needed") return false