Use the roomserver API to query senderID during sync event conversion

This commit is contained in:
Devon Hudson 2023-06-07 09:46:06 -06:00
parent 51132f9e81
commit 1eea02dcfa
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
6 changed files with 54 additions and 58 deletions

View file

@ -495,7 +495,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
} }
// Append the events ve previously retrieved locally. // Append the events ve previously retrieved locally.
events = append(events, r.snapshot.StreamEventsToEvents(nil, streamEvents)...) events = append(events, r.snapshot.StreamEventsToEvents(r.ctx, nil, streamEvents, r.rsAPI)...)
sort.Sort(eventsByDepth(events)) sort.Sort(eventsByDepth(events))
return return

View file

@ -44,8 +44,8 @@ type DatabaseTransaction interface {
MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error)
CurrentState(ctx context.Context, roomID string, stateFilterPart *synctypes.StateFilter, excludeEventIDs []string) ([]*rstypes.HeaderedEvent, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *synctypes.StateFilter, excludeEventIDs []string) ([]*rstypes.HeaderedEvent, error)
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI) ([]types.StateDelta, []string, error)
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI) ([]types.StateDelta, []string, error)
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error) GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error)
@ -90,7 +90,7 @@ type DatabaseTransaction interface {
// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
// matches the streamevent.transactionID device then the transaction ID gets // matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event. // added to the unsigned section of the output event.
StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*rstypes.HeaderedEvent StreamEventsToEvents(ctx context.Context, device *userapi.Device, in []types.StreamEvent, rsAPI api.SyncRoomserverAPI) []*rstypes.HeaderedEvent
// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the
// relevant events within the given ranges for the supplied user ID and device ID. // relevant events within the given ranges for the supplied user ID and device ID.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error) SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error)

View file

@ -99,7 +99,41 @@ func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*rstypes.He
// We don't include a device here as we only include transaction IDs in // We don't include a device here as we only include transaction IDs in
// incremental syncs. // incremental syncs.
return d.StreamEventsToEvents(nil, streamEvents), nil return d.StreamEventsToEvents(ctx, nil, streamEvents, nil), nil
}
func (d *Database) StreamEventsToEvents(ctx context.Context, device *userapi.Device, in []types.StreamEvent, rsAPI api.SyncRoomserverAPI) []*rstypes.HeaderedEvent {
out := make([]*rstypes.HeaderedEvent, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[i].HeaderedEvent
if device != nil && in[i].TransactionID != nil {
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
continue
}
deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, in[i].RoomID(), *userID)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
continue
}
if deviceSenderID == in[i].SenderID() && device.SessionID == in[i].TransactionID.SessionID {
err := out[i].SetUnsignedField(
"transaction_id", in[i].TransactionID.TransactionID,
)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
}
}
}
}
return out
} }
// AddInviteEvent stores a new invite event for a user. // AddInviteEvent stores a new invite event for a user.
@ -190,45 +224,6 @@ func (d *Database) UpsertAccountData(
return return
} }
func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*rstypes.HeaderedEvent {
out := make([]*rstypes.HeaderedEvent, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[i].HeaderedEvent
if device != nil && in[i].TransactionID != nil {
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
continue
}
deviceSenderID, err := d.getSenderIDForUser(in[i].RoomID(), *userID)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
continue
}
if deviceSenderID == in[i].SenderID() && device.SessionID == in[i].TransactionID.SessionID {
err := out[i].SetUnsignedField(
"transaction_id", in[i].TransactionID.TransactionID,
)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
}
}
}
}
return out
}
func (d *Database) getSenderIDForUser(roomID string, userID spec.UserID) (spec.SenderID, error) { // nolint
// TODO: Replace with actual logic for pseudoIDs
return spec.SenderID(userID.String()), nil
}
// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of // handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of
// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table // the events listed in the event's 'prev_events'. This function also updates the backwards extremities table
// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such.

View file

