Only care about wildcard when targeted locally

This commit is contained in:
Neil Alexander 2020-05-29 17:37:13 +01:00
parent aba631d86c
commit 5cf900428c
10 changed files with 63 additions and 47 deletions

View file

@ -39,7 +39,7 @@ func main() {
rsAPI := base.CreateHTTPRoomserverAPIs() rsAPI := base.CreateHTTPRoomserverAPIs()
fsAPI := base.CreateHTTPFederationSenderAPIs() fsAPI := base.CreateHTTPFederationSenderAPIs()
rsAPI.SetFederationSenderAPI(fsAPI) rsAPI.SetFederationSenderAPI(fsAPI)
eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New()) eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB)
clientapi.SetupClientAPIComponent( clientapi.SetupClientAPIComponent(
base, deviceDB, accountDB, federation, keyRing, base, deviceDB, accountDB, federation, keyRing,

View file

@ -148,7 +148,7 @@ func main() {
&base.Base, keyRing, federation, &base.Base, keyRing, federation,
) )
eduInputAPI := eduserver.SetupEDUServerComponent( eduInputAPI := eduserver.SetupEDUServerComponent(
&base.Base, cache.New(), &base.Base, cache.New(), deviceDB,
) )
asAPI := appservice.SetupAppServiceAPIComponent( asAPI := appservice.SetupAppServiceAPIComponent(
&base.Base, accountDB, deviceDB, federation, rsAPI, transactions.New(), &base.Base, accountDB, deviceDB, federation, rsAPI, transactions.New(),

View file

@ -29,8 +29,9 @@ func main() {
logrus.WithError(err).Warn("BaseDendrite close failed") logrus.WithError(err).Warn("BaseDendrite close failed")
} }
}() }()
deviceDB := base.CreateDeviceDB()
eduserver.SetupEDUServerComponent(base, cache.New()) eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB)
base.SetupAndServeHTTP(string(base.Cfg.Bind.EDUServer), string(base.Cfg.Listen.EDUServer)) base.SetupAndServeHTTP(string(base.Cfg.Bind.EDUServer), string(base.Cfg.Listen.EDUServer))

View file

