Handle newly joined rooms differently

This commit is contained in:
Till Faelligen 2022-08-24 18:36:47 +02:00
parent deabd9b9b2
commit 80ad177e18
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
4 changed files with 30 additions and 20 deletions

View file

@ -39,7 +39,7 @@ type Database interface {
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, map[string]struct{}, error) GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error)

View file

@ -686,7 +686,7 @@ func (d *Database) GetStateDeltas(
ctx context.Context, device *userapi.Device, ctx context.Context, device *userapi.Device,
r types.Range, userID string, r types.Range, userID string,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
) (deltas []types.StateDelta, joinedRoomsIDs []string, newlyJoinedRooms map[string]struct{}, err error) { ) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) {
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
// - Get membership list changes for this user in this sync response // - Get membership list changes for this user in this sync response
// - For each room which has membership list changes: // - For each room which has membership list changes:
@ -697,7 +697,7 @@ func (d *Database) GetStateDeltas(
// - Get all CURRENTLY joined rooms, and add them to 'joined' block. // - Get all CURRENTLY joined rooms, and add them to 'joined' block.
txn, err := d.readOnlySnapshot(ctx) txn, err := d.readOnlySnapshot(ctx)
if err != nil { if err != nil {
return nil, nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err) return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err)
} }
var succeeded bool var succeeded bool
defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err)
@ -707,9 +707,9 @@ func (d *Database) GetStateDeltas(
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil, nil, nil return nil, nil, nil
} }
return nil, nil, nil, err return nil, nil, err
} }
allRoomIDs := make([]string, 0, len(memberships)) allRoomIDs := make([]string, 0, len(memberships))
@ -725,23 +725,23 @@ func (d *Database) GetStateDeltas(
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil, nil, nil return nil, nil, nil
} }
return nil, nil, nil, err return nil, nil, err
} }
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil, nil, nil return nil, nil, nil
} }
return nil, nil, nil, err return nil, nil, err
} }
// find out which rooms this user is peeking, if any. // find out which rooms this user is peeking, if any.
// We do this before joins so any peeks get overwritten // We do this before joins so any peeks get overwritten
peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return nil, nil, nil, err return nil, nil, err
} }
// add peek blocks // add peek blocks
@ -754,7 +754,7 @@ func (d *Database) GetStateDeltas(
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
continue continue
} }
return nil, nil, nil, err return nil, nil, err
} }
state[peek.RoomID] = s state[peek.RoomID] = s
} }
@ -768,12 +768,11 @@ func (d *Database) GetStateDeltas(
} }
// handle newly joined rooms and non-joined rooms // handle newly joined rooms and non-joined rooms
newlyJoinedRoomIDs := make(map[string]struct{}, len(memberships)) newlyJoinedRooms := make(map[string]struct{}, len(state))
for roomID, stateStreamEvents := range state { for roomID, stateStreamEvents := range state {
for _, ev := range stateStreamEvents { for _, ev := range stateStreamEvents {
if membership, prevMembership := getMembershipFromEvent(ev.Event, userID); membership != "" { if membership, prevMembership := getMembershipFromEvent(ev.Event, userID); membership != "" {
if membership == gomatrixserverlib.Join && prevMembership != membership { if membership == gomatrixserverlib.Join && prevMembership != membership {
newlyJoinedRoomIDs[roomID] = struct{}{}
// send full room state down instead of a delta // send full room state down instead of a delta
var s []types.StreamEvent var s []types.StreamEvent
s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter) s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter)
@ -781,9 +780,10 @@ func (d *Database) GetStateDeltas(
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
continue continue
} }
return nil, nil, nil, err return nil, nil, err
} }
state[roomID] = s state[roomID] = s
newlyJoinedRooms[roomID] = struct{}{}
continue // we'll add this room in when we do joined rooms continue // we'll add this room in when we do joined rooms
} }
@ -800,15 +800,19 @@ func (d *Database) GetStateDeltas(
// Add in currently joined rooms // Add in currently joined rooms
for _, joinedRoomID := range joinedRoomIDs { for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, types.StateDelta{ delta := types.StateDelta{
Membership: gomatrixserverlib.Join, Membership: gomatrixserverlib.Join,
StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]),
RoomID: joinedRoomID, RoomID: joinedRoomID,
}) }
if _, ok := newlyJoinedRooms[joinedRoomID]; ok {
delta.NewlyJoined = true
}
deltas = append(deltas, delta)
} }
succeeded = true succeeded = true
return deltas, joinedRoomIDs, newlyJoinedRoomIDs, nil return deltas, joinedRoomIDs, nil
} }
// GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync // GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync

View file

@ -179,7 +179,6 @@ func (p *PDUStreamProvider) IncrementalSync(
var err error var err error
var stateDeltas []types.StateDelta var stateDeltas []types.StateDelta
var syncJoinedRooms []string var syncJoinedRooms []string
var newlyJoinedRooms map[string]struct{}
stateFilter := req.Filter.Room.State stateFilter := req.Filter.Room.State
eventFilter := req.Filter.Room.Timeline eventFilter := req.Filter.Room.Timeline
@ -190,7 +189,7 @@ func (p *PDUStreamProvider) IncrementalSync(
return return
} }
} else { } else {
if stateDeltas, syncJoinedRooms, newlyJoinedRooms, err = p.DB.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
req.Log.WithError(err).Error("p.DB.GetStateDeltas failed") req.Log.WithError(err).Error("p.DB.GetStateDeltas failed")
return return
} }
@ -213,7 +212,7 @@ func (p *PDUStreamProvider) IncrementalSync(
newRange := r newRange := r
// If this room was joined in this sync, try to fetch // If this room was joined in this sync, try to fetch
// as much timeline events as allowed by the filter. // as much timeline events as allowed by the filter.
if _, ok := newlyJoinedRooms[delta.RoomID]; ok { if delta.NewlyJoined {
// Reverse the range, so we get the most recent first. // Reverse the range, so we get the most recent first.
// This will be limited by the eventFilter. // This will be limited by the eventFilter.
newRange = types.Range{ newRange = types.Range{
@ -227,6 +226,10 @@ func (p *PDUStreamProvider) IncrementalSync(
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed") req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
return to return to
} }
// Reset the position, as it is only for the special case of newly joined rooms
if delta.NewlyJoined {
pos = newRange.From
}
switch { switch {
case r.Backwards && pos < newPos: case r.Backwards && pos < newPos:
fallthrough fallthrough
@ -400,6 +403,8 @@ func applyHistoryVisibilityFilter(
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime), "duration": time.Since(startTime),
"room_id": roomID, "room_id": roomID,
"before": len(recentEvents),
"after": len(events),
}).Debug("applied history visibility (sync)") }).Debug("applied history visibility (sync)")
return events, nil return events, nil
} }

View file

@ -37,6 +37,7 @@ var (
type StateDelta struct { type StateDelta struct {
RoomID string RoomID string
StateEvents []*gomatrixserverlib.HeaderedEvent StateEvents []*gomatrixserverlib.HeaderedEvent
NewlyJoined bool
Membership string Membership string
// The PDU stream position of the latest membership event for this user, if applicable. // The PDU stream position of the latest membership event for this user, if applicable.
// Can be 0 if there is no membership event in this delta. // Can be 0 if there is no membership event in this delta.