Include joined and invite member counts in room summary (#2315)

* Include joined and invite member counts in room summary

This should fix #2314 and also fix the problem where some clients like Element Android, Fluffychat etc would display the wrong member count for a given room.

* Improve SQLite query precision

* Check existence of state key for membership events
This commit is contained in:
Neil Alexander 2022-04-01 16:14:38 +01:00 committed by GitHub
parent 8213b2ba30
commit cd8fac152e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 72 additions and 4 deletions

View file

@ -37,6 +37,7 @@ type Database interface {
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)

View file

@ -62,9 +62,15 @@ const selectMembershipSQL = "" +
" ORDER BY stream_pos DESC" + " ORDER BY stream_pos DESC" +
" LIMIT 1" " LIMIT 1"
const selectMembershipCountSQL = "" +
"SELECT COUNT(*) FROM (" +
" SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" +
") t WHERE t.membership = $3"
type membershipsStatements struct { type membershipsStatements struct {
upsertMembershipStmt *sql.Stmt upsertMembershipStmt *sql.Stmt
selectMembershipStmt *sql.Stmt selectMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt
} }
func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
@ -79,6 +85,9 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil { if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil {
return nil, err return nil, err
} }
if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -109,3 +118,11 @@ func (s *membershipsStatements) SelectMembership(
err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos) err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos)
return return
} }
func (s *membershipsStatements) SelectMembershipCount(
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
) (count int, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt)
err = stmt.QueryRowContext(ctx, roomID, pos, membership).Scan(&count)
return
}

View file

@ -118,6 +118,10 @@ func (d *Database) RoomIDsWithMembership(ctx context.Context, userID string, mem
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership) return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership)
} }
func (d *Database) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) {
return d.Memberships.SelectMembershipCount(ctx, nil, roomID, membership, pos)
}
func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents)
} }

View file

@ -63,9 +63,15 @@ const selectMembershipSQL = "" +
" ORDER BY stream_pos DESC" + " ORDER BY stream_pos DESC" +
" LIMIT 1" " LIMIT 1"
const selectMembershipCountSQL = "" +
"SELECT COUNT(*) FROM (" +
" SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" +
") t WHERE t.membership = $3"
type membershipsStatements struct { type membershipsStatements struct {
db *sql.DB db *sql.DB
upsertMembershipStmt *sql.Stmt upsertMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt
} }
func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
@ -79,6 +85,9 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil { if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil {
return nil, err return nil, err
} }
if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -117,3 +126,11 @@ func (s *membershipsStatements) SelectMembership(
err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos) err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos)
return return
} }
func (s *membershipsStatements) SelectMembershipCount(
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
) (count int, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt)
err = stmt.QueryRowContext(ctx, roomID, pos, membership).Scan(&count)
return
}

View file

@ -174,6 +174,7 @@ type Receipts interface {
type Memberships interface { type Memberships interface {
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error) SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error)
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
} }
type NotificationData interface { type NotificationData interface {

View file

@ -253,9 +253,25 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID()) updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
} }
hasMembershipChange := false
for _, recentEvent := range recentStreamEvents {
if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil {
hasMembershipChange = true
break
}
}
// Work out how many members are in the room.
joinedCount, _ := p.DB.MembershipCount(ctx, delta.RoomID, gomatrixserverlib.Join, latestPosition)
invitedCount, _ := p.DB.MembershipCount(ctx, delta.RoomID, gomatrixserverlib.Invite, latestPosition)
switch delta.Membership { switch delta.Membership {
case gomatrixserverlib.Join: case gomatrixserverlib.Join:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
if hasMembershipChange {
jr.Summary.JoinedMemberCount = &joinedCount
jr.Summary.InvitedMemberCount = &invitedCount
}
jr.Timeline.PrevBatch = &prevBatch jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited jr.Timeline.Limited = limited
@ -367,12 +383,18 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
prevBatch.Decrement() prevBatch.Decrement()
} }
// Work out how many members are in the room.
joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, r.From)
invitedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, r.From)
// We don't include a device here as we don't need to send down // We don't include a device here as we don't need to send down
// transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for:
// "Can sync a room with a message with a transaction id" - which does a complete sync to check. // "Can sync a room with a message with a transaction id" - which does a complete sync to check.
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
jr = types.NewJoinResponse() jr = types.NewJoinResponse()
jr.Summary.JoinedMemberCount = &joinedCount
jr.Summary.InvitedMemberCount = &invitedCount
jr.Timeline.PrevBatch = prevBatch jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited jr.Timeline.Limited = limited

View file

@ -377,6 +377,11 @@ func (r *Response) IsEmpty() bool {
// JoinResponse represents a /sync response for a room which is under the 'join' or 'peek' key. // JoinResponse represents a /sync response for a room which is under the 'join' or 'peek' key.
type JoinResponse struct { type JoinResponse struct {
Summary struct {
Heroes []string `json:"m.heroes,omitempty"`
JoinedMemberCount *int `json:"m.joined_member_count,omitempty"`
InvitedMemberCount *int `json:"m.invited_member_count,omitempty"`
} `json:"summary"`
State struct { State struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
} `json:"state"` } `json:"state"`

View file

@ -661,3 +661,4 @@ Canonical alias can include alt_aliases
Can delete canonical alias Can delete canonical alias
AS can make room aliases AS can make room aliases
/context/ with lazy_load_members filter works /context/ with lazy_load_members filter works
Room summary counts change when membership changes