bugfix: fix race condition when updating presence via /sync (#2470)
* bugfix: fix race condition when updating presence via /sync Previously when presence is updated via /sync, we would send the presence update asyncly via NATS. This created a race condition: - If the presence update is processed quickly, the /sync which triggered the presence update would see an online presence. - If the presence update was processed slowly, the /sync which triggered the presence update would see an offline presence. This is the root cause behind the flakey sytest: 'User sees their own presence in a sync'. The fix is to ensure we update the database/advance the stream position synchronously for local users. * Bugfix for test
This commit is contained in:
parent
ac92e04772
commit
b3162755a9
|
@ -138,9 +138,12 @@ func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||||
presence := msg.Header.Get("presence")
|
presence := msg.Header.Get("presence")
|
||||||
timestamp := msg.Header.Get("last_active_ts")
|
timestamp := msg.Header.Get("last_active_ts")
|
||||||
fromSync, _ := strconv.ParseBool(msg.Header.Get("from_sync"))
|
fromSync, _ := strconv.ParseBool(msg.Header.Get("from_sync"))
|
||||||
|
|
||||||
logrus.Debugf("syncAPI received presence event: %+v", msg.Header)
|
logrus.Debugf("syncAPI received presence event: %+v", msg.Header)
|
||||||
|
|
||||||
|
if fromSync { // do not process local presence changes; we already did this synchronously.
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
ts, err := strconv.Atoi(timestamp)
|
ts, err := strconv.Atoi(timestamp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true
|
return true
|
||||||
|
@ -151,15 +154,19 @@ func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||||
newMsg := msg.Header.Get("status_msg")
|
newMsg := msg.Header.Get("status_msg")
|
||||||
statusMsg = &newMsg
|
statusMsg = &newMsg
|
||||||
}
|
}
|
||||||
// OK is already checked, so no need to do it again
|
// already checked, so no need to check error
|
||||||
p, _ := types.PresenceFromString(presence)
|
p, _ := types.PresenceFromString(presence)
|
||||||
pos, err := s.db.UpdatePresence(ctx, userID, p, statusMsg, gomatrixserverlib.Timestamp(ts), fromSync)
|
|
||||||
if err != nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
s.stream.Advance(pos)
|
|
||||||
s.notifier.OnNewPresence(types.StreamingToken{PresencePosition: pos}, userID)
|
|
||||||
|
|
||||||
|
s.EmitPresence(ctx, userID, p, statusMsg, ts, fromSync)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *PresenceConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts int, fromSync bool) {
|
||||||
|
pos, err := s.db.UpdatePresence(ctx, userID, presence, statusMsg, gomatrixserverlib.Timestamp(ts), fromSync)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).WithField("user", userID).WithField("presence", presence).Warn("failed to updated presence for user")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.stream.Advance(pos)
|
||||||
|
s.notifier.OnNewPresence(types.StreamingToken{PresencePosition: pos}, userID)
|
||||||
|
}
|
||||||
|
|
|
@ -53,19 +53,24 @@ type RequestPool struct {
|
||||||
streams *streams.Streams
|
streams *streams.Streams
|
||||||
Notifier *notifier.Notifier
|
Notifier *notifier.Notifier
|
||||||
producer PresencePublisher
|
producer PresencePublisher
|
||||||
|
consumer PresenceConsumer
|
||||||
}
|
}
|
||||||
|
|
||||||
type PresencePublisher interface {
|
type PresencePublisher interface {
|
||||||
SendPresence(userID string, presence types.Presence, statusMsg *string) error
|
SendPresence(userID string, presence types.Presence, statusMsg *string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PresenceConsumer interface {
|
||||||
|
EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts int, fromSync bool)
|
||||||
|
}
|
||||||
|
|
||||||
// NewRequestPool makes a new RequestPool
|
// NewRequestPool makes a new RequestPool
|
||||||
func NewRequestPool(
|
func NewRequestPool(
|
||||||
db storage.Database, cfg *config.SyncAPI,
|
db storage.Database, cfg *config.SyncAPI,
|
||||||
userAPI userapi.SyncUserAPI, keyAPI keyapi.SyncKeyAPI,
|
userAPI userapi.SyncUserAPI, keyAPI keyapi.SyncKeyAPI,
|
||||||
rsAPI roomserverAPI.SyncRoomserverAPI,
|
rsAPI roomserverAPI.SyncRoomserverAPI,
|
||||||
streams *streams.Streams, notifier *notifier.Notifier,
|
streams *streams.Streams, notifier *notifier.Notifier,
|
||||||
producer PresencePublisher, enableMetrics bool,
|
producer PresencePublisher, consumer PresenceConsumer, enableMetrics bool,
|
||||||
) *RequestPool {
|
) *RequestPool {
|
||||||
if enableMetrics {
|
if enableMetrics {
|
||||||
prometheus.MustRegister(
|
prometheus.MustRegister(
|
||||||
|
@ -83,6 +88,7 @@ func NewRequestPool(
|
||||||
streams: streams,
|
streams: streams,
|
||||||
Notifier: notifier,
|
Notifier: notifier,
|
||||||
producer: producer,
|
producer: producer,
|
||||||
|
consumer: consumer,
|
||||||
}
|
}
|
||||||
go rp.cleanLastSeen()
|
go rp.cleanLastSeen()
|
||||||
go rp.cleanPresence(db, time.Minute*5)
|
go rp.cleanPresence(db, time.Minute*5)
|
||||||
|
@ -160,6 +166,13 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user
|
||||||
logrus.WithError(err).Error("Unable to publish presence message from sync")
|
logrus.WithError(err).Error("Unable to publish presence message from sync")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// now synchronously update our view of the world. It's critical we do this before calculating
|
||||||
|
// the /sync response else we may not return presence: online immediately.
|
||||||
|
rp.consumer.EmitPresence(
|
||||||
|
context.Background(), userID, presenceID, newPresence.ClientFields.StatusMsg,
|
||||||
|
int(gomatrixserverlib.AsTimestamp(time.Now())), true,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *RequestPool) updateLastSeen(req *http.Request, device *userapi.Device) {
|
func (rp *RequestPool) updateLastSeen(req *http.Request, device *userapi.Device) {
|
||||||
|
|
|
@ -38,6 +38,12 @@ func (d dummyDB) MaxStreamPositionForPresence(ctx context.Context) (types.Stream
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type dummyConsumer struct{}
|
||||||
|
|
||||||
|
func (d dummyConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts int, fromSync bool) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func TestRequestPool_updatePresence(t *testing.T) {
|
func TestRequestPool_updatePresence(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
presence string
|
presence string
|
||||||
|
@ -45,6 +51,7 @@ func TestRequestPool_updatePresence(t *testing.T) {
|
||||||
sleep time.Duration
|
sleep time.Duration
|
||||||
}
|
}
|
||||||
publisher := &dummyPublisher{}
|
publisher := &dummyPublisher{}
|
||||||
|
consumer := &dummyConsumer{}
|
||||||
syncMap := sync.Map{}
|
syncMap := sync.Map{}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
@ -101,6 +108,7 @@ func TestRequestPool_updatePresence(t *testing.T) {
|
||||||
rp := &RequestPool{
|
rp := &RequestPool{
|
||||||
presence: &syncMap,
|
presence: &syncMap,
|
||||||
producer: publisher,
|
producer: publisher,
|
||||||
|
consumer: consumer,
|
||||||
cfg: &config.SyncAPI{
|
cfg: &config.SyncAPI{
|
||||||
Matrix: &config.Global{
|
Matrix: &config.Global{
|
||||||
JetStream: config.JetStream{
|
JetStream: config.JetStream{
|
||||||
|
|
|
@ -64,8 +64,17 @@ func AddPublicRoutes(
|
||||||
Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent),
|
Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent),
|
||||||
JetStream: js,
|
JetStream: js,
|
||||||
}
|
}
|
||||||
|
presenceConsumer := consumers.NewPresenceConsumer(
|
||||||
|
base.ProcessContext, cfg, js, natsClient, syncDB,
|
||||||
|
notifier, streams.PresenceStreamProvider,
|
||||||
|
userAPI,
|
||||||
|
)
|
||||||
|
|
||||||
requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier, federationPresenceProducer, base.EnableMetrics)
|
requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics)
|
||||||
|
|
||||||
|
if err = presenceConsumer.Start(); err != nil {
|
||||||
|
logrus.WithError(err).Panicf("failed to start presence consumer")
|
||||||
|
}
|
||||||
|
|
||||||
userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{
|
userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{
|
||||||
JetStream: js,
|
JetStream: js,
|
||||||
|
@ -131,15 +140,6 @@ func AddPublicRoutes(
|
||||||
logrus.WithError(err).Panicf("failed to start receipts consumer")
|
logrus.WithError(err).Panicf("failed to start receipts consumer")
|
||||||
}
|
}
|
||||||
|
|
||||||
presenceConsumer := consumers.NewPresenceConsumer(
|
|
||||||
base.ProcessContext, cfg, js, natsClient, syncDB,
|
|
||||||
notifier, streams.PresenceStreamProvider,
|
|
||||||
userAPI,
|
|
||||||
)
|
|
||||||
if err = presenceConsumer.Start(); err != nil {
|
|
||||||
logrus.WithError(err).Panicf("failed to start presence consumer")
|
|
||||||
}
|
|
||||||
|
|
||||||
routing.Setup(
|
routing.Setup(
|
||||||
base.PublicClientAPIMux, requestPool, syncDB, userAPI,
|
base.PublicClientAPIMux, requestPool, syncDB, userAPI,
|
||||||
rsAPI, cfg, base.Caches,
|
rsAPI, cfg, base.Caches,
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
type syncRoomserverAPI struct {
|
type syncRoomserverAPI struct {
|
||||||
|
@ -256,6 +257,60 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that if we hit /sync we get back presence: online, regardless of whether messages get delivered
|
||||||
|
// via NATS. Regression test for a flakey test "User sees their own presence in a sync"
|
||||||
|
func TestSyncAPIUpdatePresenceImmediately(t *testing.T) {
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
testSyncAPIUpdatePresenceImmediately(t, dbType)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) {
|
||||||
|
user := test.NewUser(t)
|
||||||
|
alice := userapi.Device{
|
||||||
|
ID: "ALICEID",
|
||||||
|
UserID: user.ID,
|
||||||
|
AccessToken: "ALICE_BEARER_TOKEN",
|
||||||
|
DisplayName: "Alice",
|
||||||
|
AccountType: userapi.AccountTypeUser,
|
||||||
|
}
|
||||||
|
|
||||||
|
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||||
|
base.Cfg.Global.Presence.EnableOutbound = true
|
||||||
|
base.Cfg.Global.Presence.EnableInbound = true
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
|
||||||
|
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
|
||||||
|
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
|
||||||
|
"access_token": alice.AccessToken,
|
||||||
|
"timeout": "0",
|
||||||
|
"set_presence": "online",
|
||||||
|
})))
|
||||||
|
if w.Code != 200 {
|
||||||
|
t.Fatalf("got HTTP %d want %d", w.Code, 200)
|
||||||
|
}
|
||||||
|
var res types.Response
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&res); err != nil {
|
||||||
|
t.Errorf("failed to decode response body: %s", err)
|
||||||
|
}
|
||||||
|
if len(res.Presence.Events) != 1 {
|
||||||
|
t.Fatalf("expected 1 presence events, got: %+v", res.Presence.Events)
|
||||||
|
}
|
||||||
|
if res.Presence.Events[0].Sender != alice.UserID {
|
||||||
|
t.Errorf("sender: got %v want %v", res.Presence.Events[0].Sender, alice.UserID)
|
||||||
|
}
|
||||||
|
if res.Presence.Events[0].Type != "m.presence" {
|
||||||
|
t.Errorf("type: got %v want %v", res.Presence.Events[0].Type, "m.presence")
|
||||||
|
}
|
||||||
|
if gjson.ParseBytes(res.Presence.Events[0].Content).Get("presence").Str != "online" {
|
||||||
|
t.Errorf("content: not online, got %v", res.Presence.Events[0].Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverlib.HeaderedEvent) []*nats.Msg {
|
func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverlib.HeaderedEvent) []*nats.Msg {
|
||||||
result := make([]*nats.Msg, len(input))
|
result := make([]*nats.Msg, len(input))
|
||||||
for i, ev := range input {
|
for i, ev := range input {
|
||||||
|
|
Loading…
Reference in a new issue