Fix newly joined users presence (#2854)

Fixes #2803 
Also refactors the presence stream to not hit the database for every
user, instead queries all users at once now.
This commit is contained in:
Till 2022-12-08 08:25:03 +01:00 committed by GitHub
parent 0351618ff4
commit c136a450d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 263 additions and 75 deletions

View file

@ -78,7 +78,7 @@ func (s *PresenceConsumer) Start() error {
// Normal NATS subscription, used by Request/Reply // Normal NATS subscription, used by Request/Reply
_, err := s.nats.Subscribe(s.requestTopic, func(msg *nats.Msg) { _, err := s.nats.Subscribe(s.requestTopic, func(msg *nats.Msg) {
userID := msg.Header.Get(jetstream.UserID) userID := msg.Header.Get(jetstream.UserID)
presence, err := s.db.GetPresence(context.Background(), userID) presences, err := s.db.GetPresences(context.Background(), []string{userID})
m := &nats.Msg{ m := &nats.Msg{
Header: nats.Header{}, Header: nats.Header{},
} }
@ -89,10 +89,12 @@ func (s *PresenceConsumer) Start() error {
} }
return return
} }
if presence == nil {
presence = &types.PresenceInternal{ presence := &types.PresenceInternal{
UserID: userID, UserID: userID,
} }
if len(presences) > 0 {
presence = presences[0]
} }
deviceRes := api.QueryDevicesResponse{} deviceRes := api.QueryDevicesResponse{}

View file

@ -106,7 +106,7 @@ type DatabaseTransaction interface {
SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
// getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms // getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error)
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error)
} }
@ -186,7 +186,7 @@ type Database interface {
} }
type Presence interface { type Presence interface {
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error)
UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error)
} }

View file

@ -19,10 +19,12 @@ import (
"database/sql" "database/sql"
"time" "time"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
) )
const presenceSchema = ` const presenceSchema = `
@ -63,9 +65,9 @@ const upsertPresenceFromSyncSQL = "" +
" RETURNING id" " RETURNING id"
const selectPresenceForUserSQL = "" + const selectPresenceForUserSQL = "" +
"SELECT presence, status_msg, last_active_ts" + "SELECT user_id, presence, status_msg, last_active_ts" +
" FROM syncapi_presence" + " FROM syncapi_presence" +
" WHERE user_id = $1 LIMIT 1" " WHERE user_id = ANY($1)"
const selectMaxPresenceSQL = "" + const selectMaxPresenceSQL = "" +
"SELECT COALESCE(MAX(id), 0) FROM syncapi_presence" "SELECT COALESCE(MAX(id), 0) FROM syncapi_presence"
@ -119,20 +121,28 @@ func (p *presenceStatements) UpsertPresence(
return return
} }
// GetPresenceForUser returns the current presence of a user. // GetPresenceForUsers returns the current presence for a list of users.
func (p *presenceStatements) GetPresenceForUser( // If the user doesn't have a presence status yet, it is omitted from the response.
func (p *presenceStatements) GetPresenceForUsers(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
userID string, userIDs []string,
) (*types.PresenceInternal, error) { ) ([]*types.PresenceInternal, error) {
result := &types.PresenceInternal{ result := make([]*types.PresenceInternal, 0, len(userIDs))
UserID: userID,
}
stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt)
err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS) rows, err := stmt.QueryContext(ctx, pq.Array(userIDs))
if err == sql.ErrNoRows { if err != nil {
return nil, nil return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed")
for rows.Next() {
presence := &types.PresenceInternal{}
if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil {
return nil, err
}
presence.ClientFields.Presence = presence.Presence.String()
result = append(result, presence)
} }
result.ClientFields.Presence = result.Presence.String()
return result, err return result, err
} }

View file

