diff --git a/clientapi/producers/syncapi.go b/clientapi/producers/syncapi.go index 814dd70a6..047e787c0 100644 --- a/clientapi/producers/syncapi.go +++ b/clientapi/producers/syncapi.go @@ -182,11 +182,9 @@ func (p *SyncAPIProducer) SendPresence( m := nats.NewMsg(p.TopicPresenceEvent) m.Header.Set(jetstream.UserID, userID) m.Header.Set("presence", presence) - nilMsg := statusMsg == nil - if !nilMsg { + if statusMsg != nil { m.Header.Set("status_msg", *statusMsg) } - m.Header.Set("status_msg_nil", strconv.FormatBool(nilMsg)) m.Header.Set("last_active_ts", strconv.Itoa(int(gomatrixserverlib.AsTimestamp(time.Now())))) diff --git a/federationapi/consumers/presence.go b/federationapi/consumers/presence.go index 2c5cc5e9c..8439cc6e8 100644 --- a/federationapi/consumers/presence.go +++ b/federationapi/consumers/presence.go @@ -84,10 +84,8 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) b } presence := msg.Header.Get("presence") - statusMsg := msg.Header.Get("status_msg") - nilStatusMsg, _ := strconv.ParseBool(msg.Header.Get("status_msg_nil")) - ts, err := strconv.Atoi(msg.Header.Get("last_active_ts")) + ts, err := strconv.Atoi(msg.Header.Get("last_active_ts")) if err != nil { return true } @@ -101,9 +99,10 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) b return true } - newStatusMsg := &statusMsg - if nilStatusMsg { - newStatusMsg = nil + var statusMsg *string = nil + if data, ok := msg.Header["status_msg"]; ok && len(data) > 0 { + status := msg.Header.Get("status_msg") + statusMsg = &status } p := types.Presence{LastActiveTS: gomatrixserverlib.Timestamp(ts)} @@ -114,7 +113,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) b CurrentlyActive: p.CurrentlyActive(), LastActiveAgo: p.LastActiveAgo(), Presence: presence, - StatusMsg: newStatusMsg, + StatusMsg: statusMsg, UserID: userID, }, }, diff --git a/federationapi/producers/syncapi.go b/federationapi/producers/syncapi.go index 81c289f27..e80d50521 100644 --- a/federationapi/producers/syncapi.go +++ b/federationapi/producers/syncapi.go @@ -154,7 +154,6 @@ func (p *SyncAPIProducer) SendPresence( if statusMsg != nil { m.Header.Set("status_msg", *statusMsg) } - m.Header.Set("status_msg_nil", strconv.FormatBool(statusMsg == nil)) lastActiveTS := gomatrixserverlib.AsTimestamp(time.Now().Add(-(time.Duration(lastActiveAgo) * time.Millisecond))) m.Header.Set("last_active_ts", strconv.Itoa(int(lastActiveTS))) diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go index f768d72cd..e4be2477d 100644 --- a/syncapi/consumers/presence.go +++ b/syncapi/consumers/presence.go @@ -124,10 +124,8 @@ func (s *PresenceConsumer) Start() error { func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { userID := msg.Header.Get(jetstream.UserID) presence := msg.Header.Get("presence") - statusMsg := msg.Header.Get("status_msg") timestamp := msg.Header.Get("last_active_ts") fromSync, _ := strconv.ParseBool(msg.Header.Get("from_sync")) - nilStatusMsg, _ := strconv.ParseBool(msg.Header.Get("status_msg_nil")) logrus.Debugf("syncAPI received presence event: %+v", msg.Header) @@ -136,12 +134,13 @@ func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { return true } - newStatusMsg := &statusMsg - if nilStatusMsg { - newStatusMsg = nil + var statusMsg *string = nil + if data, ok := msg.Header["status_msg"]; ok && len(data) > 0 { + newMsg := msg.Header.Get("status_msg") + statusMsg = &newMsg } - pos, err := s.db.UpdatePresence(ctx, userID, presence, newStatusMsg, gomatrixserverlib.Timestamp(ts), fromSync) + pos, err := s.db.UpdatePresence(ctx, userID, presence, statusMsg, gomatrixserverlib.Timestamp(ts), fromSync) if err != nil { return true } diff --git a/syncapi/producers/federationapi_presence.go b/syncapi/producers/federationapi_presence.go index 09b6a9271..980e793d6 100644 --- a/syncapi/producers/federationapi_presence.go +++ b/syncapi/producers/federationapi_presence.go @@ -31,7 +31,7 @@ type FederationAPIPresenceProducer struct { } func (f *FederationAPIPresenceProducer) SendPresence( - userID, presence string, + userID, presence string, statusMsg *string, ) error { msg := nats.NewMsg(f.Topic) msg.Header.Set(jetstream.UserID, userID) @@ -39,6 +39,10 @@ func (f *FederationAPIPresenceProducer) SendPresence( msg.Header.Set("from_sync", "true") // only update last_active_ts and presence msg.Header.Set("last_active_ts", strconv.Itoa(int(gomatrixserverlib.AsTimestamp(time.Now())))) + if statusMsg != nil { + msg.Header.Set("status_msg", *statusMsg) + } + _, err := f.JetStream.PublishMsg(msg) return err } diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 7fa943b3d..e1004dd7e 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -17,6 +17,8 @@ package sync import ( + "context" + "database/sql" "net" "net/http" "strings" @@ -54,7 +56,7 @@ type RequestPool struct { } type PresencePublisher interface { - SendPresence(userID, presence string) error + SendPresence(userID, presence string, statusMsg *string) error } // NewRequestPool makes a new RequestPool @@ -78,7 +80,7 @@ func NewRequestPool( producer: producer, } go rp.cleanLastSeen() - go rp.cleanPresence(time.Minute * 5) + go rp.cleanPresence(db, time.Minute*5) return rp } @@ -92,12 +94,12 @@ func (rp *RequestPool) cleanLastSeen() { } } -func (rp *RequestPool) cleanPresence(cleanupTime time.Duration) { +func (rp *RequestPool) cleanPresence(db storage.Presence, cleanupTime time.Duration) { for { rp.presence.Range(func(key interface{}, v interface{}) bool { p := v.(types.Presence) if time.Since(p.LastActiveTS.Time()) > cleanupTime { - rp.updatePresence("unavailable", p.UserID) + rp.updatePresence(db, "unavailable", p.UserID) rp.presence.Delete(key) } return true @@ -107,7 +109,7 @@ func (rp *RequestPool) cleanPresence(cleanupTime time.Duration) { } // updatePresence sends presence updates to the SyncAPI and FederationAPI -func (rp *RequestPool) updatePresence(presence string, userID string) { +func (rp *RequestPool) updatePresence(db storage.Presence, presence string, userID string) { if rp.cfg.Matrix.DisablePresence { return } @@ -132,7 +134,13 @@ func (rp *RequestPool) updatePresence(presence string, userID string) { } } - if err := rp.producer.SendPresence(userID, strings.ToLower(presence)); err != nil { + // ensure we also send the current status_msg to federated servers and not nil + dbPresence, err := db.GetPresence(context.Background(), userID) + if err != nil && err != sql.ErrNoRows { + return + } + + if err := rp.producer.SendPresence(userID, strings.ToLower(presence), dbPresence.ClientFields.StatusMsg); err != nil { logrus.WithError(err).Error("Unable to publish presence message from sync") return } @@ -214,7 +222,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. defer activeSyncRequests.Dec() rp.updateLastSeen(req, device) - rp.updatePresence(req.FormValue("set_presence"), device.UserID) + rp.updatePresence(rp.db, req.FormValue("set_presence"), device.UserID) waitingSyncRequests.Inc() defer waitingSyncRequests.Dec() diff --git a/syncapi/sync/requestpool_test.go b/syncapi/sync/requestpool_test.go index 256e31d4e..f2e89bd99 100644 --- a/syncapi/sync/requestpool_test.go +++ b/syncapi/sync/requestpool_test.go @@ -1,22 +1,43 @@ package sync import ( + "context" "sync" "testing" "time" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" ) type dummyPublisher struct { count int } -func (d *dummyPublisher) SendPresence(userID, presence string) error { +func (d *dummyPublisher) SendPresence(userID, presence string, statusMsg *string) error { d.count++ return nil } +type dummyDB struct{} + +func (d dummyDB) UpdatePresence(ctx context.Context, userID, presence string, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { + return 0, nil +} + +func (d dummyDB) GetPresence(ctx context.Context, userID string) (*types.Presence, error) { + return &types.Presence{}, nil +} + +func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition) (map[string]*types.Presence, error) { + return map[string]*types.Presence{}, nil +} + +func (d dummyDB) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { + return 0, nil +} + func TestRequestPool_updatePresence(t *testing.T) { type args struct { presence string @@ -89,11 +110,12 @@ func TestRequestPool_updatePresence(t *testing.T) { }, }, } - go rp.cleanPresence(time.Millisecond * 50) + db := dummyDB{} + go rp.cleanPresence(db, time.Millisecond*50) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { beforeCount := publisher.count - rp.updatePresence(tt.args.presence, tt.args.userID) + rp.updatePresence(db, tt.args.presence, tt.args.userID) if tt.wantIncrease && publisher.count <= beforeCount { t.Fatalf("expected count to increase: %d <= %d", publisher.count, beforeCount) }