Include joined and invite member counts in room summary ()

* Include joined and invite member counts in room summary

This should fix  and also fix the problem where some clients like Element Android, Fluffychat etc would display the wrong member count for a given room.

* Improve SQLite query precision

* Check existence of state key for membership events
This commit is contained in:
Neil Alexander 2022-04-01 16:14:38 +01:00 committed by GitHub
parent 8213b2ba30
commit cd8fac152e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 72 additions and 4 deletions

View file

@ -37,6 +37,7 @@ type Database interface {
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)

View file

@ -62,9 +62,15 @@ const selectMembershipSQL = "" +
" ORDER BY stream_pos DESC" + " ORDER BY stream_pos DESC" +
" LIMIT 1" " LIMIT 1"
const selectMembershipCountSQL = "" +
"SELECT COUNT(*) FROM (" +
" SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" +
") t WHERE t.membership = $3"
type membershipsStatements struct { type membershipsStatements struct {
upsertMembershipStmt *sql.Stmt upsertMembershipStmt *sql.Stmt
selectMembershipStmt *sql.Stmt selectMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt
} }
func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
@ -79,6 +85,9 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil { if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil {
return nil, err return nil, err
} }
if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -109,3 +118,11 @@ func (s *membershipsStatements) SelectMembership(
err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos) err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos)
return return
} }
func (s *membershipsStatements) SelectMembershipCount(
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
) (count int, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt)
err = stmt.QueryRowContext(ctx, roomID, pos, membership).Scan(&count)
return
}

View file

@ -118,6 +118,10 @@ func (d *Database) RoomIDsWithMembership(ctx context.Context, userID string, mem
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership) return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership)
} }
func (d *Database) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) {
return d.Memberships.SelectMembershipCount(ctx, nil, roomID, membership, pos)
}
func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents)
} }

View file

@ -63,9 +63,15 @@ const selectMembershipSQL = "" +
" ORDER BY stream_pos DESC" + " ORDER BY stream_pos DESC" +
" LIMIT 1" " LIMIT 1"
const selectMembershipCountSQL = "" +
"SELECT COUNT(*) FROM (" +
" SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" +
") t WHERE t.membership = $3"
type membershipsStatements struct { type membershipsStatements struct {
db *sql.DB db *sql.DB
upsertMembershipStmt *sql.Stmt upsertMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt
} }
func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
@ -79,6 +85,9 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil { if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil {
return nil, err return nil, err
} }
if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -117,3 +126,11 @@ func (s *membershipsStatements) SelectMembership(
err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos) err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos)
return return
} }
func (s *membershipsStatements) SelectMembershipCount(
ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition,
) (count int, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt)
err = stmt.QueryRowContext(ctx, roomID, pos, membership).Scan(&count)
return
}

View file

@ -174,6 +174,7 @@ type Receipts interface {
type Memberships interface { type Memberships interface {
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error) SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error)
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
} }
type NotificationData interface { type NotificationData interface {

View file

@ -253,9 +253,25 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID()) updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
} }
hasMembershipChange := false
for _, recentEvent := range recentStreamEvents {
if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil {
hasMembershipChange = true
break
}
}
// Work out how many members are in the room.
joinedCount, _ := p.DB.MembershipCount(ctx, delta.RoomID, gomatrixserverlib.Join, latestPosition)
invitedCount, _ := p.DB.MembershipCount(ctx, delta.RoomID, gomatrixserverlib.Invite, latestPosition)
switch delta.Membership { switch delta.Membership {
case gomatrixserverlib.Join: case gomatrixserverlib.Join:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
if hasMembershipChange {
jr.Summary.JoinedMemberCount = &joinedCount
jr.Summary.InvitedMemberCount = &invitedCount
}
jr.Timeline.PrevBatch = &prevBatch jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited jr.Timeline.Limited = limited
@ -367,12 +383,18 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
prevBatch.Decrement() prevBatch.Decrement()
} }
// Work out how many members are in the room.
joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, r.From)
invitedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, r.From)
// We don't include a device here as we don't need to send down // We don't include a device here as we don't need to send down
// transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for:
// "Can sync a room with a message with a transaction id" - which does a complete sync to check. // "Can sync a room with a message with a transaction id" - which does a complete sync to check.
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
jr = types.NewJoinResponse() jr = types.NewJoinResponse()
jr.Summary.JoinedMemberCount = &joinedCount
jr.Summary.InvitedMemberCount = &invitedCount
jr.Timeline.PrevBatch = prevBatch jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited jr.Timeline.Limited = limited

View file

@ -377,6 +377,11 @@ func (r *Response) IsEmpty() bool {
// JoinResponse represents a /sync response for a room which is under the 'join' or 'peek' key. // JoinResponse represents a /sync response for a room which is under the 'join' or 'peek' key.
type JoinResponse struct { type JoinResponse struct {
Summary struct {
Heroes []string `json:"m.heroes,omitempty"`
JoinedMemberCount *int `json:"m.joined_member_count,omitempty"`
InvitedMemberCount *int `json:"m.invited_member_count,omitempty"`
} `json:"summary"`
State struct { State struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
} `json:"state"` } `json:"state"`

View file

@ -661,3 +661,4 @@ Canonical alias can include alt_aliases
Can delete canonical alias Can delete canonical alias
AS can make room aliases AS can make room aliases
/context/ with lazy_load_members filter works /context/ with lazy_load_members filter works
Room summary counts change when membership changes