Remove status_msg_nil

This commit is contained in:
Till Faelligen 2022-04-04 08:49:35 +02:00
parent 16b8502bdb
commit 552fce0fc9
7 changed files with 57 additions and 28 deletions

View file

@ -182,11 +182,9 @@ func (p *SyncAPIProducer) SendPresence(
m := nats.NewMsg(p.TopicPresenceEvent) m := nats.NewMsg(p.TopicPresenceEvent)
m.Header.Set(jetstream.UserID, userID) m.Header.Set(jetstream.UserID, userID)
m.Header.Set("presence", presence) m.Header.Set("presence", presence)
nilMsg := statusMsg == nil if statusMsg != nil {
if !nilMsg {
m.Header.Set("status_msg", *statusMsg) m.Header.Set("status_msg", *statusMsg)
} }
m.Header.Set("status_msg_nil", strconv.FormatBool(nilMsg))
m.Header.Set("last_active_ts", strconv.Itoa(int(gomatrixserverlib.AsTimestamp(time.Now())))) m.Header.Set("last_active_ts", strconv.Itoa(int(gomatrixserverlib.AsTimestamp(time.Now()))))

View file

@ -84,10 +84,8 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) b
} }
presence := msg.Header.Get("presence") presence := msg.Header.Get("presence")
statusMsg := msg.Header.Get("status_msg")
nilStatusMsg, _ := strconv.ParseBool(msg.Header.Get("status_msg_nil"))
ts, err := strconv.Atoi(msg.Header.Get("last_active_ts"))
ts, err := strconv.Atoi(msg.Header.Get("last_active_ts"))
if err != nil { if err != nil {
return true return true
} }
@ -101,9 +99,10 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) b
return true return true
} }
newStatusMsg := &statusMsg var statusMsg *string = nil
if nilStatusMsg { if data, ok := msg.Header["status_msg"]; ok && len(data) > 0 {
newStatusMsg = nil status := msg.Header.Get("status_msg")
statusMsg = &status
} }
p := types.Presence{LastActiveTS: gomatrixserverlib.Timestamp(ts)} p := types.Presence{LastActiveTS: gomatrixserverlib.Timestamp(ts)}
@ -114,7 +113,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) b
CurrentlyActive: p.CurrentlyActive(), CurrentlyActive: p.CurrentlyActive(),
LastActiveAgo: p.LastActiveAgo(), LastActiveAgo: p.LastActiveAgo(),
Presence: presence, Presence: presence,
StatusMsg: newStatusMsg, StatusMsg: statusMsg,
UserID: userID, UserID: userID,
}, },
}, },

View file

@ -154,7 +154,6 @@ func (p *SyncAPIProducer) SendPresence(
if statusMsg != nil { if statusMsg != nil {
m.Header.Set("status_msg", *statusMsg) m.Header.Set("status_msg", *statusMsg)
} }
m.Header.Set("status_msg_nil", strconv.FormatBool(statusMsg == nil))
lastActiveTS := gomatrixserverlib.AsTimestamp(time.Now().Add(-(time.Duration(lastActiveAgo) * time.Millisecond))) lastActiveTS := gomatrixserverlib.AsTimestamp(time.Now().Add(-(time.Duration(lastActiveAgo) * time.Millisecond)))
m.Header.Set("last_active_ts", strconv.Itoa(int(lastActiveTS))) m.Header.Set("last_active_ts", strconv.Itoa(int(lastActiveTS)))

View file

@ -124,10 +124,8 @@ func (s *PresenceConsumer) Start() error {
func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
userID := msg.Header.Get(jetstream.UserID) userID := msg.Header.Get(jetstream.UserID)
presence := msg.Header.Get("presence") presence := msg.Header.Get("presence")
statusMsg := msg.Header.Get("status_msg")
timestamp := msg.Header.Get("last_active_ts") timestamp := msg.Header.Get("last_active_ts")
fromSync, _ := strconv.ParseBool(msg.Header.Get("from_sync")) fromSync, _ := strconv.ParseBool(msg.Header.Get("from_sync"))
nilStatusMsg, _ := strconv.ParseBool(msg.Header.Get("status_msg_nil"))
logrus.Debugf("syncAPI received presence event: %+v", msg.Header) logrus.Debugf("syncAPI received presence event: %+v", msg.Header)
@ -136,12 +134,13 @@ func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
return true return true
} }
newStatusMsg := &statusMsg var statusMsg *string = nil
if nilStatusMsg { if data, ok := msg.Header["status_msg"]; ok && len(data) > 0 {
newStatusMsg = nil newMsg := msg.Header.Get("status_msg")
statusMsg = &newMsg
} }
pos, err := s.db.UpdatePresence(ctx, userID, presence, newStatusMsg, gomatrixserverlib.Timestamp(ts), fromSync) pos, err := s.db.UpdatePresence(ctx, userID, presence, statusMsg, gomatrixserverlib.Timestamp(ts), fromSync)
if err != nil { if err != nil {
return true return true
} }

View file

