From 8f54d33f1d888cb5c0b7810cbe4f683db03d2723 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 15 Nov 2022 09:57:40 +0000 Subject: [PATCH] Fix keyserver consumer maybe --- keyserver/consumers/devicelistupdate.go | 26 +++++++++++------------ keyserver/consumers/signingkeyupdate.go | 28 +++++++++++++------------ 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/keyserver/consumers/devicelistupdate.go b/keyserver/consumers/devicelistupdate.go index 575e41281..cd911f8c6 100644 --- a/keyserver/consumers/devicelistupdate.go +++ b/keyserver/consumers/devicelistupdate.go @@ -30,12 +30,12 @@ import ( // DeviceListUpdateConsumer consumes device list updates that came in over federation. type DeviceListUpdateConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - updater *internal.DeviceListUpdater - serverName gomatrixserverlib.ServerName + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + updater *internal.DeviceListUpdater + isLocalServerName func(gomatrixserverlib.ServerName) bool } // NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. @@ -46,12 +46,12 @@ func NewDeviceListUpdateConsumer( updater *internal.DeviceListUpdater, ) *DeviceListUpdateConsumer { return &DeviceListUpdateConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), - updater: updater, - serverName: cfg.Matrix.ServerName, + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + updater: updater, + isLocalServerName: cfg.Matrix.IsLocalServerName, } } @@ -75,7 +75,7 @@ func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil { return true - } else if serverName == t.serverName { + } else if t.isLocalServerName(serverName) { return true } else if serverName != origin { return true diff --git a/keyserver/consumers/signingkeyupdate.go b/keyserver/consumers/signingkeyupdate.go index 366e259b4..bcceaad15 100644 --- a/keyserver/consumers/signingkeyupdate.go +++ b/keyserver/consumers/signingkeyupdate.go @@ -31,12 +31,13 @@ import ( // SigningKeyUpdateConsumer consumes signing key updates that came in over federation. type SigningKeyUpdateConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - keyAPI *internal.KeyInternalAPI - cfg *config.KeyServer + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + keyAPI *internal.KeyInternalAPI + cfg *config.KeyServer + isLocalServerName func(gomatrixserverlib.ServerName) bool } // NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers. @@ -47,12 +48,13 @@ func NewSigningKeyUpdateConsumer( keyAPI *internal.KeyInternalAPI, ) *SigningKeyUpdateConsumer { return &SigningKeyUpdateConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), - keyAPI: keyAPI, - cfg: cfg, + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + keyAPI: keyAPI, + cfg: cfg, + isLocalServerName: cfg.Matrix.IsLocalServerName, } } @@ -77,7 +79,7 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil { logrus.WithError(err).Error("failed to split user id") return true - } else if serverName == t.cfg.Matrix.ServerName { + } else if t.isLocalServerName(serverName) { logrus.Warn("dropping device key update from ourself") return true } else if serverName != origin {