diff --git a/src/github.com/matrix-org/dendrite/clientapi/consumers/roomserver.go b/src/github.com/matrix-org/dendrite/clientapi/consumers/roomserver.go index 5f3c813ff..6173e2077 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/consumers/roomserver.go +++ b/src/github.com/matrix-org/dendrite/clientapi/consumers/roomserver.go @@ -16,8 +16,6 @@ package consumers import ( "encoding/json" - "fmt" - "strings" log "github.com/Sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" @@ -33,6 +31,7 @@ type OutputRoomEvent struct { roomServerConsumer *common.ContinualConsumer db *accounts.Database query api.RoomserverQueryAPI + serverName string } // NewOutputRoomEvent creates a new OutputRoomEvent consumer. Call Start() to begin consuming from room servers. @@ -52,6 +51,7 @@ func NewOutputRoomEvent(cfg *config.Dendrite, store *accounts.Database) (*Output roomServerConsumer: &consumer, db: store, query: api.NewRoomserverQueryAPIHTTP(roomServerURL, nil), + serverName: string(cfg.Matrix.ServerName), } consumer.ProcessMessage = s.onMessage @@ -68,18 +68,21 @@ func (s *OutputRoomEvent) Start() error { // sync stream position may race and be incorrectly calculated. func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error { // Parse out the event JSON - var output api.OutputRoomEvent + var output api.OutputEvent if err := json.Unmarshal(msg.Value, &output); err != nil { // If the message was invalid, log it and move on to the next message in the stream log.WithError(err).Errorf("roomserver output log: message parse failure") return nil } - ev, err := gomatrixserverlib.NewEventFromTrustedJSON(output.Event, false) - if err != nil { - log.WithError(err).Errorf("roomserver output log: event parse failure") + if output.Type != api.OutputTypeNewRoomEvent { + log.WithField("type", output.Type).Debug( + "roomserver output log: ignoring unknown output type", + ) return nil } + + ev := output.NewRoomEvent.Event log.WithFields(log.Fields{ "event_id": ev.EventID(), "room_id": ev.RoomID(), @@ -87,12 +90,22 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error { }).Info("received event from roomserver") if ev.Type() == "m.room.member" && ev.StateKey() != nil { - localpart := getLocalPart(*ev.StateKey()) + localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) + if err != nil { + return err + } + + // we only want state events from local users + if string(serverName) != s.serverName { + return nil + } + roomID := ev.RoomID() membership, err := ev.Membership() if err != nil { return err } + switch membership { case "join": if err := s.db.SaveMembership(localpart, roomID); err != nil { @@ -108,14 +121,3 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error { return nil } - -func getLocalPart(userID string) string { - if !strings.HasPrefix(userID, "@") { - panic(fmt.Errorf("Invalid user ID")) - } - - // Get the part before ":" - username := strings.Split(userID, ":")[0] - // Return the part after the "@" - return strings.Split(username, "@")[1] -}