@ -31,7 +31,7 @@ type FederationAPIPresenceProducer struct {
} }
func (f *FederationAPIPresenceProducer) SendPresence( func (f *FederationAPIPresenceProducer) SendPresence(
userID, presence string, userID, presence string, statusMsg *string,
) error { ) error {
msg := nats.NewMsg(f.Topic) msg := nats.NewMsg(f.Topic)
msg.Header.Set(jetstream.UserID, userID) msg.Header.Set(jetstream.UserID, userID)
@ -39,6 +39,10 @@ func (f *FederationAPIPresenceProducer) SendPresence(
msg.Header.Set("from_sync", "true") // only update last_active_ts and presence msg.Header.Set("from_sync", "true") // only update last_active_ts and presence
msg.Header.Set("last_active_ts", strconv.Itoa(int(gomatrixserverlib.AsTimestamp(time.Now())))) msg.Header.Set("last_active_ts", strconv.Itoa(int(gomatrixserverlib.AsTimestamp(time.Now()))))
if statusMsg != nil {
msg.Header.Set("status_msg", *statusMsg)
}
_, err := f.JetStream.PublishMsg(msg) _, err := f.JetStream.PublishMsg(msg)
return err return err
} }

View file

@ -17,6 +17,8 @@
package sync package sync
import ( import (
"context"
"database/sql"
"net" "net"
"net/http" "net/http"
"strings" "strings"
@ -54,7 +56,7 @@ type RequestPool struct {
} }
type PresencePublisher interface { type PresencePublisher interface {
SendPresence(userID, presence string) error SendPresence(userID, presence string, statusMsg *string) error
} }
// NewRequestPool makes a new RequestPool // NewRequestPool makes a new RequestPool
@ -78,7 +80,7 @@ func NewRequestPool(
producer: producer, producer: producer,
} }
go rp.cleanLastSeen() go rp.cleanLastSeen()
go rp.cleanPresence(time.Minute * 5) go rp.cleanPresence(db, time.Minute*5)
return rp return rp
} }
@ -92,12 +94,12 @@ func (rp *RequestPool) cleanLastSeen() {
} }
} }
func (rp *RequestPool) cleanPresence(cleanupTime time.Duration) { func (rp *RequestPool) cleanPresence(db storage.Presence, cleanupTime time.Duration) {
for { for {
rp.presence.Range(func(key interface{}, v interface{}) bool { rp.presence.Range(func(key interface{}, v interface{}) bool {
p := v.(types.Presence) p := v.(types.Presence)
if time.Since(p.LastActiveTS.Time()) > cleanupTime { if time.Since(p.LastActiveTS.Time()) > cleanupTime {
rp.updatePresence("unavailable", p.UserID) rp.updatePresence(db, "unavailable", p.UserID)
rp.presence.Delete(key) rp.presence.Delete(key)
} }
return true return true
@ -107,7 +109,7 @@ func (rp *RequestPool) cleanPresence(cleanupTime time.Duration) {
} }
// updatePresence sends presence updates to the SyncAPI and FederationAPI // updatePresence sends presence updates to the SyncAPI and FederationAPI
func (rp *RequestPool) updatePresence(presence string, userID string) { func (rp *RequestPool) updatePresence(db storage.Presence, presence string, userID string) {
if rp.cfg.Matrix.DisablePresence { if rp.cfg.Matrix.DisablePresence {
return return
} }
@ -132,7 +134,13 @@ func (rp *RequestPool) updatePresence(presence string, userID string) {
} }
} }
if err := rp.producer.SendPresence(userID, strings.ToLower(presence)); err != nil { // ensure we also send the current status_msg to federated servers and not nil
dbPresence, err := db.GetPresence(context.Background(), userID)
if err != nil && err != sql.ErrNoRows {
return
}
if err := rp.producer.SendPresence(userID, strings.ToLower(presence), dbPresence.ClientFields.StatusMsg); err != nil {
logrus.WithError(err).Error("Unable to publish presence message from sync") logrus.WithError(err).Error("Unable to publish presence message from sync")
return return
} }
@ -214,7 +222,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
defer activeSyncRequests.Dec() defer activeSyncRequests.Dec()
rp.updateLastSeen(req, device) rp.updateLastSeen(req, device)
rp.updatePresence(req.FormValue("set_presence"), device.UserID) rp.updatePresence(rp.db, req.FormValue("set_presence"), device.UserID)
waitingSyncRequests.Inc() waitingSyncRequests.Inc()
defer waitingSyncRequests.Dec() defer waitingSyncRequests.Dec()

View file

@ -1,22 +1,43 @@
package sync package sync
import ( import (
"context"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
) )
type dummyPublisher struct { type dummyPublisher struct {
count int count int
} }
func (d *dummyPublisher) SendPresence(userID, presence string) error { func (d *dummyPublisher) SendPresence(userID, presence string, statusMsg *string) error {
d.count++ d.count++
return nil return nil
} }
type dummyDB struct{}
func (d dummyDB) UpdatePresence(ctx context.Context, userID, presence string, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) {
return 0, nil
}
func (d dummyDB) GetPresence(ctx context.Context, userID string) (*types.Presence, error) {
return &types.Presence{}, nil
}
func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition) (map[string]*types.Presence, error) {
return map[string]*types.Presence{}, nil
}
func (d dummyDB) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {
return 0, nil
}
func TestRequestPool_updatePresence(t *testing.T) { func TestRequestPool_updatePresence(t *testing.T) {
type args struct { type args struct {
presence string presence string
@ -89,11 +110,12 @@ func TestRequestPool_updatePresence(t *testing.T) {
}, },
}, },
} }
go rp.cleanPresence(time.Millisecond * 50) db := dummyDB{}
go rp.cleanPresence(db, time.Millisecond*50)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
beforeCount := publisher.count beforeCount := publisher.count
rp.updatePresence(tt.args.presence, tt.args.userID) rp.updatePresence(db, tt.args.presence, tt.args.userID)
if tt.wantIncrease && publisher.count <= beforeCount { if tt.wantIncrease && publisher.count <= beforeCount {
t.Fatalf("expected count to increase: %d <= %d", publisher.count, beforeCount) t.Fatalf("expected count to increase: %d <= %d", publisher.count, beforeCount)
} }