@ -10,6 +10,7 @@ import (
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
rstypes "github.com/matrix-org/dendrite/roomserver/types" rstypes "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -186,7 +187,7 @@ func (d *DatabaseTransaction) Events(ctx context.Context, eventIDs []string) ([]
// We don't include a device here as we only include transaction IDs in // We don't include a device here as we only include transaction IDs in
// incremental syncs. // incremental syncs.
return d.StreamEventsToEvents(nil, streamEvents), nil return d.StreamEventsToEvents(ctx, nil, streamEvents, nil), nil
} }
func (d *DatabaseTransaction) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { func (d *DatabaseTransaction) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
@ -325,7 +326,7 @@ func (d *DatabaseTransaction) GetBackwardTopologyPos(
func (d *DatabaseTransaction) GetStateDeltas( func (d *DatabaseTransaction) GetStateDeltas(
ctx context.Context, device *userapi.Device, ctx context.Context, device *userapi.Device,
r types.Range, userID string, r types.Range, userID string,
stateFilter *synctypes.StateFilter, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI,
) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) { ) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) {
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
// - Get membership list changes for this user in this sync response // - Get membership list changes for this user in this sync response
@ -417,7 +418,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
if !peek.Deleted { if !peek.Deleted {
deltas = append(deltas, types.StateDelta{ deltas = append(deltas, types.StateDelta{
Membership: spec.Peek, Membership: spec.Peek,
StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]), StateEvents: d.StreamEventsToEvents(ctx, device, state[peek.RoomID], rsAPI),
RoomID: peek.RoomID, RoomID: peek.RoomID,
}) })
} }
@ -462,7 +463,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
deltas = append(deltas, types.StateDelta{ deltas = append(deltas, types.StateDelta{
Membership: membership, Membership: membership,
MembershipPos: ev.StreamPosition, MembershipPos: ev.StreamPosition,
StateEvents: d.StreamEventsToEvents(device, stateFiltered[roomID]), StateEvents: d.StreamEventsToEvents(ctx, device, stateFiltered[roomID], rsAPI),
RoomID: roomID, RoomID: roomID,
}) })
break break
@ -474,7 +475,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
for _, joinedRoomID := range joinedRoomIDs { for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, types.StateDelta{ deltas = append(deltas, types.StateDelta{
Membership: spec.Join, Membership: spec.Join,
StateEvents: d.StreamEventsToEvents(device, stateFiltered[joinedRoomID]), StateEvents: d.StreamEventsToEvents(ctx, device, stateFiltered[joinedRoomID], rsAPI),
RoomID: joinedRoomID, RoomID: joinedRoomID,
NewlyJoined: newlyJoinedRooms[joinedRoomID], NewlyJoined: newlyJoinedRooms[joinedRoomID],
}) })
@ -490,7 +491,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
ctx context.Context, device *userapi.Device, ctx context.Context, device *userapi.Device,
r types.Range, userID string, r types.Range, userID string,
stateFilter *synctypes.StateFilter, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI,
) ([]types.StateDelta, []string, error) { ) ([]types.StateDelta, []string, error) {
// Look up all memberships for the user. We only care about rooms that a // Look up all memberships for the user. We only care about rooms that a
// user has ever interacted with — joined to, kicked/banned from, left. // user has ever interacted with — joined to, kicked/banned from, left.
@ -531,7 +532,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
} }
deltas[peek.RoomID] = types.StateDelta{ deltas[peek.RoomID] = types.StateDelta{
Membership: spec.Peek, Membership: spec.Peek,
StateEvents: d.StreamEventsToEvents(device, s), StateEvents: d.StreamEventsToEvents(ctx, device, s, rsAPI),
RoomID: peek.RoomID, RoomID: peek.RoomID,
} }
} }
@ -560,7 +561,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
deltas[roomID] = types.StateDelta{ deltas[roomID] = types.StateDelta{
Membership: membership, Membership: membership,
MembershipPos: ev.StreamPosition, MembershipPos: ev.StreamPosition,
StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), StateEvents: d.StreamEventsToEvents(ctx, device, stateStreamEvents, rsAPI),
RoomID: roomID, RoomID: roomID,
} }
} }
@ -581,7 +582,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
} }
deltas[joinedRoomID] = types.StateDelta{ deltas[joinedRoomID] = types.StateDelta{
Membership: spec.Join, Membership: spec.Join,
StateEvents: d.StreamEventsToEvents(device, s), StateEvents: d.StreamEventsToEvents(ctx, device, s, rsAPI),
RoomID: joinedRoomID, RoomID: joinedRoomID,
} }
} }

View file

@ -214,7 +214,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
} }
gots := snapshot.StreamEventsToEvents(nil, paginatedEvents) gots := snapshot.StreamEventsToEvents(context.Background(), nil, paginatedEvents, nil)
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:])) test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
}) })
}) })

View file

@ -175,12 +175,12 @@ func (p *PDUStreamProvider) IncrementalSync(
eventFilter := req.Filter.Room.Timeline eventFilter := req.Filter.Room.Timeline
if req.WantFullState { if req.WantFullState {
if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter, p.rsAPI); err != nil {
req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed") req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed")
return from return from
} }
} else { } else {
if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter, p.rsAPI); err != nil {
req.Log.WithError(err).Error("p.DB.GetStateDeltas failed") req.Log.WithError(err).Error("p.DB.GetStateDeltas failed")
return from return from
} }
@ -275,7 +275,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
limited := dbEvents[delta.RoomID].Limited limited := dbEvents[delta.RoomID].Limited
recEvents := gomatrixserverlib.ReverseTopologicalOrdering( recEvents := gomatrixserverlib.ReverseTopologicalOrdering(
gomatrixserverlib.ToPDUs(snapshot.StreamEventsToEvents(device, recentStreamEvents)), gomatrixserverlib.ToPDUs(snapshot.StreamEventsToEvents(ctx, device, recentStreamEvents, p.rsAPI)),
gomatrixserverlib.TopologicalOrderByPrevEvents, gomatrixserverlib.TopologicalOrderByPrevEvents,
) )
recentEvents := make([]*rstypes.HeaderedEvent, len(recEvents)) recentEvents := make([]*rstypes.HeaderedEvent, len(recEvents))
@ -512,7 +512,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
// We don't include a device here as we don't need to send down // We don't include a device here as we don't need to send down
// transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for:
// "Can sync a room with a message with a transaction id" - which does a complete sync to check. // "Can sync a room with a message with a transaction id" - which does a complete sync to check.
recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents) recentEvents := snapshot.StreamEventsToEvents(ctx, device, recentStreamEvents, p.rsAPI)
events := recentEvents events := recentEvents
// Only apply history visibility checks if the response is for joined rooms // Only apply history visibility checks if the response is for joined rooms