@ -564,8 +564,8 @@ func (d *Database) UpdatePresence(ctx context.Context, userID string, presence t
return pos, err return pos, err
} }
func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { func (d *Database) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) {
return d.Presence.GetPresenceForUser(ctx, nil, userID) return d.Presence.GetPresenceForUsers(ctx, nil, userIDs)
} }
func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) {

View file

@ -596,8 +596,8 @@ func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx contex
return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs) return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs)
} }
func (d *DatabaseTransaction) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { func (d *DatabaseTransaction) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) {
return d.Presence.GetPresenceForUser(ctx, d.txn, userID) return d.Presence.GetPresenceForUsers(ctx, d.txn, userIDs)
} }
func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {

View file

@ -17,12 +17,14 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
) )
const presenceSchema = ` const presenceSchema = `
@ -62,9 +64,9 @@ const upsertPresenceFromSyncSQL = "" +
" RETURNING id" " RETURNING id"
const selectPresenceForUserSQL = "" + const selectPresenceForUserSQL = "" +
"SELECT presence, status_msg, last_active_ts" + "SELECT user_id, presence, status_msg, last_active_ts" +
" FROM syncapi_presence" + " FROM syncapi_presence" +
" WHERE user_id = $1 LIMIT 1" " WHERE user_id IN ($1)"
const selectMaxPresenceSQL = "" + const selectMaxPresenceSQL = "" +
"SELECT COALESCE(MAX(id), 0) FROM syncapi_presence" "SELECT COALESCE(MAX(id), 0) FROM syncapi_presence"
@ -134,20 +136,38 @@ func (p *presenceStatements) UpsertPresence(
return return
} }
// GetPresenceForUser returns the current presence of a user. // GetPresenceForUsers returns the current presence for a list of users.
func (p *presenceStatements) GetPresenceForUser( // If the user doesn't have a presence status yet, it is omitted from the response.
func (p *presenceStatements) GetPresenceForUsers(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
userID string, userIDs []string,
) (*types.PresenceInternal, error) { ) ([]*types.PresenceInternal, error) {
result := &types.PresenceInternal{ qry := strings.Replace(selectPresenceForUserSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1)
UserID: userID, prepStmt, err := p.db.Prepare(qry)
if err != nil {
return nil, err
} }
stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) defer internal.CloseAndLogIfError(ctx, prepStmt, "GetPresenceForUsers: stmt.close() failed")
err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS)
if err == sql.ErrNoRows { params := make([]interface{}, len(userIDs))
return nil, nil for i := range userIDs {
params[i] = userIDs[i]
}
rows, err := sqlutil.TxStmt(txn, prepStmt).QueryContext(ctx, params...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed")
result := make([]*types.PresenceInternal, 0, len(userIDs))
for rows.Next() {
presence := &types.PresenceInternal{}
if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil {
return nil, err
}
presence.ClientFields.Presence = presence.Presence.String()
result = append(result, presence)
} }
result.ClientFields.Presence = result.Presence.String()
return result, err return result, err
} }

View file

@ -207,7 +207,7 @@ type Ignores interface {
type Presence interface { type Presence interface {
UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error) UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error)
GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error) GetPresenceForUsers(ctx context.Context, txn *sql.Tx, userIDs []string) (presence []*types.PresenceInternal, err error)
GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error)
GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error) GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error)
} }

View file

@ -0,0 +1,136 @@
package tables_test
import (
"context"
"database/sql"
"reflect"
"testing"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
)
func mustPresenceTable(t *testing.T, dbType test.DBType) (tables.Presence, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
if err != nil {
t.Fatalf("failed to open db: %s", err)
}
var tab tables.Presence
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresPresenceTable(db)
case test.DBTypeSQLite:
var stream sqlite3.StreamIDStatements
if err = stream.Prepare(db); err != nil {
t.Fatalf("failed to prepare stream stmts: %s", err)
}
tab, err = sqlite3.NewSqlitePresenceTable(db, &stream)
}
if err != nil {
t.Fatalf("failed to make new table: %s", err)
}
return tab, close
}
func TestPresence(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
ctx := context.Background()
statusMsg := "Hello World!"
timestamp := gomatrixserverlib.AsTimestamp(time.Now())
var txn *sql.Tx
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, closeDB := mustPresenceTable(t, dbType)
defer closeDB()
// Insert some presences
pos, err := tab.UpsertPresence(ctx, txn, alice.ID, &statusMsg, types.PresenceOnline, timestamp, false)
if err != nil {
t.Error(err)
}
wantPos := types.StreamPosition(1)
if pos != wantPos {
t.Errorf("expected pos to be %d, got %d", wantPos, pos)
}
pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, false)
if err != nil {
t.Error(err)
}
wantPos = 2
if pos != wantPos {
t.Errorf("expected pos to be %d, got %d", wantPos, pos)
}
// verify the expected max presence ID
maxPos, err := tab.GetMaxPresenceID(ctx, txn)
if err != nil {
t.Error(err)
}
if maxPos != wantPos {
t.Errorf("expected max pos to be %d, got %d", wantPos, maxPos)
}
// This should increment the position
pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, true)
if err != nil {
t.Error(err)
}
wantPos = pos
if wantPos <= maxPos {
t.Errorf("expected pos to be %d incremented, got %d", wantPos, pos)
}
// This should return only Bobs status
presences, err := tab.GetPresenceAfter(ctx, txn, maxPos, gomatrixserverlib.EventFilter{Limit: 10})
if err != nil {
t.Error(err)
}
if c := len(presences); c > 1 {
t.Errorf("expected only one presence, got %d", c)
}
// Validate the response
wantPresence := &types.PresenceInternal{
UserID: bob.ID,
Presence: types.PresenceOnline,
StreamPos: wantPos,
LastActiveTS: timestamp,
ClientFields: types.PresenceClientResponse{
LastActiveAgo: 0,
Presence: types.PresenceOnline.String(),
StatusMsg: &statusMsg,
},
}
if !reflect.DeepEqual(wantPresence, presences[bob.ID]) {
t.Errorf("unexpected presence result:\n%+v, want\n%+v", presences[bob.ID], wantPresence)
}
// Try getting presences for existing and non-existing users
getUsers := []string{alice.ID, bob.ID, "@doesntexist:test"}
presencesForUsers, err := tab.GetPresenceForUsers(ctx, nil, getUsers)
if err != nil {
t.Error(err)
}
if len(presencesForUsers) >= len(getUsers) {
t.Errorf("expected less presences, but they are the same/more as requested: %d >= %d", len(presencesForUsers), len(getUsers))
}
})
}