@ -39,7 +39,7 @@ func main() {
rsAPI := base.CreateHTTPRoomserverAPIs() rsAPI := base.CreateHTTPRoomserverAPIs()
asAPI := base.CreateHTTPAppServiceAPIs() asAPI := base.CreateHTTPAppServiceAPIs()
rsAPI.SetFederationSenderAPI(fsAPI) rsAPI.SetFederationSenderAPI(fsAPI)
eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New()) eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB)
eduProducer := producers.NewEDUServerProducer(eduInputAPI) eduProducer := producers.NewEDUServerProducer(eduInputAPI)
federationapi.SetupFederationAPIComponent( federationapi.SetupFederationAPIComponent(

View file

@ -87,7 +87,7 @@ func main() {
} }
eduInputAPI := eduserver.SetupEDUServerComponent( eduInputAPI := eduserver.SetupEDUServerComponent(
base, cache.New(), base, cache.New(), deviceDB,
) )
if base.EnableHTTPAPIs { if base.EnableHTTPAPIs {
eduInputAPI = base.CreateHTTPEDUServerAPIs() eduInputAPI = base.CreateHTTPEDUServerAPIs()

View file

@ -13,6 +13,7 @@
package eduserver package eduserver
import ( import (
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/eduserver/cache"
"github.com/matrix-org/dendrite/eduserver/input" "github.com/matrix-org/dendrite/eduserver/input"
@ -26,12 +27,15 @@ import (
func SetupEDUServerComponent( func SetupEDUServerComponent(
base *basecomponent.BaseDendrite, base *basecomponent.BaseDendrite,
eduCache *cache.EDUCache, eduCache *cache.EDUCache,
deviceDB devices.Database,
) api.EDUServerInputAPI { ) api.EDUServerInputAPI {
inputAPI := &input.EDUServerInputAPI{ inputAPI := &input.EDUServerInputAPI{
Cache: eduCache, Cache: eduCache,
DeviceDB: deviceDB,
Producer: base.KafkaProducer, Producer: base.KafkaProducer,
OutputTypingEventTopic: string(base.Cfg.Kafka.Topics.OutputTypingEvent), OutputTypingEventTopic: string(base.Cfg.Kafka.Topics.OutputTypingEvent),
OutputSendToDeviceEventTopic: string(base.Cfg.Kafka.Topics.OutputSendToDeviceEventTopic), OutputSendToDeviceEventTopic: string(base.Cfg.Kafka.Topics.OutputSendToDeviceEventTopic),
ServerName: base.Cfg.Matrix.ServerName,
} }
inputAPI.SetupHTTP(base.InternalAPIMux) inputAPI.SetupHTTP(base.InternalAPIMux)

View file

@ -20,6 +20,7 @@ import (
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/eduserver/cache"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
@ -38,6 +39,10 @@ type EDUServerInputAPI struct {
OutputSendToDeviceEventTopic string OutputSendToDeviceEventTopic string
// kafka producer // kafka producer
Producer sarama.SyncProducer Producer sarama.SyncProducer
// device database
DeviceDB devices.Database
// our server name
ServerName gomatrixserverlib.ServerName
} }
// InputTypingEvent implements api.EDUServerInputAPI // InputTypingEvent implements api.EDUServerInputAPI
@ -104,9 +109,28 @@ func (t *EDUServerInputAPI) sendTypingEvent(ite *api.InputTypingEvent) error {
} }
func (t *EDUServerInputAPI) sendToDeviceEvent(ise *api.InputSendToDeviceEvent) error { func (t *EDUServerInputAPI) sendToDeviceEvent(ise *api.InputSendToDeviceEvent) error {
devices := []string{}
localpart, domain, err := gomatrixserverlib.SplitID('@', ise.UserID)
if err != nil {
return err
}
if domain == t.ServerName && ise.DeviceID == "*" {
devs, err := t.DeviceDB.GetDevicesByLocalpart(context.TODO(), localpart)
if err != nil {
return err
}
for _, dev := range devs {
devices = append(devices, dev.ID)
}
} else {
devices = append(devices, ise.DeviceID)
}
for _, device := range devices {
ote := &api.OutputSendToDeviceEvent{ ote := &api.OutputSendToDeviceEvent{
UserID: ise.UserID, UserID: ise.UserID,
DeviceID: ise.DeviceID, DeviceID: device,
SendToDeviceEvent: ise.SendToDeviceEvent, SendToDeviceEvent: ise.SendToDeviceEvent,
} }
@ -131,9 +155,12 @@ func (t *EDUServerInputAPI) sendToDeviceEvent(ise *api.InputSendToDeviceEvent) e
_, _, err = t.Producer.SendMessage(m) _, _, err = t.Producer.SendMessage(m)
if err != nil { if err != nil {
logrus.WithError(err).Error("sendToDevice failed t.Producer.SendMessage") logrus.WithError(err).Error("sendToDevice failed t.Producer.SendMessage")
}
return err return err
} }
}
return nil
}
// SetupHTTP adds the EDUServerInputAPI handlers to the http.ServeMux. // SetupHTTP adds the EDUServerInputAPI handlers to the http.ServeMux.
func (t *EDUServerInputAPI) SetupHTTP(internalAPIMux *mux.Router) { func (t *EDUServerInputAPI) SetupHTTP(internalAPIMux *mux.Router) {

View file

@ -19,7 +19,6 @@ import (
"encoding/json" "encoding/json"
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
@ -35,7 +34,6 @@ import (
type OutputSendToDeviceEventConsumer struct { type OutputSendToDeviceEventConsumer struct {
sendToDeviceConsumer *internal.ContinualConsumer sendToDeviceConsumer *internal.ContinualConsumer
db storage.Database db storage.Database
deviceDB devices.Database
serverName gomatrixserverlib.ServerName // our server name serverName gomatrixserverlib.ServerName // our server name
notifier *sync.Notifier notifier *sync.Notifier
} }
@ -47,7 +45,6 @@ func NewOutputSendToDeviceEventConsumer(
kafkaConsumer sarama.Consumer, kafkaConsumer sarama.Consumer,
n *sync.Notifier, n *sync.Notifier,
store storage.Database, store storage.Database,
deviceDB devices.Database,
) *OutputSendToDeviceEventConsumer { ) *OutputSendToDeviceEventConsumer {
consumer := internal.ContinualConsumer{ consumer := internal.ContinualConsumer{
@ -59,7 +56,6 @@ func NewOutputSendToDeviceEventConsumer(
s := &OutputSendToDeviceEventConsumer{ s := &OutputSendToDeviceEventConsumer{
sendToDeviceConsumer: &consumer, sendToDeviceConsumer: &consumer,
db: store, db: store,
deviceDB: deviceDB,
serverName: cfg.Matrix.ServerName, serverName: cfg.Matrix.ServerName,
notifier: n, notifier: n,
} }
@ -82,7 +78,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage)
return err return err
} }
localpart, domain, err := gomatrixserverlib.SplitID('@', output.UserID) _, domain, err := gomatrixserverlib.SplitID('@', output.UserID)
if err != nil { if err != nil {
return err return err
} }
@ -107,22 +103,9 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage)
return err return err
} }
devices := []string{}
if output.DeviceID == "*" {
devs, err := s.deviceDB.GetDevicesByLocalpart(context.TODO(), localpart)
if err != nil {
return err
}
for _, dev := range devs {
devices = append(devices, dev.ID)
}
} else {
devices = append(devices, output.DeviceID)
}
s.notifier.OnNewSendToDevice( s.notifier.OnNewSendToDevice(
output.UserID, output.UserID,
devices, []string{output.DeviceID},
types.NewStreamToken(0, streamPos), types.NewStreamToken(0, streamPos),
) )

View file

@ -82,7 +82,7 @@ func SetupSyncAPIComponent(
} }
sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer( sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer(
base.Cfg, base.KafkaConsumer, notifier, syncDB, deviceDB, base.Cfg, base.KafkaConsumer, notifier, syncDB,
) )
if err = sendToDeviceConsumer.Start(); err != nil { if err = sendToDeviceConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start send-to-device consumer") logrus.WithError(err).Panicf("failed to start send-to-device consumer")

View file

@ -300,3 +300,4 @@ Can send messages with a wildcard device id
Can send messages with a wildcard device id to two devices Can send messages with a wildcard device id to two devices
Wildcard device messages wake up /sync Wildcard device messages wake up /sync
Wildcard device messages over federation wake up /sync Wildcard device messages over federation wake up /sync
Can send a to-device message to two users which both receive it using /sync