diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go index 145059c2d..6e3150c29 100644 --- a/syncapi/consumers/presence.go +++ b/syncapi/consumers/presence.go @@ -78,7 +78,7 @@ func (s *PresenceConsumer) Start() error { // Normal NATS subscription, used by Request/Reply _, err := s.nats.Subscribe(s.requestTopic, func(msg *nats.Msg) { userID := msg.Header.Get(jetstream.UserID) - presence, err := s.db.GetPresence(context.Background(), userID) + presences, err := s.db.GetPresences(context.Background(), []string{userID}) m := &nats.Msg{ Header: nats.Header{}, } @@ -89,10 +89,12 @@ func (s *PresenceConsumer) Start() error { } return } - if presence == nil { - presence = &types.PresenceInternal{ - UserID: userID, - } + + presence := &types.PresenceInternal{ + UserID: userID, + } + if len(presences) > 0 { + presence = presences[0] } deviceRes := api.QueryDevicesResponse{} diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 97c2ced49..75afbce15 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -106,7 +106,7 @@ type DatabaseTransaction interface { SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) // getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error) - GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) + GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error) } @@ -186,7 +186,7 @@ type Database interface { } type Presence interface { - GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) + GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) } diff --git a/syncapi/storage/postgres/presence_table.go b/syncapi/storage/postgres/presence_table.go index 7194afea6..a3f7c5213 100644 --- a/syncapi/storage/postgres/presence_table.go +++ b/syncapi/storage/postgres/presence_table.go @@ -19,10 +19,12 @@ import ( "database/sql" "time" + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const presenceSchema = ` @@ -63,9 +65,9 @@ const upsertPresenceFromSyncSQL = "" + " RETURNING id" const selectPresenceForUserSQL = "" + - "SELECT presence, status_msg, last_active_ts" + + "SELECT user_id, presence, status_msg, last_active_ts" + " FROM syncapi_presence" + - " WHERE user_id = $1 LIMIT 1" + " WHERE user_id = ANY($1)" const selectMaxPresenceSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_presence" @@ -119,20 +121,28 @@ func (p *presenceStatements) UpsertPresence( return } -// GetPresenceForUser returns the current presence of a user. -func (p *presenceStatements) GetPresenceForUser( +// GetPresenceForUsers returns the current presence for a list of users. +// If the user doesn't have a presence status yet, it is omitted from the response. +func (p *presenceStatements) GetPresenceForUsers( ctx context.Context, txn *sql.Tx, - userID string, -) (*types.PresenceInternal, error) { - result := &types.PresenceInternal{ - UserID: userID, - } + userIDs []string, +) ([]*types.PresenceInternal, error) { + result := make([]*types.PresenceInternal, 0, len(userIDs)) stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) - err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS) - if err == sql.ErrNoRows { - return nil, nil + rows, err := stmt.QueryContext(ctx, pq.Array(userIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed") + + for rows.Next() { + presence := &types.PresenceInternal{} + if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil { + return nil, err + } + presence.ClientFields.Presence = presence.Presence.String() + result = append(result, presence) } - result.ClientFields.Presence = result.Presence.String() return result, err } diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index f2064fb89..df2338cf8 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -564,8 +564,8 @@ func (d *Database) UpdatePresence(ctx context.Context, userID string, presence t return pos, err } -func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return d.Presence.GetPresenceForUser(ctx, nil, userID) +func (d *Database) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUsers(ctx, nil, userIDs) } func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index c3763521c..77afa0290 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -596,8 +596,8 @@ func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx contex return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs) } -func (d *DatabaseTransaction) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return d.Presence.GetPresenceForUser(ctx, d.txn, userID) +func (d *DatabaseTransaction) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUsers(ctx, d.txn, userIDs) } func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { diff --git a/syncapi/storage/sqlite3/presence_table.go b/syncapi/storage/sqlite3/presence_table.go index b61a825df..7641de92f 100644 --- a/syncapi/storage/sqlite3/presence_table.go +++ b/syncapi/storage/sqlite3/presence_table.go @@ -17,12 +17,14 @@ package sqlite3 import ( "context" "database/sql" + "strings" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const presenceSchema = ` @@ -62,9 +64,9 @@ const upsertPresenceFromSyncSQL = "" + " RETURNING id" const selectPresenceForUserSQL = "" + - "SELECT presence, status_msg, last_active_ts" + + "SELECT user_id, presence, status_msg, last_active_ts" + " FROM syncapi_presence" + - " WHERE user_id = $1 LIMIT 1" + " WHERE user_id IN ($1)" const selectMaxPresenceSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_presence" @@ -134,20 +136,38 @@ func (p *presenceStatements) UpsertPresence( return } -// GetPresenceForUser returns the current presence of a user. -func (p *presenceStatements) GetPresenceForUser( +// GetPresenceForUsers returns the current presence for a list of users. +// If the user doesn't have a presence status yet, it is omitted from the response. +func (p *presenceStatements) GetPresenceForUsers( ctx context.Context, txn *sql.Tx, - userID string, -) (*types.PresenceInternal, error) { - result := &types.PresenceInternal{ - UserID: userID, + userIDs []string, +) ([]*types.PresenceInternal, error) { + qry := strings.Replace(selectPresenceForUserSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1) + prepStmt, err := p.db.Prepare(qry) + if err != nil { + return nil, err } - stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) - err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS) - if err == sql.ErrNoRows { - return nil, nil + defer internal.CloseAndLogIfError(ctx, prepStmt, "GetPresenceForUsers: stmt.close() failed") + + params := make([]interface{}, len(userIDs)) + for i := range userIDs { + params[i] = userIDs[i] + } + + rows, err := sqlutil.TxStmt(txn, prepStmt).QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed") + result := make([]*types.PresenceInternal, 0, len(userIDs)) + for rows.Next() { + presence := &types.PresenceInternal{} + if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil { + return nil, err + } + presence.ClientFields.Presence = presence.Presence.String() + result = append(result, presence) } - result.ClientFields.Presence = result.Presence.String() return result, err } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 2c4f04ec2..a0574b257 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -207,7 +207,7 @@ type Ignores interface { type Presence interface { UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error) - GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error) + GetPresenceForUsers(ctx context.Context, txn *sql.Tx, userIDs []string) (presence []*types.PresenceInternal, err error) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error) } diff --git a/syncapi/storage/tables/presence_table_test.go b/syncapi/storage/tables/presence_table_test.go new file mode 100644 index 000000000..dce0c695a --- /dev/null +++ b/syncapi/storage/tables/presence_table_test.go @@ -0,0 +1,136 @@ +package tables_test + +import ( + "context" + "database/sql" + "reflect" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/postgres" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" +) + +func mustPresenceTable(t *testing.T, dbType test.DBType) (tables.Presence, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open db: %s", err) + } + + var tab tables.Presence + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresPresenceTable(db) + case test.DBTypeSQLite: + var stream sqlite3.StreamIDStatements + if err = stream.Prepare(db); err != nil { + t.Fatalf("failed to prepare stream stmts: %s", err) + } + tab, err = sqlite3.NewSqlitePresenceTable(db, &stream) + } + if err != nil { + t.Fatalf("failed to make new table: %s", err) + } + return tab, close +} + +func TestPresence(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + ctx := context.Background() + + statusMsg := "Hello World!" + timestamp := gomatrixserverlib.AsTimestamp(time.Now()) + + var txn *sql.Tx + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustPresenceTable(t, dbType) + defer closeDB() + + // Insert some presences + pos, err := tab.UpsertPresence(ctx, txn, alice.ID, &statusMsg, types.PresenceOnline, timestamp, false) + if err != nil { + t.Error(err) + } + wantPos := types.StreamPosition(1) + if pos != wantPos { + t.Errorf("expected pos to be %d, got %d", wantPos, pos) + } + pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, false) + if err != nil { + t.Error(err) + } + wantPos = 2 + if pos != wantPos { + t.Errorf("expected pos to be %d, got %d", wantPos, pos) + } + + // verify the expected max presence ID + maxPos, err := tab.GetMaxPresenceID(ctx, txn) + if err != nil { + t.Error(err) + } + if maxPos != wantPos { + t.Errorf("expected max pos to be %d, got %d", wantPos, maxPos) + } + + // This should increment the position + pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, true) + if err != nil { + t.Error(err) + } + wantPos = pos + if wantPos <= maxPos { + t.Errorf("expected pos to be %d incremented, got %d", wantPos, pos) + } + + // This should return only Bobs status + presences, err := tab.GetPresenceAfter(ctx, txn, maxPos, gomatrixserverlib.EventFilter{Limit: 10}) + if err != nil { + t.Error(err) + } + + if c := len(presences); c > 1 { + t.Errorf("expected only one presence, got %d", c) + } + + // Validate the response + wantPresence := &types.PresenceInternal{ + UserID: bob.ID, + Presence: types.PresenceOnline, + StreamPos: wantPos, + LastActiveTS: timestamp, + ClientFields: types.PresenceClientResponse{ + LastActiveAgo: 0, + Presence: types.PresenceOnline.String(), + StatusMsg: &statusMsg, + }, + } + if !reflect.DeepEqual(wantPresence, presences[bob.ID]) { + t.Errorf("unexpected presence result:\n%+v, want\n%+v", presences[bob.ID], wantPresence) + } + + // Try getting presences for existing and non-existing users + getUsers := []string{alice.ID, bob.ID, "@doesntexist:test"} + presencesForUsers, err := tab.GetPresenceForUsers(ctx, nil, getUsers) + if err != nil { + t.Error(err) + } + + if len(presencesForUsers) >= len(getUsers) { + t.Errorf("expected less presences, but they are the same/more as requested: %d >= %d", len(presencesForUsers), len(getUsers)) + } + }) + +} diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index 030b7c5d5..445e46b3a 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -17,6 +17,7 @@ package streams import ( "context" "encoding/json" + "fmt" "sync" "github.com/matrix-org/gomatrixserverlib" @@ -70,39 +71,25 @@ func (p *PresenceStreamProvider) IncrementalSync( return from } - if len(presences) == 0 { + getPresenceForUsers, err := p.getNeededUsersFromRequest(ctx, req, presences) + if err != nil { + req.Log.WithError(err).Error("getNeededUsersFromRequest failed") + return from + } + + // Got no presence between range and no presence to get from the database + if len(getPresenceForUsers) == 0 && len(presences) == 0 { return to } - // add newly joined rooms user presences - newlyJoined := joinedRooms(req.Response, req.Device.UserID) - if len(newlyJoined) > 0 { - // TODO: Check if this is working better than before. - if err = p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil { - req.Log.WithError(err).Error("unable to refresh notifier lists") - return from - } - NewlyJoinedLoop: - for _, roomID := range newlyJoined { - roomUsers := p.notifier.JoinedUsers(roomID) - for i := range roomUsers { - // we already got a presence from this user - if _, ok := presences[roomUsers[i]]; ok { - continue - } - // Bear in mind that this might return nil, but at least populating - // a nil means that there's a map entry so we won't repeat this call. - presences[roomUsers[i]], err = snapshot.GetPresence(ctx, roomUsers[i]) - if err != nil { - req.Log.WithError(err).Error("unable to query presence for user") - _ = snapshot.Rollback() - return from - } - if len(presences) > req.Filter.Presence.Limit { - break NewlyJoinedLoop - } - } - } + dbPresences, err := snapshot.GetPresences(ctx, getPresenceForUsers) + if err != nil { + req.Log.WithError(err).Error("unable to query presence for user") + _ = snapshot.Rollback() + return from + } + for _, presence := range dbPresences { + presences[presence.UserID] = presence } lastPos := from @@ -164,6 +151,39 @@ func (p *PresenceStreamProvider) IncrementalSync( return lastPos } +func (p *PresenceStreamProvider) getNeededUsersFromRequest(ctx context.Context, req *types.SyncRequest, presences map[string]*types.PresenceInternal) ([]string, error) { + getPresenceForUsers := []string{} + // Add presence for users which newly joined a room + for userID := range req.MembershipChanges { + if _, ok := presences[userID]; ok { + continue + } + getPresenceForUsers = append(getPresenceForUsers, userID) + } + + // add newly joined rooms user presences + newlyJoined := joinedRooms(req.Response, req.Device.UserID) + if len(newlyJoined) == 0 { + return getPresenceForUsers, nil + } + + // TODO: Check if this is working better than before. + if err := p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil { + return getPresenceForUsers, fmt.Errorf("unable to refresh notifier lists: %w", err) + } + for _, roomID := range newlyJoined { + roomUsers := p.notifier.JoinedUsers(roomID) + for i := range roomUsers { + // we already got a presence from this user + if _, ok := presences[roomUsers[i]]; ok { + continue + } + getPresenceForUsers = append(getPresenceForUsers, roomUsers[i]) + } + } + return getPresenceForUsers, nil +} + func joinedRooms(res *types.Response, userID string) []string { var roomIDs []string for roomID, join := range res.Rooms.Join { diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 29d92b293..b086567b8 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -145,12 +145,12 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user } // ensure we also send the current status_msg to federated servers and not nil - dbPresence, err := db.GetPresence(context.Background(), userID) + dbPresence, err := db.GetPresences(context.Background(), []string{userID}) if err != nil && err != sql.ErrNoRows { return } - if dbPresence != nil { - newPresence.ClientFields = dbPresence.ClientFields + if len(dbPresence) > 0 && dbPresence[0] != nil { + newPresence.ClientFields = dbPresence[0].ClientFields } newPresence.ClientFields.Presence = presenceID.String() diff --git a/syncapi/sync/requestpool_test.go b/syncapi/sync/requestpool_test.go index 3e5769d8c..faa0b49c6 100644 --- a/syncapi/sync/requestpool_test.go +++ b/syncapi/sync/requestpool_test.go @@ -29,8 +29,8 @@ func (d dummyDB) UpdatePresence(ctx context.Context, userID string, presence typ return 0, nil } -func (d dummyDB) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return &types.PresenceInternal{}, nil +func (d dummyDB) GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error) { + return []*types.PresenceInternal{}, nil } func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {