Fix newly joined users presence ()

Fixes  
Also refactors the presence stream to not hit the database for every
user, instead queries all users at once now.
This commit is contained in:
Till 2022-12-08 08:25:03 +01:00 committed by GitHub
parent 0351618ff4
commit c136a450d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 263 additions and 75 deletions

View file

@ -78,7 +78,7 @@ func (s *PresenceConsumer) Start() error {
// Normal NATS subscription, used by Request/Reply // Normal NATS subscription, used by Request/Reply
_, err := s.nats.Subscribe(s.requestTopic, func(msg *nats.Msg) { _, err := s.nats.Subscribe(s.requestTopic, func(msg *nats.Msg) {
userID := msg.Header.Get(jetstream.UserID) userID := msg.Header.Get(jetstream.UserID)
presence, err := s.db.GetPresence(context.Background(), userID) presences, err := s.db.GetPresences(context.Background(), []string{userID})
m := &nats.Msg{ m := &nats.Msg{
Header: nats.Header{}, Header: nats.Header{},
} }
@ -89,10 +89,12 @@ func (s *PresenceConsumer) Start() error {
} }
return return
} }
if presence == nil {
presence = &types.PresenceInternal{ presence := &types.PresenceInternal{
UserID: userID, UserID: userID,
} }
if len(presences) > 0 {
presence = presences[0]
} }
deviceRes := api.QueryDevicesResponse{} deviceRes := api.QueryDevicesResponse{}

View file

@ -106,7 +106,7 @@ type DatabaseTransaction interface {
SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
// getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms // getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error)
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error)
} }
@ -186,7 +186,7 @@ type Database interface {
} }
type Presence interface { type Presence interface {
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error)
UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error)
} }

View file

@ -19,10 +19,12 @@ import (
"database/sql" "database/sql"
"time" "time"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
) )
const presenceSchema = ` const presenceSchema = `
@ -63,9 +65,9 @@ const upsertPresenceFromSyncSQL = "" +
" RETURNING id" " RETURNING id"
const selectPresenceForUserSQL = "" + const selectPresenceForUserSQL = "" +
"SELECT presence, status_msg, last_active_ts" + "SELECT user_id, presence, status_msg, last_active_ts" +
" FROM syncapi_presence" + " FROM syncapi_presence" +
" WHERE user_id = $1 LIMIT 1" " WHERE user_id = ANY($1)"
const selectMaxPresenceSQL = "" + const selectMaxPresenceSQL = "" +
"SELECT COALESCE(MAX(id), 0) FROM syncapi_presence" "SELECT COALESCE(MAX(id), 0) FROM syncapi_presence"
@ -119,20 +121,28 @@ func (p *presenceStatements) UpsertPresence(
return return
} }
// GetPresenceForUser returns the current presence of a user. // GetPresenceForUsers returns the current presence for a list of users.
func (p *presenceStatements) GetPresenceForUser( // If the user doesn't have a presence status yet, it is omitted from the response.
func (p *presenceStatements) GetPresenceForUsers(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
userID string, userIDs []string,
) (*types.PresenceInternal, error) { ) ([]*types.PresenceInternal, error) {
result := &types.PresenceInternal{ result := make([]*types.PresenceInternal, 0, len(userIDs))
UserID: userID,
}
stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt)
err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS) rows, err := stmt.QueryContext(ctx, pq.Array(userIDs))
if err == sql.ErrNoRows { if err != nil {
return nil, nil return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed")
for rows.Next() {
presence := &types.PresenceInternal{}
if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil {
return nil, err
}
presence.ClientFields.Presence = presence.Presence.String()
result = append(result, presence)
} }
result.ClientFields.Presence = result.Presence.String()
return result, err return result, err
} }

View file

@ -564,8 +564,8 @@ func (d *Database) UpdatePresence(ctx context.Context, userID string, presence t
return pos, err return pos, err
} }
func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { func (d *Database) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) {
return d.Presence.GetPresenceForUser(ctx, nil, userID) return d.Presence.GetPresenceForUsers(ctx, nil, userIDs)
} }
func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) {

View file

@ -596,8 +596,8 @@ func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx contex
return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs) return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs)
} }
func (d *DatabaseTransaction) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { func (d *DatabaseTransaction) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) {
return d.Presence.GetPresenceForUser(ctx, d.txn, userID) return d.Presence.GetPresenceForUsers(ctx, d.txn, userIDs)
} }
func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {

View file

@ -17,12 +17,14 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
) )
const presenceSchema = ` const presenceSchema = `
@ -62,9 +64,9 @@ const upsertPresenceFromSyncSQL = "" +
" RETURNING id" " RETURNING id"
const selectPresenceForUserSQL = "" + const selectPresenceForUserSQL = "" +
"SELECT presence, status_msg, last_active_ts" + "SELECT user_id, presence, status_msg, last_active_ts" +
" FROM syncapi_presence" + " FROM syncapi_presence" +
" WHERE user_id = $1 LIMIT 1" " WHERE user_id IN ($1)"
const selectMaxPresenceSQL = "" + const selectMaxPresenceSQL = "" +
"SELECT COALESCE(MAX(id), 0) FROM syncapi_presence" "SELECT COALESCE(MAX(id), 0) FROM syncapi_presence"
@ -134,20 +136,38 @@ func (p *presenceStatements) UpsertPresence(
return return
} }
// GetPresenceForUser returns the current presence of a user. // GetPresenceForUsers returns the current presence for a list of users.
func (p *presenceStatements) GetPresenceForUser( // If the user doesn't have a presence status yet, it is omitted from the response.
func (p *presenceStatements) GetPresenceForUsers(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
userID string, userIDs []string,
) (*types.PresenceInternal, error) { ) ([]*types.PresenceInternal, error) {
result := &types.PresenceInternal{ qry := strings.Replace(selectPresenceForUserSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1)
UserID: userID, prepStmt, err := p.db.Prepare(qry)
if err != nil {
return nil, err
} }
stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) defer internal.CloseAndLogIfError(ctx, prepStmt, "GetPresenceForUsers: stmt.close() failed")
err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS)
if err == sql.ErrNoRows { params := make([]interface{}, len(userIDs))
return nil, nil for i := range userIDs {
params[i] = userIDs[i]
}
rows, err := sqlutil.TxStmt(txn, prepStmt).QueryContext(ctx, params...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed")
result := make([]*types.PresenceInternal, 0, len(userIDs))
for rows.Next() {
presence := &types.PresenceInternal{}
if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil {
return nil, err
}
presence.ClientFields.Presence = presence.Presence.String()
result = append(result, presence)
} }
result.ClientFields.Presence = result.Presence.String()
return result, err return result, err
} }

View file

@ -207,7 +207,7 @@ type Ignores interface {
type Presence interface { type Presence interface {
UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error) UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error)
GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error) GetPresenceForUsers(ctx context.Context, txn *sql.Tx, userIDs []string) (presence []*types.PresenceInternal, err error)
GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error)
GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error) GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error)
} }

View file

@ -0,0 +1,136 @@
package tables_test
import (
"context"
"database/sql"
"reflect"
"testing"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
)
func mustPresenceTable(t *testing.T, dbType test.DBType) (tables.Presence, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
if err != nil {
t.Fatalf("failed to open db: %s", err)
}
var tab tables.Presence
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresPresenceTable(db)
case test.DBTypeSQLite:
var stream sqlite3.StreamIDStatements
if err = stream.Prepare(db); err != nil {
t.Fatalf("failed to prepare stream stmts: %s", err)
}
tab, err = sqlite3.NewSqlitePresenceTable(db, &stream)
}
if err != nil {
t.Fatalf("failed to make new table: %s", err)
}
return tab, close
}
func TestPresence(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
ctx := context.Background()
statusMsg := "Hello World!"
timestamp := gomatrixserverlib.AsTimestamp(time.Now())
var txn *sql.Tx
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, closeDB := mustPresenceTable(t, dbType)
defer closeDB()
// Insert some presences
pos, err := tab.UpsertPresence(ctx, txn, alice.ID, &statusMsg, types.PresenceOnline, timestamp, false)
if err != nil {
t.Error(err)
}
wantPos := types.StreamPosition(1)
if pos != wantPos {
t.Errorf("expected pos to be %d, got %d", wantPos, pos)
}
pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, false)
if err != nil {
t.Error(err)
}
wantPos = 2
if pos != wantPos {
t.Errorf("expected pos to be %d, got %d", wantPos, pos)
}
// verify the expected max presence ID
maxPos, err := tab.GetMaxPresenceID(ctx, txn)
if err != nil {
t.Error(err)
}
if maxPos != wantPos {
t.Errorf("expected max pos to be %d, got %d", wantPos, maxPos)
}
// This should increment the position
pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, true)
if err != nil {
t.Error(err)
}
wantPos = pos
if wantPos <= maxPos {
t.Errorf("expected pos to be %d incremented, got %d", wantPos, pos)
}
// This should return only Bobs status
presences, err := tab.GetPresenceAfter(ctx, txn, maxPos, gomatrixserverlib.EventFilter{Limit: 10})
if err != nil {
t.Error(err)
}
if c := len(presences); c > 1 {
t.Errorf("expected only one presence, got %d", c)
}
// Validate the response
wantPresence := &types.PresenceInternal{
UserID: bob.ID,
Presence: types.PresenceOnline,
StreamPos: wantPos,
LastActiveTS: timestamp,
ClientFields: types.PresenceClientResponse{
LastActiveAgo: 0,
Presence: types.PresenceOnline.String(),
StatusMsg: &statusMsg,
},
}
if !reflect.DeepEqual(wantPresence, presences[bob.ID]) {
t.Errorf("unexpected presence result:\n%+v, want\n%+v", presences[bob.ID], wantPresence)
}
// Try getting presences for existing and non-existing users
getUsers := []string{alice.ID, bob.ID, "@doesntexist:test"}
presencesForUsers, err := tab.GetPresenceForUsers(ctx, nil, getUsers)
if err != nil {
t.Error(err)
}
if len(presencesForUsers) >= len(getUsers) {
t.Errorf("expected less presences, but they are the same/more as requested: %d >= %d", len(presencesForUsers), len(getUsers))
}
})
}

View file

@ -17,6 +17,7 @@ package streams
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"sync" "sync"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -70,39 +71,25 @@ func (p *PresenceStreamProvider) IncrementalSync(
return from return from
} }
if len(presences) == 0 { getPresenceForUsers, err := p.getNeededUsersFromRequest(ctx, req, presences)
if err != nil {
req.Log.WithError(err).Error("getNeededUsersFromRequest failed")
return from
}
// Got no presence between range and no presence to get from the database
if len(getPresenceForUsers) == 0 && len(presences) == 0 {
return to return to
} }
// add newly joined rooms user presences dbPresences, err := snapshot.GetPresences(ctx, getPresenceForUsers)
newlyJoined := joinedRooms(req.Response, req.Device.UserID)
if len(newlyJoined) > 0 {
// TODO: Check if this is working better than before.
if err = p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil {
req.Log.WithError(err).Error("unable to refresh notifier lists")
return from
}
NewlyJoinedLoop:
for _, roomID := range newlyJoined {
roomUsers := p.notifier.JoinedUsers(roomID)
for i := range roomUsers {
// we already got a presence from this user
if _, ok := presences[roomUsers[i]]; ok {
continue
}
// Bear in mind that this might return nil, but at least populating
// a nil means that there's a map entry so we won't repeat this call.
presences[roomUsers[i]], err = snapshot.GetPresence(ctx, roomUsers[i])
if err != nil { if err != nil {
req.Log.WithError(err).Error("unable to query presence for user") req.Log.WithError(err).Error("unable to query presence for user")
_ = snapshot.Rollback() _ = snapshot.Rollback()
return from return from
} }
if len(presences) > req.Filter.Presence.Limit { for _, presence := range dbPresences {
break NewlyJoinedLoop presences[presence.UserID] = presence
}
}
}
} }
lastPos := from lastPos := from
@ -164,6 +151,39 @@ func (p *PresenceStreamProvider) IncrementalSync(
return lastPos return lastPos
} }
func (p *PresenceStreamProvider) getNeededUsersFromRequest(ctx context.Context, req *types.SyncRequest, presences map[string]*types.PresenceInternal) ([]string, error) {
getPresenceForUsers := []string{}
// Add presence for users which newly joined a room
for userID := range req.MembershipChanges {
if _, ok := presences[userID]; ok {
continue
}
getPresenceForUsers = append(getPresenceForUsers, userID)
}
// add newly joined rooms user presences
newlyJoined := joinedRooms(req.Response, req.Device.UserID)
if len(newlyJoined) == 0 {
return getPresenceForUsers, nil
}
// TODO: Check if this is working better than before.
if err := p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil {
return getPresenceForUsers, fmt.Errorf("unable to refresh notifier lists: %w", err)
}
for _, roomID := range newlyJoined {
roomUsers := p.notifier.JoinedUsers(roomID)
for i := range roomUsers {
// we already got a presence from this user
if _, ok := presences[roomUsers[i]]; ok {
continue
}
getPresenceForUsers = append(getPresenceForUsers, roomUsers[i])
}
}
return getPresenceForUsers, nil
}
func joinedRooms(res *types.Response, userID string) []string { func joinedRooms(res *types.Response, userID string) []string {
var roomIDs []string var roomIDs []string
for roomID, join := range res.Rooms.Join { for roomID, join := range res.Rooms.Join {

View file

@ -145,12 +145,12 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user
} }
// ensure we also send the current status_msg to federated servers and not nil // ensure we also send the current status_msg to federated servers and not nil
dbPresence, err := db.GetPresence(context.Background(), userID) dbPresence, err := db.GetPresences(context.Background(), []string{userID})
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return return
} }
if dbPresence != nil { if len(dbPresence) > 0 && dbPresence[0] != nil {
newPresence.ClientFields = dbPresence.ClientFields newPresence.ClientFields = dbPresence[0].ClientFields
} }
newPresence.ClientFields.Presence = presenceID.String() newPresence.ClientFields.Presence = presenceID.String()

View file

@ -29,8 +29,8 @@ func (d dummyDB) UpdatePresence(ctx context.Context, userID string, presence typ
return 0, nil return 0, nil
} }
func (d dummyDB) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { func (d dummyDB) GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error) {
return &types.PresenceInternal{}, nil return []*types.PresenceInternal{}, nil
} }
func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {