diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 018d61f70..618aad95c 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -38,6 +38,7 @@ type Database interface { GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) + MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index 6566544d6..1242a3221 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -62,9 +62,15 @@ const selectMembershipSQL = "" + " ORDER BY stream_pos DESC" + " LIMIT 1" +const selectMembershipCountSQL = "" + + "SELECT COUNT(*) FROM (" + + " SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" + + ") t WHERE t.membership = $3" + type membershipsStatements struct { - upsertMembershipStmt *sql.Stmt - selectMembershipStmt *sql.Stmt + upsertMembershipStmt *sql.Stmt + selectMembershipStmt *sql.Stmt + selectMembershipCountStmt *sql.Stmt } func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -79,6 +85,9 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil { return nil, err } + if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil { + return nil, err + } return s, nil } @@ -109,3 +118,11 @@ func (s *membershipsStatements) SelectMembership( err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos) return } + +func (s *membershipsStatements) SelectMembershipCount( + ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition, +) (count int, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt) + err = stmt.QueryRowContext(ctx, roomID, pos, membership).Scan(&count) + return +} diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 6b78f69d7..7f47a7e48 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -119,6 +119,10 @@ func (d *Database) RoomIDsWithMembership(ctx context.Context, userID string, mem return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership) } +func (d *Database) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) { + return d.Memberships.SelectMembershipCount(ctx, nil, roomID, membership, pos) +} + func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) } diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index e5445e815..776bf3da3 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -63,9 +63,15 @@ const selectMembershipSQL = "" + " ORDER BY stream_pos DESC" + " LIMIT 1" +const selectMembershipCountSQL = "" + + "SELECT COUNT(*) FROM (" + + " SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" + + ") t WHERE t.membership = $3" + type membershipsStatements struct { - db *sql.DB - upsertMembershipStmt *sql.Stmt + db *sql.DB + upsertMembershipStmt *sql.Stmt + selectMembershipCountStmt *sql.Stmt } func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -79,6 +85,9 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil { return nil, err } + if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil { + return nil, err + } return s, nil } @@ -117,3 +126,11 @@ func (s *membershipsStatements) SelectMembership( err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos) return } + +func (s *membershipsStatements) SelectMembershipCount( + ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition, +) (count int, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt) + err = stmt.QueryRowContext(ctx, roomID, pos, membership).Scan(&count) + return +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 08e589a33..585515328 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -174,6 +174,7 @@ type Receipts interface { type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error) + SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) } type NotificationData interface { diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index ccdac0864..d23209af3 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -253,9 +253,25 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID()) } + hasMembershipChange := false + for _, recentEvent := range recentStreamEvents { + if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil { + hasMembershipChange = true + break + } + } + + // Work out how many members are in the room. + joinedCount, _ := p.DB.MembershipCount(ctx, delta.RoomID, gomatrixserverlib.Join, latestPosition) + invitedCount, _ := p.DB.MembershipCount(ctx, delta.RoomID, gomatrixserverlib.Invite, latestPosition) + switch delta.Membership { case gomatrixserverlib.Join: jr := types.NewJoinResponse() + if hasMembershipChange { + jr.Summary.JoinedMemberCount = &joinedCount + jr.Summary.InvitedMemberCount = &invitedCount + } jr.Timeline.PrevBatch = &prevBatch jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = limited @@ -367,12 +383,18 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( prevBatch.Decrement() } + // Work out how many members are in the room. + joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, r.From) + invitedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, r.From) + // We don't include a device here as we don't need to send down // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: // "Can sync a room with a message with a transaction id" - which does a complete sync to check. recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) stateEvents = removeDuplicates(stateEvents, recentEvents) jr = types.NewJoinResponse() + jr.Summary.JoinedMemberCount = &joinedCount + jr.Summary.InvitedMemberCount = &invitedCount jr.Timeline.PrevBatch = prevBatch jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = limited diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 01fc303ef..3402efac0 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -386,6 +386,11 @@ func (r *Response) IsEmpty() bool { // JoinResponse represents a /sync response for a room which is under the 'join' or 'peek' key. type JoinResponse struct { + Summary struct { + Heroes []string `json:"m.heroes,omitempty"` + JoinedMemberCount *int `json:"m.joined_member_count,omitempty"` + InvitedMemberCount *int `json:"m.invited_member_count,omitempty"` + } `json:"summary"` State struct { Events []gomatrixserverlib.ClientEvent `json:"events"` } `json:"state"` diff --git a/sytest-whitelist b/sytest-whitelist index ab23727ea..7f8f20193 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -661,6 +661,7 @@ Canonical alias can include alt_aliases Can delete canonical alias AS can make room aliases /context/ with lazy_load_members filter works +Room summary counts change when membership changes GET /presence/:user_id/status fetches initial status PUT /presence/:user_id/status updates my presence Presence change reports an event to myself @@ -674,4 +675,4 @@ Presence changes are reported to local room members Presence changes are also reported to remote room members Presence changes to UNAVAILABLE are reported to local room members Presence changes to UNAVAILABLE are reported to remote room members -New federated private chats get full presence information (SYN-115) \ No newline at end of file +New federated private chats get full presence information (SYN-115)