Handle newly joined rooms

This commit is contained in:
Till Faelligen 2022-08-24 16:31:21 +02:00
parent af6ca1ab26
commit 6db4d96718
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
3 changed files with 35 additions and 21 deletions

View file

@ -19,10 +19,11 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
) )
type Database interface { type Database interface {
@ -38,7 +39,7 @@ type Database interface {
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, map[string]struct{}, error)
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error)

View file

@ -686,7 +686,7 @@ func (d *Database) GetStateDeltas(
ctx context.Context, device *userapi.Device, ctx context.Context, device *userapi.Device,
r types.Range, userID string, r types.Range, userID string,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
) ([]types.StateDelta, []string, error) { ) (deltas []types.StateDelta, joinedRoomsIDs []string, newlyJoinedRooms map[string]struct{}, err error) {
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
// - Get membership list changes for this user in this sync response // - Get membership list changes for this user in this sync response
// - For each room which has membership list changes: // - For each room which has membership list changes:
@ -697,7 +697,7 @@ func (d *Database) GetStateDeltas(
// - Get all CURRENTLY joined rooms, and add them to 'joined' block. // - Get all CURRENTLY joined rooms, and add them to 'joined' block.
txn, err := d.readOnlySnapshot(ctx) txn, err := d.readOnlySnapshot(ctx)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err) return nil, nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err)
} }
var succeeded bool var succeeded bool
defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err)
@ -707,9 +707,9 @@ func (d *Database) GetStateDeltas(
memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil, nil return nil, nil, nil, nil
} }
return nil, nil, err return nil, nil, nil, err
} }
allRoomIDs := make([]string, 0, len(memberships)) allRoomIDs := make([]string, 0, len(memberships))
@ -721,29 +721,27 @@ func (d *Database) GetStateDeltas(
} }
} }
var deltas []types.StateDelta
// get all the state events ever (i.e. for all available rooms) between these two positions // get all the state events ever (i.e. for all available rooms) between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil, nil return nil, nil, nil, nil
} }
return nil, nil, err return nil, nil, nil, err
} }
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil, nil return nil, nil, nil, nil
} }
return nil, nil, err return nil, nil, nil, err
} }
// find out which rooms this user is peeking, if any. // find out which rooms this user is peeking, if any.
// We do this before joins so any peeks get overwritten // We do this before joins so any peeks get overwritten
peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return nil, nil, err return nil, nil, nil, err
} }
// add peek blocks // add peek blocks
@ -756,7 +754,7 @@ func (d *Database) GetStateDeltas(
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
continue continue
} }
return nil, nil, err return nil, nil, nil, err
} }
state[peek.RoomID] = s state[peek.RoomID] = s
} }
@ -770,10 +768,12 @@ func (d *Database) GetStateDeltas(
} }
// handle newly joined rooms and non-joined rooms // handle newly joined rooms and non-joined rooms
newlyJoinedRoomIDs := make(map[string]struct{}, len(memberships))
for roomID, stateStreamEvents := range state { for roomID, stateStreamEvents := range state {
for _, ev := range stateStreamEvents { for _, ev := range stateStreamEvents {
if membership, prevMembership := getMembershipFromEvent(ev.Event, userID); membership != "" { if membership, prevMembership := getMembershipFromEvent(ev.Event, userID); membership != "" {
if membership == gomatrixserverlib.Join && prevMembership != membership { if membership == gomatrixserverlib.Join && prevMembership != membership {
newlyJoinedRoomIDs[roomID] = struct{}{}
// send full room state down instead of a delta // send full room state down instead of a delta
var s []types.StreamEvent var s []types.StreamEvent
s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter) s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter)
@ -781,7 +781,7 @@ func (d *Database) GetStateDeltas(
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
continue continue
} }
return nil, nil, err return nil, nil, nil, err
} }
state[roomID] = s state[roomID] = s
continue // we'll add this room in when we do joined rooms continue // we'll add this room in when we do joined rooms
@ -808,7 +808,7 @@ func (d *Database) GetStateDeltas(
} }
succeeded = true succeeded = true
return deltas, joinedRoomIDs, nil return deltas, joinedRoomIDs, newlyJoinedRoomIDs, nil
} }
// GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync // GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync

View file

@ -178,24 +178,25 @@ func (p *PDUStreamProvider) IncrementalSync(
var err error var err error
var stateDeltas []types.StateDelta var stateDeltas []types.StateDelta
var joinedRooms []string var syncJoinedRooms []string
var newlyJoinedRooms map[string]struct{}
stateFilter := req.Filter.Room.State stateFilter := req.Filter.Room.State
eventFilter := req.Filter.Room.Timeline eventFilter := req.Filter.Room.Timeline
if req.WantFullState { if req.WantFullState {
if stateDeltas, joinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed") req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed")
return return
} }
} else { } else {
if stateDeltas, joinedRooms, err = p.DB.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { if stateDeltas, syncJoinedRooms, newlyJoinedRooms, err = p.DB.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
req.Log.WithError(err).Error("p.DB.GetStateDeltas failed") req.Log.WithError(err).Error("p.DB.GetStateDeltas failed")
return return
} }
} }
for _, roomID := range joinedRooms { for _, roomID := range syncJoinedRooms {
req.Rooms[roomID] = gomatrixserverlib.Join req.Rooms[roomID] = gomatrixserverlib.Join
} }
@ -209,8 +210,20 @@ func (p *PDUStreamProvider) IncrementalSync(
newPos = from newPos = from
for _, delta := range stateDeltas { for _, delta := range stateDeltas {
newRange := r
// If this room was joined in this sync, try to fetch
// as much timeline events as allowed by the filter.
if _, ok := newlyJoinedRooms[delta.RoomID]; ok {
// Reverse the range, so we get the most recent first.
// This will be limited by the eventFilter.
newRange = types.Range{
From: r.To,
To: 0,
Backwards: true,
}
}
var pos types.StreamPosition var pos types.StreamPosition
if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, &stateFilter, req.Response); err != nil { if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil {
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed") req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
return to return to
} }