diff --git a/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go b/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go index 8a5b9648d..84417a348 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go +++ b/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go @@ -19,6 +19,9 @@ import ( "database/sql" "fmt" + "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/roomserver/api" // Import the postgres database driver. _ "github.com/lib/pq" @@ -86,13 +89,17 @@ func (d *SyncServerDatabase) AllJoinedUsersInRooms(ctx context.Context) (map[str // Events lookups a list of event by their event ID. // Returns a list of events matching the requested IDs found in the database. // If an event is not found in the database then it will be omitted from the list. -// Returns an error if there was a problem talking with the database +// Returns an error if there was a problem talking with the database. +// Does not include any transaction IDs in the returned events. func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs) if err != nil { return nil, err } - return streamEventsToEvents(streamEvents), nil + + // We don't include a device here as we only include transaction IDs in + // incremental syncs. + return streamEventsToEvents(nil, streamEvents), nil } // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races @@ -208,10 +215,14 @@ func (d *SyncServerDatabase) syncStreamPositionTx( return types.StreamPosition(maxID), nil } -// IncrementalSync returns all the data needed in order to create an incremental sync response. +// IncrementalSync returns all the data needed in order to create an incremental +// sync response for the given user. Events returned will include any client +// transaction IDs associated with the given device. These transaction IDs come +// from when the device sent the event via an API that included a transaction +// ID. func (d *SyncServerDatabase) IncrementalSync( ctx context.Context, - userID string, + device authtypes.Device, fromPos, toPos types.StreamPosition, numRecentEventsPerRoom int, ) (*types.Response, error) { @@ -226,21 +237,21 @@ func (d *SyncServerDatabase) IncrementalSync( // joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions. // This works out what the 'state' key should be for each room as well as which membership block // to put the room into. - deltas, err := d.getStateDeltas(ctx, txn, fromPos, toPos, userID) + deltas, err := d.getStateDeltas(ctx, &device, txn, fromPos, toPos, device.UserID) if err != nil { return nil, err } res := types.NewResponse(toPos) for _, delta := range deltas { - err = d.addRoomDeltaToResponse(ctx, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res) + err = d.addRoomDeltaToResponse(ctx, &device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res) if err != nil { return nil, err } } // TODO: This should be done in getStateDeltas - if err = d.addInvitesToResponse(ctx, txn, userID, fromPos, toPos, res); err != nil { + if err = d.addInvitesToResponse(ctx, txn, device.UserID, fromPos, toPos, res); err != nil { return nil, err } @@ -292,7 +303,10 @@ func (d *SyncServerDatabase) CompleteSync( if err != nil { return nil, err } - recentEvents := streamEventsToEvents(recentStreamEvents) + + // We don't include a device here as we don't need to send down + // transaction IDs for complete syncs + recentEvents := streamEventsToEvents(nil, recentStreamEvents) stateEvents = removeDuplicates(stateEvents, recentEvents) jr := types.NewJoinResponse() @@ -390,7 +404,9 @@ func (d *SyncServerDatabase) addInvitesToResponse( // addRoomDeltaToResponse adds a room state delta to a sync response func (d *SyncServerDatabase) addRoomDeltaToResponse( - ctx context.Context, txn *sql.Tx, + ctx context.Context, + device *authtypes.Device, + txn *sql.Tx, fromPos, toPos types.StreamPosition, delta stateDelta, numRecentEventsPerRoom int, @@ -412,7 +428,7 @@ func (d *SyncServerDatabase) addRoomDeltaToResponse( if err != nil { return err } - recentEvents := streamEventsToEvents(recentStreamEvents) + recentEvents := streamEventsToEvents(device, recentStreamEvents) delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back // Don't bother appending empty room entries @@ -529,7 +545,7 @@ func (d *SyncServerDatabase) fetchMissingStateEvents( } func (d *SyncServerDatabase) getStateDeltas( - ctx context.Context, txn *sql.Tx, + ctx context.Context, device *authtypes.Device, txn *sql.Tx, fromPos, toPos types.StreamPosition, userID string, ) ([]stateDelta, error) { // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 @@ -578,7 +594,7 @@ func (d *SyncServerDatabase) getStateDeltas( deltas = append(deltas, stateDelta{ membership: membership, membershipPos: ev.streamPosition, - stateEvents: streamEventsToEvents(stateStreamEvents), + stateEvents: streamEventsToEvents(device, stateStreamEvents), roomID: roomID, }) break @@ -594,7 +610,7 @@ func (d *SyncServerDatabase) getStateDeltas( for _, joinedRoomID := range joinedRoomIDs { deltas = append(deltas, stateDelta{ membership: "join", - stateEvents: streamEventsToEvents(state[joinedRoomID]), + stateEvents: streamEventsToEvents(device, state[joinedRoomID]), roomID: joinedRoomID, }) } @@ -602,10 +618,25 @@ func (d *SyncServerDatabase) getStateDeltas( return deltas, nil } -func streamEventsToEvents(in []streamEvent) []gomatrixserverlib.Event { +// streamEventsToEvents converts streamEvent to Event. If device is non-nil and +// matches the streamevent.transactionID device then the transaction ID gets +// added to the unsigned section of the output event. +func streamEventsToEvents(device *authtypes.Device, in []streamEvent) []gomatrixserverlib.Event { out := make([]gomatrixserverlib.Event, len(in)) for i := 0; i < len(in); i++ { out[i] = in[i].Event + if device != nil && in[i].transactionID != nil { + if device.UserID == in[i].Sender() && device.ID == in[i].transactionID.DeviceID { + err := out[i].SetUnsignedField( + "transaction_id", in[i].transactionID.TransactionID, + ) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + } + } + } } return out } diff --git a/src/github.com/matrix-org/dendrite/syncapi/sync/notifier.go b/src/github.com/matrix-org/dendrite/syncapi/sync/notifier.go index 4712a2c74..5ed701d8e 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/sync/notifier.go +++ b/src/github.com/matrix-org/dendrite/syncapi/sync/notifier.go @@ -123,7 +123,7 @@ func (n *Notifier) GetListener(req syncRequest) UserStreamListener { n.removeEmptyUserStreams() - return n.fetchUserStream(req.userID, true).GetListener(req.ctx) + return n.fetchUserStream(req.device.UserID, true).GetListener(req.ctx) } // Load the membership states required to notify users correctly. diff --git a/src/github.com/matrix-org/dendrite/syncapi/sync/notifier_test.go b/src/github.com/matrix-org/dendrite/syncapi/sync/notifier_test.go index 79c5a2872..4fa543936 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/sync/notifier_test.go +++ b/src/github.com/matrix-org/dendrite/syncapi/sync/notifier_test.go @@ -21,6 +21,8 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -262,7 +264,7 @@ func waitForEvents(n *Notifier, req syncRequest) (types.StreamPosition, error) { select { case <-time.After(5 * time.Second): return types.StreamPosition(0), fmt.Errorf( - "waitForEvents timed out waiting for %s (pos=%d)", req.userID, req.since, + "waitForEvents timed out waiting for %s (pos=%d)", req.device.UserID, req.since, ) case <-listener.GetNotifyChannel(*req.since): p := listener.GetStreamPosition() @@ -280,7 +282,7 @@ func waitForBlocking(s *UserStream, numBlocking uint) { func newTestSyncRequest(userID string, since types.StreamPosition) syncRequest { return syncRequest{ - userID: userID, + device: authtypes.Device{UserID: userID}, timeout: 1 * time.Minute, since: &since, wantFullState: false, diff --git a/src/github.com/matrix-org/dendrite/syncapi/sync/request.go b/src/github.com/matrix-org/dendrite/syncapi/sync/request.go index 7f5259814..3c1befddf 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/sync/request.go +++ b/src/github.com/matrix-org/dendrite/syncapi/sync/request.go @@ -20,6 +20,8 @@ import ( "strconv" "time" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" @@ -31,7 +33,7 @@ const defaultTimelineLimit = 20 // syncRequest represents a /sync request, with sensible defaults/sanity checks applied. type syncRequest struct { ctx context.Context - userID string + device authtypes.Device limit int timeout time.Duration since *types.StreamPosition // nil means that no since token was supplied @@ -39,7 +41,7 @@ type syncRequest struct { log *log.Entry } -func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) { +func newSyncRequest(req *http.Request, device authtypes.Device) (*syncRequest, error) { timeout := getTimeout(req.URL.Query().Get("timeout")) fullState := req.URL.Query().Get("full_state") wantFullState := fullState != "" && fullState != "false" @@ -50,7 +52,7 @@ func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) { // TODO: Additional query params: set_presence, filter return &syncRequest{ ctx: req.Context(), - userID: userID, + device: device, timeout: timeout, since: since, wantFullState: wantFullState, diff --git a/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go b/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go index 15993b774..703ddd3f1 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go +++ b/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go @@ -48,7 +48,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype // Extract values from request logger := util.GetLogger(req.Context()) userID := device.UserID - syncReq, err := newSyncRequest(req, userID) + syncReq, err := newSyncRequest(req, *device) if err != nil { return util.JSONResponse{ Code: 400, @@ -122,16 +122,16 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.StreamPosition) (res *types.Response, err error) { // TODO: handle ignored users if req.since == nil { - res, err = rp.db.CompleteSync(req.ctx, req.userID, req.limit) + res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit) } else { - res, err = rp.db.IncrementalSync(req.ctx, req.userID, *req.since, currentPos, req.limit) + res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, currentPos, req.limit) } if err != nil { return } - res, err = rp.appendAccountData(res, req.userID, req, currentPos) + res, err = rp.appendAccountData(res, req.device.UserID, req, currentPos) return }