View file

@ -17,6 +17,7 @@ package streams
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"sync" "sync"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -70,39 +71,25 @@ func (p *PresenceStreamProvider) IncrementalSync(
return from return from
} }
if len(presences) == 0 { getPresenceForUsers, err := p.getNeededUsersFromRequest(ctx, req, presences)
if err != nil {
req.Log.WithError(err).Error("getNeededUsersFromRequest failed")
return from
}
// Got no presence between range and no presence to get from the database
if len(getPresenceForUsers) == 0 && len(presences) == 0 {
return to return to
} }
// add newly joined rooms user presences dbPresences, err := snapshot.GetPresences(ctx, getPresenceForUsers)
newlyJoined := joinedRooms(req.Response, req.Device.UserID)
if len(newlyJoined) > 0 {
// TODO: Check if this is working better than before.
if err = p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil {
req.Log.WithError(err).Error("unable to refresh notifier lists")
return from
}
NewlyJoinedLoop:
for _, roomID := range newlyJoined {
roomUsers := p.notifier.JoinedUsers(roomID)
for i := range roomUsers {
// we already got a presence from this user
if _, ok := presences[roomUsers[i]]; ok {
continue
}
// Bear in mind that this might return nil, but at least populating
// a nil means that there's a map entry so we won't repeat this call.
presences[roomUsers[i]], err = snapshot.GetPresence(ctx, roomUsers[i])
if err != nil { if err != nil {
req.Log.WithError(err).Error("unable to query presence for user") req.Log.WithError(err).Error("unable to query presence for user")
_ = snapshot.Rollback() _ = snapshot.Rollback()
return from return from
} }
if len(presences) > req.Filter.Presence.Limit { for _, presence := range dbPresences {
break NewlyJoinedLoop presences[presence.UserID] = presence
}
}
}
} }
lastPos := from lastPos := from
@ -164,6 +151,39 @@ func (p *PresenceStreamProvider) IncrementalSync(
return lastPos return lastPos
} }
func (p *PresenceStreamProvider) getNeededUsersFromRequest(ctx context.Context, req *types.SyncRequest, presences map[string]*types.PresenceInternal) ([]string, error) {
getPresenceForUsers := []string{}
// Add presence for users which newly joined a room
for userID := range req.MembershipChanges {
if _, ok := presences[userID]; ok {
continue
}
getPresenceForUsers = append(getPresenceForUsers, userID)
}
// add newly joined rooms user presences
newlyJoined := joinedRooms(req.Response, req.Device.UserID)
if len(newlyJoined) == 0 {
return getPresenceForUsers, nil
}
// TODO: Check if this is working better than before.
if err := p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil {
return getPresenceForUsers, fmt.Errorf("unable to refresh notifier lists: %w", err)
}
for _, roomID := range newlyJoined {
roomUsers := p.notifier.JoinedUsers(roomID)
for i := range roomUsers {
// we already got a presence from this user
if _, ok := presences[roomUsers[i]]; ok {
continue
}
getPresenceForUsers = append(getPresenceForUsers, roomUsers[i])
}
}
return getPresenceForUsers, nil
}
func joinedRooms(res *types.Response, userID string) []string { func joinedRooms(res *types.Response, userID string) []string {
var roomIDs []string var roomIDs []string
for roomID, join := range res.Rooms.Join { for roomID, join := range res.Rooms.Join {

View file

@ -145,12 +145,12 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user
} }
// ensure we also send the current status_msg to federated servers and not nil // ensure we also send the current status_msg to federated servers and not nil
dbPresence, err := db.GetPresence(context.Background(), userID) dbPresence, err := db.GetPresences(context.Background(), []string{userID})
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return return
} }
if dbPresence != nil { if len(dbPresence) > 0 && dbPresence[0] != nil {
newPresence.ClientFields = dbPresence.ClientFields newPresence.ClientFields = dbPresence[0].ClientFields
} }
newPresence.ClientFields.Presence = presenceID.String() newPresence.ClientFields.Presence = presenceID.String()

View file

@ -29,8 +29,8 @@ func (d dummyDB) UpdatePresence(ctx context.Context, userID string, presence typ
return 0, nil return 0, nil
} }
func (d dummyDB) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { func (d dummyDB) GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error) {
return &types.PresenceInternal{}, nil return []*types.PresenceInternal{}, nil
} }
func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {