diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 92ee80d81..37c55c8f8 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -33,16 +33,17 @@ import ( "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/producers" "github.com/matrix-org/dendrite/keyserver/storage" + "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" ) type KeyInternalAPI struct { - DB storage.Database - ThisServer gomatrixserverlib.ServerName - FedClient fedsenderapi.KeyserverFederationAPI - UserAPI userapi.KeyserverUserAPI - Producer *producers.KeyChange - Updater *DeviceListUpdater + DB storage.Database + Cfg *config.KeyServer + FedClient fedsenderapi.KeyserverFederationAPI + UserAPI userapi.KeyserverUserAPI + Producer *producers.KeyChange + Updater *DeviceListUpdater } func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) { @@ -95,8 +96,11 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC nested[userID] = val domainToDeviceKeys[string(serverName)] = nested } - // claim local keys - if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok { + for domain, local := range domainToDeviceKeys { + if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + continue + } + // claim local keys keys, err := a.DB.ClaimKeys(ctx, local) if err != nil { res.Error = &api.KeyError{ @@ -117,7 +121,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON } } - delete(domainToDeviceKeys, string(a.ThisServer)) + delete(domainToDeviceKeys, domain) } if len(domainToDeviceKeys) > 0 { a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) @@ -258,7 +262,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } domain := string(serverName) // query local devices - if serverName == a.ThisServer { + if a.Cfg.Matrix.IsLocalServerName(serverName) { deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) if err != nil { res.Error = &api.KeyError{ @@ -437,13 +441,13 @@ func (a *KeyInternalAPI) queryRemoteKeys( domains := map[string]struct{}{} for domain := range domainToDeviceKeys { - if domain == string(a.ThisServer) { + if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { continue } domains[domain] = struct{}{} } for domain := range domainToCrossSigningKeys { - if domain == string(a.ThisServer) { + if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { continue } domains[domain] = struct{}{} @@ -689,7 +693,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per if err != nil { continue // ignore invalid users } - if serverName != a.ThisServer { + if !a.Cfg.Matrix.IsLocalServerName(serverName) { continue // ignore remote users } if len(key.KeyJSON) == 0 { diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 9ae4f9ca3..0a4b8fde6 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -53,10 +53,10 @@ func NewInternalAPI( DB: db, } ap := &internal.KeyInternalAPI{ - DB: db, - ThisServer: cfg.Matrix.ServerName, - FedClient: fedClient, - Producer: keyChangeProducer, + DB: db, + Cfg: cfg, + FedClient: fedClient, + Producer: keyChangeProducer, } updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable ap.Updater = updater