From 31e6a7f1932c11d9b5b682ad06a5b8db9d74a44f Mon Sep 17 00:00:00 2001 From: Sid Karunaratne Date: Wed, 13 May 2020 19:04:54 +0800 Subject: [PATCH 1/2] Enforce `mediaIDRegex` to be only valid `mediaIDCharacters` (#1020) Error messages indicate that: > mediaId must be a non-empty string using only characters in `mediaIDCharacters` However the regex used only required that some characters in the filename match the restriction, not that the entire filename does. This commit ensures that the filename must entirely fullfill the `mediaIDCharacters` restriction Signed-off-by: Sid Karunaratne Co-authored-by: Kegsay --- mediaapi/routing/download.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index 9feca90e9..75df313f6 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -43,7 +43,7 @@ import ( const mediaIDCharacters = "A-Za-z0-9_=-" // Note: unfortunately regex.MustCompile() cannot be assigned to a const -var mediaIDRegex = regexp.MustCompile("[" + mediaIDCharacters + "]+") +var mediaIDRegex = regexp.MustCompile("^[" + mediaIDCharacters + "]+$") // downloadRequest metadata included in or derivable from a download or thumbnail request // https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-media-r0-download-servername-mediaid From 5e9dce1c0c66736937eeddd5c33c92700d9a65a7 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Wed, 13 May 2020 12:14:50 +0100 Subject: [PATCH 2/2] syncapi: Rename and split out tokens (#1025) * syncapi: Rename and split out tokens Previously we used the badly named `PaginationToken` which was used for both `/sync` and `/messages` requests. This quickly became confusing because named fields like `PDUPosition` meant different things depending on the token type. Instead, we now have two token types: `TopologyToken` and `StreamingToken`, both of which have fields which make more sense for their specific situations. Updated the codebase to use one or the other. `PaginationToken` still lives on as `syncToken`, an unexported type which both tokens rely on. This allows us to guarantee that the specific mappings of positions to a string remain solely under the control of the `types` package. This enables us to move high-level conceptual things like "decrement this topological token" to function calls e.g `TopologicalToken.Decrement()`. Currently broken because `/messages` seemingly used both stream and topological tokens, though I need to confirm this. * final tweaks/hacks * spurious logging * Review comments and linting --- syncapi/consumers/clientapi.go | 2 +- syncapi/consumers/eduserver.go | 6 +- syncapi/consumers/roomserver.go | 4 +- syncapi/routing/messages.go | 77 +++---- syncapi/storage/interface.go | 11 +- syncapi/storage/postgres/syncserver.go | 127 +++++------ syncapi/storage/sqlite3/syncserver.go | 154 ++++++------- syncapi/storage/storage_test.go | 64 +++--- syncapi/sync/notifier.go | 10 +- syncapi/sync/notifier_test.go | 70 ++---- syncapi/sync/request.go | 26 +-- syncapi/sync/requestpool.go | 6 +- syncapi/sync/userstream.go | 12 +- syncapi/types/types.go | 293 ++++++++++++++++--------- syncapi/types/types_test.go | 34 +-- 15 files changed, 457 insertions(+), 439 deletions(-) diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index f5b8c43ec..b65d01a04 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -90,7 +90,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error }).Panicf("could not save account data") } - s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.PaginationToken{PDUPosition: pduPos}) + s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.NewStreamToken(pduPos, 0)) return nil } diff --git a/syncapi/consumers/eduserver.go b/syncapi/consumers/eduserver.go index 249452af5..ece999d59 100644 --- a/syncapi/consumers/eduserver.go +++ b/syncapi/consumers/eduserver.go @@ -65,9 +65,7 @@ func (s *OutputTypingEventConsumer) Start() error { s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) { s.notifier.OnNewEvent( nil, roomID, nil, - types.PaginationToken{ - EDUTypingPosition: types.StreamPosition(latestSyncPosition), - }, + types.NewStreamToken(0, types.StreamPosition(latestSyncPosition)), ) }) @@ -96,6 +94,6 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID) } - s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.PaginationToken{EDUTypingPosition: typingPos}) + s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.NewStreamToken(0, typingPos)) return nil } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 987cc5df6..368420a6c 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -146,7 +146,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( }).Panicf("roomserver output log: write event failure") return nil } - s.notifier.OnNewEvent(&ev, "", nil, types.PaginationToken{PDUPosition: pduPos}) + s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0)) return nil } @@ -164,7 +164,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( }).Panicf("roomserver output log: write invite failure") return nil } - s.notifier.OnNewEvent(&msg.Event, "", nil, types.PaginationToken{PDUPosition: pduPos}) + s.notifier.OnNewEvent(&msg.Event, "", nil, types.NewStreamToken(pduPos, 0)) return nil } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 270b0ee95..72c306d4f 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -38,8 +38,9 @@ type messagesReq struct { federation *gomatrixserverlib.FederationClient cfg *config.Dendrite roomID string - from *types.PaginationToken - to *types.PaginationToken + from *types.TopologyToken + to *types.TopologyToken + fromStream *types.StreamingToken wasToProvided bool limit int backwardOrdering bool @@ -66,11 +67,16 @@ func OnIncomingMessagesRequest( // Extract parameters from the request's URL. // Pagination tokens. - from, err := types.NewPaginationTokenFromString(req.URL.Query().Get("from")) + var fromStream *types.StreamingToken + from, err := types.NewTopologyTokenFromString(req.URL.Query().Get("from")) if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err.Error()), + fs, err2 := types.NewStreamTokenFromString(req.URL.Query().Get("from")) + fromStream = &fs + if err2 != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err2.Error()), + } } } @@ -88,10 +94,10 @@ func OnIncomingMessagesRequest( // Pagination tokens. To is optional, and its default value depends on the // direction ("b" or "f"). - var to *types.PaginationToken + var to types.TopologyToken wasToProvided := true if s := req.URL.Query().Get("to"); len(s) > 0 { - to, err = types.NewPaginationTokenFromString(s) + to, err = types.NewTopologyTokenFromString(s) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -139,8 +145,9 @@ func OnIncomingMessagesRequest( federation: federation, cfg: cfg, roomID: roomID, - from: from, - to: to, + from: &from, + to: &to, + fromStream: fromStream, wasToProvided: wasToProvided, limit: limit, backwardOrdering: backwardOrdering, @@ -178,12 +185,20 @@ func OnIncomingMessagesRequest( // remote homeserver. func (r *messagesReq) retrieveEvents() ( clientEvents []gomatrixserverlib.ClientEvent, start, - end *types.PaginationToken, err error, + end types.TopologyToken, err error, ) { // Retrieve the events from the local database. - streamEvents, err := r.db.GetEventsInRange( - r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering, - ) + var streamEvents []types.StreamEvent + if r.fromStream != nil { + toStream := r.to.StreamToken() + streamEvents, err = r.db.GetEventsInStreamingRange( + r.ctx, r.fromStream, &toStream, r.roomID, r.limit, r.backwardOrdering, + ) + } else { + streamEvents, err = r.db.GetEventsInTopologicalRange( + r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering, + ) + } if err != nil { err = fmt.Errorf("GetEventsInRange: %w", err) return @@ -206,7 +221,7 @@ func (r *messagesReq) retrieveEvents() ( // If we didn't get any event, we don't need to proceed any further. if len(events) == 0 { - return []gomatrixserverlib.ClientEvent{}, r.from, r.to, nil + return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil } // Sort the events to ensure we send them in the right order. @@ -246,12 +261,8 @@ func (r *messagesReq) retrieveEvents() ( } // Generate pagination tokens to send to the client using the positions // retrieved previously. - start = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, startPos, startStreamPos, - ) - end = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, endPos, endStreamPos, - ) + start = types.NewTopologyToken(startPos, startStreamPos) + end = types.NewTopologyToken(endPos, endStreamPos) if r.backwardOrdering { // A stream/topological position is a cursor located between two events. @@ -259,14 +270,7 @@ func (r *messagesReq) retrieveEvents() ( // we consider a left to right chronological order), tokens need to refer // to them by the event on their left, therefore we need to decrement the // end position we send in the response if we're going backward. - end.PDUPosition-- - end.EDUTypingPosition += 1000 - } - - // The lowest token value is 1, therefore we need to manually set it to that - // value if we're below it. - if end.PDUPosition < types.StreamPosition(1) { - end.PDUPosition = types.StreamPosition(1) + end.Decrement() } return clientEvents, start, end, err @@ -317,11 +321,11 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent // The condition in the SQL query is a strict "greater than" so // we need to check against to-1. streamPos := types.StreamPosition(streamEvents[len(streamEvents)-1].StreamPosition) - isSetLargeEnough = (r.to.PDUPosition-1 == streamPos) + isSetLargeEnough = (r.to.PDUPosition()-1 == streamPos) } } else { streamPos := types.StreamPosition(streamEvents[0].StreamPosition) - isSetLargeEnough = (r.from.PDUPosition-1 == streamPos) + isSetLargeEnough = (r.from.PDUPosition()-1 == streamPos) } } @@ -424,18 +428,17 @@ func (r *messagesReq) backfill(roomID string, fromEventIDs []string, limit int) func setToDefault( ctx context.Context, db storage.Database, backwardOrdering bool, roomID string, -) (to *types.PaginationToken, err error) { +) (to types.TopologyToken, err error) { if backwardOrdering { // go 1 earlier than the first event so we correctly fetch the earliest event - to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) + to = types.NewTopologyToken(0, 0) } else { - var pos, stream types.StreamPosition - pos, stream, err = db.MaxTopologicalPosition(ctx, roomID) + var depth, stream types.StreamPosition + depth, stream, err = db.MaxTopologicalPosition(ctx, roomID) if err != nil { return } - - to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, pos, stream) + to = types.NewTopologyToken(depth, stream) } return diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 7d6376438..63af11365 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -50,13 +50,13 @@ type Database interface { // Returns an error if there was an issue with the retrieval. GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) // SyncPosition returns the latest positions for syncing. - SyncPosition(ctx context.Context) (types.PaginationToken, error) + SyncPosition(ctx context.Context) (types.StreamingToken, error) // IncrementalSync returns all the data needed in order to create an incremental // sync response for the given user. Events returned will include any client // transaction IDs associated with the given device. These transaction IDs come // from when the device sent the event via an API that included a transaction // ID. - IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.PaginationToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) + IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) // CompleteSync returns a complete /sync API response for the given user. CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error) // GetAccountDataInRange returns all account data for a given user inserted or @@ -88,9 +88,10 @@ type Database interface { // RemoveTypingUser removes a typing user from the typing cache. // Returns the newly calculated sync position for typing notifications. RemoveTypingUser(userID, roomID string) types.StreamPosition - // GetEventsInRange retrieves all of the events on a given ordering using the - // given extremities and limit. - GetEventsInRange(ctx context.Context, from, to *types.PaginationToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) + // GetEventsInStreamingRange retrieves all of the events on a given ordering using the given extremities and limit. + GetEventsInStreamingRange(ctx context.Context, from, to *types.StreamingToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) + // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. + GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) // EventPositionInTopology returns the depth and stream position of the given event. EventPositionInTopology(ctx context.Context, eventID string) (depth types.StreamPosition, stream types.StreamPosition, err error) // EventsAtTopologicalPosition returns all of the events matching a given diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 1845ac386..d45bc09e5 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -228,69 +228,68 @@ func (d *SyncServerDatasource) GetStateEventsForRoom( return } -func (d *SyncServerDatasource) GetEventsInRange( +func (d *SyncServerDatasource) GetEventsInTopologicalRange( ctx context.Context, - from, to *types.PaginationToken, + from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool, ) (events []types.StreamEvent, err error) { - // If the pagination token's type is types.PaginationTokenTypeTopology, the - // events must be retrieved from the rooms' topology table rather than the - // table contaning the syncapi server's whole stream of events. - if from.Type == types.PaginationTokenTypeTopology { - // Determine the backward and forward limit, i.e. the upper and lower - // limits to the selection in the room's topology, from the direction. - var backwardLimit, forwardLimit, forwardMicroLimit types.StreamPosition - if backwardOrdering { - // Backward ordering is antichronological (latest event to oldest - // one). - backwardLimit = to.PDUPosition - forwardLimit = from.PDUPosition - forwardMicroLimit = from.EDUTypingPosition - } else { - // Forward ordering is chronological (oldest event to latest one). - backwardLimit = from.PDUPosition - forwardLimit = to.PDUPosition - } + // Determine the backward and forward limit, i.e. the upper and lower + // limits to the selection in the room's topology, from the direction. + var backwardLimit, forwardLimit, forwardMicroLimit types.StreamPosition + if backwardOrdering { + // Backward ordering is antichronological (latest event to oldest + // one). + backwardLimit = to.Depth() + forwardLimit = from.Depth() + forwardMicroLimit = from.PDUPosition() + } else { + // Forward ordering is chronological (oldest event to latest one). + backwardLimit = from.Depth() + forwardLimit = to.Depth() + } - // Select the event IDs from the defined range. - var eIDs []string - eIDs, err = d.topology.selectEventIDsInRange( - ctx, roomID, backwardLimit, forwardLimit, forwardMicroLimit, limit, !backwardOrdering, - ) - if err != nil { - return - } - - // Retrieve the events' contents using their IDs. - events, err = d.events.selectEvents(ctx, nil, eIDs) + // Select the event IDs from the defined range. + var eIDs []string + eIDs, err = d.topology.selectEventIDsInRange( + ctx, roomID, backwardLimit, forwardLimit, forwardMicroLimit, limit, !backwardOrdering, + ) + if err != nil { return } - // If the pagination token's type is types.PaginationTokenTypeStream, the - // events must be retrieved from the table contaning the syncapi server's - // whole stream of events. + // Retrieve the events' contents using their IDs. + events, err = d.events.selectEvents(ctx, nil, eIDs) + return +} +// GetEventsInStreamingRange retrieves all of the events on a given ordering using the +// given extremities and limit. +func (d *SyncServerDatasource) GetEventsInStreamingRange( + ctx context.Context, + from, to *types.StreamingToken, + roomID string, limit int, + backwardOrdering bool, +) (events []types.StreamEvent, err error) { if backwardOrdering { // When using backward ordering, we want the most recent events first. if events, err = d.events.selectRecentEvents( - ctx, nil, roomID, to.PDUPosition, from.PDUPosition, limit, false, false, + ctx, nil, roomID, to.PDUPosition(), from.PDUPosition(), limit, false, false, ); err != nil { return } } else { // When using forward ordering, we want the least recent events first. if events, err = d.events.selectEarlyEvents( - ctx, nil, roomID, from.PDUPosition, to.PDUPosition, limit, + ctx, nil, roomID, from.PDUPosition(), to.PDUPosition(), limit, ); err != nil { return } } - - return + return events, err } -func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.PaginationToken, error) { +func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.StreamingToken, error) { return d.syncPositionTx(ctx, nil) } @@ -353,7 +352,7 @@ func (d *SyncServerDatasource) syncStreamPositionTx( func (d *SyncServerDatasource) syncPositionTx( ctx context.Context, txn *sql.Tx, -) (sp types.PaginationToken, err error) { +) (sp types.StreamingToken, err error) { maxEventID, err := d.events.selectMaxEventID(ctx, txn) if err != nil { @@ -373,8 +372,7 @@ func (d *SyncServerDatasource) syncPositionTx( if maxInviteID > maxEventID { maxEventID = maxInviteID } - sp.PDUPosition = types.StreamPosition(maxEventID) - sp.EDUTypingPosition = types.StreamPosition(d.eduCache.GetLatestSyncPosition()) + sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.eduCache.GetLatestSyncPosition())) return } @@ -439,7 +437,7 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse( // addTypingDeltaToResponse adds all typing notifications to a sync response // since the specified position. func (d *SyncServerDatasource) addTypingDeltaToResponse( - since types.PaginationToken, + since types.StreamingToken, joinedRoomIDs []string, res *types.Response, ) error { @@ -448,7 +446,7 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse( var err error for _, roomID := range joinedRoomIDs { if typingUsers, updated := d.eduCache.GetTypingUsersIfUpdatedAfter( - roomID, int64(since.EDUTypingPosition), + roomID, int64(since.EDUPosition()), ); updated { ev := gomatrixserverlib.ClientEvent{ Type: gomatrixserverlib.MTyping, @@ -473,12 +471,12 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse( // addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if // the positions of that type are not equal in fromPos and toPos. func (d *SyncServerDatasource) addEDUDeltaToResponse( - fromPos, toPos types.PaginationToken, + fromPos, toPos types.StreamingToken, joinedRoomIDs []string, res *types.Response, ) (err error) { - if fromPos.EDUTypingPosition != toPos.EDUTypingPosition { + if fromPos.EDUPosition() != toPos.EDUPosition() { err = d.addTypingDeltaToResponse( fromPos, joinedRoomIDs, res, ) @@ -490,7 +488,7 @@ func (d *SyncServerDatasource) addEDUDeltaToResponse( func (d *SyncServerDatasource) IncrementalSync( ctx context.Context, device authtypes.Device, - fromPos, toPos types.PaginationToken, + fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool, ) (*types.Response, error) { @@ -499,9 +497,9 @@ func (d *SyncServerDatasource) IncrementalSync( var joinedRoomIDs []string var err error - if fromPos.PDUPosition != toPos.PDUPosition || wantFullState { + if fromPos.PDUPosition() != toPos.PDUPosition() || wantFullState { joinedRoomIDs, err = d.addPDUDeltaToResponse( - ctx, device, fromPos.PDUPosition, toPos.PDUPosition, numRecentEventsPerRoom, wantFullState, res, + ctx, device, fromPos.PDUPosition(), toPos.PDUPosition(), numRecentEventsPerRoom, wantFullState, res, ) } else { joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership( @@ -530,7 +528,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( numRecentEventsPerRoom int, ) ( res *types.Response, - toPos types.PaginationToken, + toPos types.StreamingToken, joinedRoomIDs []string, err error, ) { @@ -577,7 +575,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 var recentStreamEvents []types.StreamEvent recentStreamEvents, err = d.events.selectRecentEvents( - ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition, + ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition(), numRecentEventsPerRoom, true, true, ) if err != nil { @@ -588,27 +586,25 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( // oldest event in the room's topology. var backwardTopologyPos, backwardStreamPos types.StreamPosition backwardTopologyPos, backwardStreamPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID()) - if backwardTopologyPos-1 <= 0 { - backwardTopologyPos = types.StreamPosition(1) - } else { - backwardTopologyPos-- + if err != nil { + return } + prevBatch := types.NewTopologyToken(backwardTopologyPos, backwardStreamPos) + prevBatch.Decrement() // We don't include a device here as we don't need to send down // transaction IDs for complete syncs recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents) stateEvents = removeDuplicates(stateEvents, recentEvents) jr := types.NewJoinResponse() - jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos, - ).String() + jr.Timeline.PrevBatch = prevBatch.String() jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = true jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync) res.Rooms.Join[roomID] = *jr } - if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil { + if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition(), res); err != nil { return } @@ -628,7 +624,7 @@ func (d *SyncServerDatasource) CompleteSync( // Use a zero value SyncPosition for fromPos so all EDU states are added. err = d.addEDUDeltaToResponse( - types.PaginationToken{}, toPos, joinedRoomIDs, res, + types.NewStreamToken(0, 0), toPos, joinedRoomIDs, res, ) if err != nil { return nil, err @@ -757,14 +753,15 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back backwardTopologyPos, backwardStreamPos := d.getBackwardTopologyPos(ctx, recentStreamEvents) + prevBatch := types.NewTopologyToken( + backwardTopologyPos, backwardStreamPos, + ) switch delta.membership { case gomatrixserverlib.Join: jr := types.NewJoinResponse() - jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos, - ).String() + jr.Timeline.PrevBatch = prevBatch.String() jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) @@ -775,9 +772,7 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( // TODO: recentEvents may contain events that this user is not allowed to see because they are // no longer in the room. lr := types.NewLeaveResponse() - lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos, - ).String() + lr.Timeline.PrevBatch = prevBatch.String() lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 314ea2aa3..212f882b1 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -269,63 +269,63 @@ func (d *SyncServerDatasource) GetStateEventsForRoom( return } -// GetEventsInRange retrieves all of the events on a given ordering using the +// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the // given extremities and limit. -func (d *SyncServerDatasource) GetEventsInRange( +func (d *SyncServerDatasource) GetEventsInTopologicalRange( ctx context.Context, - from, to *types.PaginationToken, + from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool, ) (events []types.StreamEvent, err error) { - // If the pagination token's type is types.PaginationTokenTypeTopology, the - // events must be retrieved from the rooms' topology table rather than the - // table contaning the syncapi server's whole stream of events. - if from.Type == types.PaginationTokenTypeTopology { - // TODO: ARGH CONFUSING - // Determine the backward and forward limit, i.e. the upper and lower - // limits to the selection in the room's topology, from the direction. - var backwardLimit, forwardLimit, forwardMicroLimit types.StreamPosition - if backwardOrdering { - // Backward ordering is antichronological (latest event to oldest - // one). - backwardLimit = to.PDUPosition - forwardLimit = from.PDUPosition - forwardMicroLimit = from.EDUTypingPosition - } else { - // Forward ordering is chronological (oldest event to latest one). - backwardLimit = from.PDUPosition - forwardLimit = to.PDUPosition - } + // TODO: ARGH CONFUSING + // Determine the backward and forward limit, i.e. the upper and lower + // limits to the selection in the room's topology, from the direction. + var backwardLimit, forwardLimit, forwardMicroLimit types.StreamPosition + if backwardOrdering { + // Backward ordering is antichronological (latest event to oldest + // one). + backwardLimit = to.Depth() + forwardLimit = from.Depth() + forwardMicroLimit = from.PDUPosition() + } else { + // Forward ordering is chronological (oldest event to latest one). + backwardLimit = from.Depth() + forwardLimit = to.Depth() + } - // Select the event IDs from the defined range. - var eIDs []string - eIDs, err = d.topology.selectEventIDsInRange( - ctx, nil, roomID, backwardLimit, forwardLimit, forwardMicroLimit, limit, !backwardOrdering, - ) - if err != nil { - return - } - - // Retrieve the events' contents using their IDs. - events, err = d.events.selectEvents(ctx, nil, eIDs) + // Select the event IDs from the defined range. + var eIDs []string + eIDs, err = d.topology.selectEventIDsInRange( + ctx, nil, roomID, backwardLimit, forwardLimit, forwardMicroLimit, limit, !backwardOrdering, + ) + if err != nil { return } - // If the pagination token's type is types.PaginationTokenTypeStream, the - // events must be retrieved from the table contaning the syncapi server's - // whole stream of events. + // Retrieve the events' contents using their IDs. + events, err = d.events.selectEvents(ctx, nil, eIDs) + return +} +// GetEventsInStreamingRange retrieves all of the events on a given ordering using the +// given extremities and limit. +func (d *SyncServerDatasource) GetEventsInStreamingRange( + ctx context.Context, + from, to *types.StreamingToken, + roomID string, limit int, + backwardOrdering bool, +) (events []types.StreamEvent, err error) { if backwardOrdering { // When using backward ordering, we want the most recent events first. if events, err = d.events.selectRecentEvents( - ctx, nil, roomID, to.PDUPosition, from.PDUPosition, limit, false, false, + ctx, nil, roomID, to.PDUPosition(), from.PDUPosition(), limit, false, false, ); err != nil { return } } else { // When using forward ordering, we want the least recent events first. if events, err = d.events.selectEarlyEvents( - ctx, nil, roomID, from.PDUPosition, to.PDUPosition, limit, + ctx, nil, roomID, from.PDUPosition(), to.PDUPosition(), limit, ); err != nil { return } @@ -334,10 +334,14 @@ func (d *SyncServerDatasource) GetEventsInRange( } // SyncPosition returns the latest positions for syncing. -func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (tok types.PaginationToken, err error) { +func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (tok types.StreamingToken, err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - tok, err = d.syncPositionTx(ctx, txn) - return err + pos, err := d.syncPositionTx(ctx, txn) + if err != nil { + return err + } + tok = *pos + return nil }) return } @@ -412,30 +416,31 @@ func (d *SyncServerDatasource) syncStreamPositionTx( func (d *SyncServerDatasource) syncPositionTx( ctx context.Context, txn *sql.Tx, -) (sp types.PaginationToken, err error) { +) (*types.StreamingToken, error) { maxEventID, err := d.events.selectMaxEventID(ctx, txn) if err != nil { - return sp, err + return nil, err } maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) if err != nil { - return sp, err + return nil, err } if maxAccountDataID > maxEventID { maxEventID = maxAccountDataID } maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) if err != nil { - return sp, err + return nil, err } if maxInviteID > maxEventID { maxEventID = maxInviteID } - sp.PDUPosition = types.StreamPosition(maxEventID) - sp.EDUTypingPosition = types.StreamPosition(d.eduCache.GetLatestSyncPosition()) - sp.Type = types.PaginationTokenTypeStream - return + sp := types.NewStreamToken( + types.StreamPosition(maxEventID), + types.StreamPosition(d.eduCache.GetLatestSyncPosition()), + ) + return &sp, nil } // addPDUDeltaToResponse adds all PDU deltas to a sync response. @@ -499,7 +504,7 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse( // addTypingDeltaToResponse adds all typing notifications to a sync response // since the specified position. func (d *SyncServerDatasource) addTypingDeltaToResponse( - since types.PaginationToken, + since types.StreamingToken, joinedRoomIDs []string, res *types.Response, ) error { @@ -508,7 +513,7 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse( var err error for _, roomID := range joinedRoomIDs { if typingUsers, updated := d.eduCache.GetTypingUsersIfUpdatedAfter( - roomID, int64(since.EDUTypingPosition), + roomID, int64(since.EDUPosition()), ); updated { ev := gomatrixserverlib.ClientEvent{ Type: gomatrixserverlib.MTyping, @@ -533,12 +538,12 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse( // addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if // the positions of that type are not equal in fromPos and toPos. func (d *SyncServerDatasource) addEDUDeltaToResponse( - fromPos, toPos types.PaginationToken, + fromPos, toPos types.StreamingToken, joinedRoomIDs []string, res *types.Response, ) (err error) { - if fromPos.EDUTypingPosition != toPos.EDUTypingPosition { + if fromPos.EDUPosition() != toPos.EDUPosition() { err = d.addTypingDeltaToResponse( fromPos, joinedRoomIDs, res, ) @@ -555,18 +560,21 @@ func (d *SyncServerDatasource) addEDUDeltaToResponse( func (d *SyncServerDatasource) IncrementalSync( ctx context.Context, device authtypes.Device, - fromPos, toPos types.PaginationToken, + fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool, ) (*types.Response, error) { + fmt.Println("from ", fromPos, "to", toPos) nextBatchPos := fromPos.WithUpdates(toPos) res := types.NewResponse(nextBatchPos) + fmt.Println("from ", fromPos, "to", toPos, "next", nextBatchPos) var joinedRoomIDs []string var err error - if fromPos.PDUPosition != toPos.PDUPosition || wantFullState { + fmt.Println("from", fromPos.PDUPosition(), "to", toPos.PDUPosition()) + if fromPos.PDUPosition() != toPos.PDUPosition() || wantFullState { joinedRoomIDs, err = d.addPDUDeltaToResponse( - ctx, device, fromPos.PDUPosition, toPos.PDUPosition, numRecentEventsPerRoom, wantFullState, res, + ctx, device, fromPos.PDUPosition(), toPos.PDUPosition(), numRecentEventsPerRoom, wantFullState, res, ) } else { joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership( @@ -595,7 +603,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( numRecentEventsPerRoom int, ) ( res *types.Response, - toPos types.PaginationToken, + toPos *types.StreamingToken, joinedRoomIDs []string, err error, ) { @@ -621,7 +629,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( return } - res = types.NewResponse(toPos) + res = types.NewResponse(*toPos) // Extract room state and recent events for all rooms the user is joined to. joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) @@ -643,7 +651,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 var recentStreamEvents []types.StreamEvent recentStreamEvents, err = d.events.selectRecentEvents( - ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition, + ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition(), numRecentEventsPerRoom, true, true, ) if err != nil { @@ -655,28 +663,22 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( // oldest event in the room's topology. var backwardTopologyPos, backwardTopologyStreamPos types.StreamPosition backwardTopologyPos, backwardTopologyStreamPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID()) - if backwardTopologyPos-1 <= 0 { - backwardTopologyPos = types.StreamPosition(1) - } else { - backwardTopologyPos-- - backwardTopologyStreamPos += 1000 // this has to be bigger than the number of events we backfill per request - } + prevBatch := types.NewTopologyToken(backwardTopologyPos, backwardTopologyStreamPos) + prevBatch.Decrement() // We don't include a device here as we don't need to send down // transaction IDs for complete syncs recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents) stateEvents = removeDuplicates(stateEvents, recentEvents) jr := types.NewJoinResponse() - jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, backwardTopologyStreamPos, - ).String() + jr.Timeline.PrevBatch = prevBatch.String() jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = true jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync) res.Rooms.Join[roomID] = *jr } - if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil { + if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition(), res); err != nil { return } @@ -697,7 +699,7 @@ func (d *SyncServerDatasource) CompleteSync( // Use a zero value SyncPosition for fromPos so all EDU states are added. err = d.addEDUDeltaToResponse( - types.PaginationToken{}, toPos, joinedRoomIDs, res, + types.NewStreamToken(0, 0), *toPos, joinedRoomIDs, res, ) if err != nil { return nil, err @@ -860,14 +862,14 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) backwardTopologyPos, backwardStreamPos := d.getBackwardTopologyPos(ctx, txn, recentStreamEvents) + prevBatch := types.NewTopologyToken( + backwardTopologyPos, backwardStreamPos, + ) switch delta.membership { case gomatrixserverlib.Join: jr := types.NewJoinResponse() - - jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos, - ).String() + jr.Timeline.PrevBatch = prevBatch.String() jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) @@ -878,9 +880,7 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( // TODO: recentEvents may contain events that this user is not allowed to see because they are // no longer in the room. lr := types.NewLeaveResponse() - lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos, - ).String() + lr.Timeline.PrevBatch = prevBatch.String() lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index b951efa45..f7fa1a870 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -154,10 +154,10 @@ func TestSyncResponse(t *testing.T) { { Name: "IncrementalSync penultimate", DoSync: func() (*types.Response, error) { - from := types.NewPaginationTokenFromTypeAndPosition( // pretend we are at the penultimate event - types.PaginationTokenTypeStream, positions[len(positions)-2], types.StreamPosition(0), + from := types.NewStreamToken( // pretend we are at the penultimate event + positions[len(positions)-2], types.StreamPosition(0), ) - return db.IncrementalSync(ctx, testUserDeviceA, *from, latest, 5, false) + return db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false) }, WantTimeline: events[len(events)-1:], }, @@ -166,11 +166,11 @@ func TestSyncResponse(t *testing.T) { { Name: "IncrementalSync limited", DoSync: func() (*types.Response, error) { - from := types.NewPaginationTokenFromTypeAndPosition( // pretend we are 10 events behind - types.PaginationTokenTypeStream, positions[len(positions)-11], types.StreamPosition(0), + from := types.NewStreamToken( // pretend we are 10 events behind + positions[len(positions)-11], types.StreamPosition(0), ) // limit is set to 5 - return db.IncrementalSync(ctx, testUserDeviceA, *from, latest, 5, false) + return db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false) }, // want the last 5 events, NOT the last 10. WantTimeline: events[len(events)-5:], @@ -207,7 +207,7 @@ func TestSyncResponse(t *testing.T) { if err != nil { st.Fatalf("failed to do sync: %s", err) } - next := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeStream, latest.PDUPosition, latest.EDUTypingPosition) + next := types.NewStreamToken(latest.PDUPosition(), latest.EDUPosition()) if res.NextBatch != next.String() { st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String()) } @@ -230,11 +230,11 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) { if err != nil { t.Fatalf("failed to get SyncPosition: %s", err) } - from := types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeStream, positions[len(positions)-2], types.StreamPosition(0), + from := types.NewStreamToken( + positions[len(positions)-2], types.StreamPosition(0), ) - res, err := db.IncrementalSync(ctx, testUserDeviceA, *from, latest, 5, false) + res, err := db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false) if err != nil { t.Fatalf("failed to IncrementalSync with latest token") } @@ -249,14 +249,14 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) { if prev == "" { t.Fatalf("IncrementalSync expected prev_batch token") } - prevBatchToken, err := types.NewPaginationTokenFromString(prev) + prevBatchToken, err := types.NewTopologyTokenFromString(prev) if err != nil { - t.Fatalf("failed to NewPaginationTokenFromString : %s", err) + t.Fatalf("failed to NewTopologyTokenFromString : %s", err) } // backpaginate 5 messages starting at the latest position. // head towards the beginning of time - to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) - paginatedEvents, err := db.GetEventsInRange(ctx, prevBatchToken, to, testRoomID, 5, true) + to := types.NewTopologyToken(0, 0) + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true) if err != nil { t.Fatalf("GetEventsInRange returned an error: %s", err) } @@ -275,10 +275,10 @@ func TestGetEventsInRangeWithStreamToken(t *testing.T) { t.Fatalf("failed to get SyncPosition: %s", err) } // head towards the beginning of time - to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) + to := types.NewStreamToken(0, 0) // backpaginate 5 messages starting at the latest position. - paginatedEvents, err := db.GetEventsInRange(ctx, &latest, to, testRoomID, 5, true) + paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true) if err != nil { t.Fatalf("GetEventsInRange returned an error: %s", err) } @@ -296,12 +296,12 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { if err != nil { t.Fatalf("failed to get MaxTopologicalPosition: %s", err) } - from := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, latest, latestStream) + from := types.NewTopologyToken(latest, latestStream) // head towards the beginning of time - to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) + to := types.NewTopologyToken(0, 0) // backpaginate 5 messages starting at the latest position. - paginatedEvents, err := db.GetEventsInRange(ctx, from, to, testRoomID, 5, true) + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true) if err != nil { t.Fatalf("GetEventsInRange returned an error: %s", err) } @@ -366,14 +366,14 @@ func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) { if err != nil { t.Fatalf("failed to get EventPositionInTopology for event: %s", err) } - fromLatest := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, latestPos, latestStreamPos) - fromFork := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, topoPos, streamPos) + fromLatest := types.NewTopologyToken(latestPos, latestStreamPos) + fromFork := types.NewTopologyToken(topoPos, streamPos) // head towards the beginning of time - to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) + to := types.NewTopologyToken(0, 0) testCases := []struct { Name string - From *types.PaginationToken + From types.TopologyToken Limit int Wants []gomatrixserverlib.HeaderedEvent }{ @@ -399,7 +399,7 @@ func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) { for _, tc := range testCases { // backpaginate messages starting at the latest position. - paginatedEvents, err := db.GetEventsInRange(ctx, tc.From, to, testRoomID, tc.Limit, true) + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &tc.From, &to, testRoomID, tc.Limit, true) if err != nil { t.Fatalf("%s GetEventsInRange returned an error: %s", tc.Name, err) } @@ -446,13 +446,13 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) { } // head towards the beginning of time - to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) + to := types.NewTopologyToken(0, 0) // starting at `from`, backpaginate to the beginning of time, asserting as we go. chunkSize = 3 events = reversed(events) for i := 0; i < len(events); i += chunkSize { - paginatedEvents, err := db.GetEventsInRange(ctx, from, to, testRoomID, chunkSize, true) + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, from, &to, testRoomID, chunkSize, true) if err != nil { t.Fatalf("GetEventsInRange returned an error: %s", err) } @@ -506,19 +506,15 @@ func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatr } } -func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *types.PaginationToken { +func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *types.TopologyToken { pos, spos, err := db.EventPositionInTopology(ctx, eventID) if err != nil { t.Fatalf("failed to get EventPositionInTopology: %s", err) } - if pos-1 <= 0 { - pos = types.StreamPosition(1) - } else { - pos = pos - 1 - spos += 1000 // this has to be bigger than the chunk limit - } - return types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, pos, spos) + tok := types.NewTopologyToken(pos, spos) + tok.Decrement() + return &tok } func reversed(in []gomatrixserverlib.HeaderedEvent) []gomatrixserverlib.HeaderedEvent { diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index 0d8050112..b3ed5cd03 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -36,7 +36,7 @@ type Notifier struct { // Protects currPos and userStreams. streamLock *sync.Mutex // The latest sync position - currPos types.PaginationToken + currPos types.StreamingToken // A map of user_id => UserStream which can be used to wake a given user's /sync request. userStreams map[string]*UserStream // The last time we cleaned out stale entries from the userStreams map @@ -46,7 +46,7 @@ type Notifier struct { // NewNotifier creates a new notifier set to the given sync position. // In order for this to be of any use, the Notifier needs to be told all rooms and // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). -func NewNotifier(pos types.PaginationToken) *Notifier { +func NewNotifier(pos types.StreamingToken) *Notifier { return &Notifier{ currPos: pos, roomIDToJoinedUsers: make(map[string]userIDSet), @@ -68,7 +68,7 @@ func NewNotifier(pos types.PaginationToken) *Notifier { // event type it handles, leaving other fields as 0. func (n *Notifier) OnNewEvent( ev *gomatrixserverlib.HeaderedEvent, roomID string, userIDs []string, - posUpdate types.PaginationToken, + posUpdate types.StreamingToken, ) { // update the current position then notify relevant /sync streams. // This needs to be done PRIOR to waking up users as they will read this value. @@ -151,7 +151,7 @@ func (n *Notifier) Load(ctx context.Context, db storage.Database) error { } // CurrentPosition returns the current sync position -func (n *Notifier) CurrentPosition() types.PaginationToken { +func (n *Notifier) CurrentPosition() types.StreamingToken { n.streamLock.Lock() defer n.streamLock.Unlock() @@ -173,7 +173,7 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) { } } -func (n *Notifier) wakeupUsers(userIDs []string, newPos types.PaginationToken) { +func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) { for _, userID := range userIDs { stream := n.fetchUserStream(userID, false) if stream != nil { diff --git a/syncapi/sync/notifier_test.go b/syncapi/sync/notifier_test.go index 350d757c6..7d979fcc9 100644 --- a/syncapi/sync/notifier_test.go +++ b/syncapi/sync/notifier_test.go @@ -33,11 +33,11 @@ var ( randomMessageEvent gomatrixserverlib.HeaderedEvent aliceInviteBobEvent gomatrixserverlib.HeaderedEvent bobLeaveEvent gomatrixserverlib.HeaderedEvent - syncPositionVeryOld types.PaginationToken - syncPositionBefore types.PaginationToken - syncPositionAfter types.PaginationToken - syncPositionNewEDU types.PaginationToken - syncPositionAfter2 types.PaginationToken + syncPositionVeryOld = types.NewStreamToken(5, 0) + syncPositionBefore = types.NewStreamToken(11, 0) + syncPositionAfter = types.NewStreamToken(12, 0) + syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition(), 1) + syncPositionAfter2 = types.NewStreamToken(13, 0) ) var ( @@ -47,26 +47,6 @@ var ( ) func init() { - baseSyncPos := types.PaginationToken{ - PDUPosition: 0, - EDUTypingPosition: 0, - } - - syncPositionVeryOld = baseSyncPos - syncPositionVeryOld.PDUPosition = 5 - - syncPositionBefore = baseSyncPos - syncPositionBefore.PDUPosition = 11 - - syncPositionAfter = baseSyncPos - syncPositionAfter.PDUPosition = 12 - - syncPositionNewEDU = syncPositionAfter - syncPositionNewEDU.EDUTypingPosition = 1 - - syncPositionAfter2 = baseSyncPos - syncPositionAfter2.PDUPosition = 13 - var err error err = json.Unmarshal([]byte(`{ "_room_version": "1", @@ -118,6 +98,12 @@ func init() { } } +func mustEqualPositions(t *testing.T, got, want types.StreamingToken) { + if got.String() != want.String() { + t.Fatalf("mustEqualPositions got %s want %s", got.String(), want.String()) + } +} + // Test that the current position is returned if a request is already behind. func TestImmediateNotification(t *testing.T) { n := NewNotifier(syncPositionBefore) @@ -125,9 +111,7 @@ func TestImmediateNotification(t *testing.T) { if err != nil { t.Fatalf("TestImmediateNotification error: %s", err) } - if pos != syncPositionBefore { - t.Fatalf("TestImmediateNotification want %v, got %v", syncPositionBefore, pos) - } + mustEqualPositions(t, pos, syncPositionBefore) } // Test that new events to a joined room unblocks the request. @@ -144,9 +128,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { if err != nil { t.Errorf("TestNewEventAndJoinedToRoom error: %w", err) } - if pos != syncPositionAfter { - t.Errorf("TestNewEventAndJoinedToRoom want %v, got %v", syncPositionAfter, pos) - } + mustEqualPositions(t, pos, syncPositionAfter) wg.Done() }() @@ -172,9 +154,7 @@ func TestNewInviteEventForUser(t *testing.T) { if err != nil { t.Errorf("TestNewInviteEventForUser error: %w", err) } - if pos != syncPositionAfter { - t.Errorf("TestNewInviteEventForUser want %v, got %v", syncPositionAfter, pos) - } + mustEqualPositions(t, pos, syncPositionAfter) wg.Done() }() @@ -200,9 +180,7 @@ func TestEDUWakeup(t *testing.T) { if err != nil { t.Errorf("TestNewInviteEventForUser error: %w", err) } - if pos != syncPositionNewEDU { - t.Errorf("TestNewInviteEventForUser want %v, got %v", syncPositionNewEDU, pos) - } + mustEqualPositions(t, pos, syncPositionNewEDU) wg.Done() }() @@ -228,9 +206,7 @@ func TestMultipleRequestWakeup(t *testing.T) { if err != nil { t.Errorf("TestMultipleRequestWakeup error: %w", err) } - if pos != syncPositionAfter { - t.Errorf("TestMultipleRequestWakeup want %v, got %v", syncPositionAfter, pos) - } + mustEqualPositions(t, pos, syncPositionAfter) wg.Done() } go poll() @@ -268,9 +244,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { if err != nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err) } - if pos != syncPositionAfter { - t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %v, got %v", syncPositionAfter, pos) - } + mustEqualPositions(t, pos, syncPositionAfter) leaveWG.Done() }() bobStream := lockedFetchUserStream(n, bob) @@ -287,9 +261,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { if err != nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err) } - if pos != syncPositionAfter2 { - t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %v, got %v", syncPositionAfter2, pos) - } + mustEqualPositions(t, pos, syncPositionAfter2) aliceWG.Done() }() @@ -312,13 +284,13 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { time.Sleep(1 * time.Millisecond) } -func waitForEvents(n *Notifier, req syncRequest) (types.PaginationToken, error) { +func waitForEvents(n *Notifier, req syncRequest) (types.StreamingToken, error) { listener := n.GetListener(req) defer listener.Close() select { case <-time.After(5 * time.Second): - return types.PaginationToken{}, fmt.Errorf( + return types.StreamingToken{}, fmt.Errorf( "waitForEvents timed out waiting for %s (pos=%v)", req.device.UserID, req.since, ) case <-listener.GetNotifyChannel(*req.since): @@ -344,7 +316,7 @@ func lockedFetchUserStream(n *Notifier, userID string) *UserStream { return n.fetchUserStream(userID, true) } -func newTestSyncRequest(userID string, since types.PaginationToken) syncRequest { +func newTestSyncRequest(userID string, since types.StreamingToken) syncRequest { return syncRequest{ device: authtypes.Device{UserID: userID}, timeout: 1 * time.Minute, diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index f2e199d23..66663cf0a 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -36,7 +36,7 @@ type syncRequest struct { device authtypes.Device limit int timeout time.Duration - since *types.PaginationToken // nil means that no since token was supplied + since *types.StreamingToken // nil means that no since token was supplied wantFullState bool log *log.Entry } @@ -45,9 +45,14 @@ func newSyncRequest(req *http.Request, device authtypes.Device) (*syncRequest, e timeout := getTimeout(req.URL.Query().Get("timeout")) fullState := req.URL.Query().Get("full_state") wantFullState := fullState != "" && fullState != "false" - since, err := getPaginationToken(req.URL.Query().Get("since")) - if err != nil { - return nil, err + var since *types.StreamingToken + sinceStr := req.URL.Query().Get("since") + if sinceStr != "" { + tok, err := types.NewStreamTokenFromString(sinceStr) + if err != nil { + return nil, err + } + since = &tok } // TODO: Additional query params: set_presence, filter return &syncRequest{ @@ -71,16 +76,3 @@ func getTimeout(timeoutMS string) time.Duration { } return time.Duration(i) * time.Millisecond } - -// getSyncStreamPosition tries to parse a 'since' token taken from the API to a -// types.PaginationToken. If the string is empty then (nil, nil) is returned. -// There are two forms of tokens: The full length form containing all PDU and EDU -// positions separated by "_", and the short form containing only the PDU -// position. Short form can be used for, e.g., `prev_batch` tokens. -func getPaginationToken(since string) (*types.PaginationToken, error) { - if since == "" { - return nil, nil - } - - return types.NewPaginationTokenFromString(since) -} diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 69efd8aa8..126e76f5b 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -132,7 +132,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype } } -func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.PaginationToken) (res *types.Response, err error) { +func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) { // TODO: handle ignored users if req.since == nil { res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit) @@ -145,7 +145,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Pagin } accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead - res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition, &accountDataFilter) + res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter) return } @@ -187,7 +187,7 @@ func (rp *RequestPool) appendAccountData( // Sync is not initial, get all account data since the latest sync dataTypes, err := rp.db.GetAccountDataInRange( req.ctx, userID, - types.StreamPosition(req.since.PDUPosition), types.StreamPosition(currentPos), + types.StreamPosition(req.since.PDUPosition()), types.StreamPosition(currentPos), accountDataFilter, ) if err != nil { diff --git a/syncapi/sync/userstream.go b/syncapi/sync/userstream.go index 88867005e..b2eafa3dc 100644 --- a/syncapi/sync/userstream.go +++ b/syncapi/sync/userstream.go @@ -34,7 +34,7 @@ type UserStream struct { // Closed when there is an update. signalChannel chan struct{} // The last sync position that there may have been an update for the user - pos types.PaginationToken + pos types.StreamingToken // The last time when we had some listeners waiting timeOfLastChannel time.Time // The number of listeners waiting @@ -50,7 +50,7 @@ type UserStreamListener struct { } // NewUserStream creates a new user stream -func NewUserStream(userID string, currPos types.PaginationToken) *UserStream { +func NewUserStream(userID string, currPos types.StreamingToken) *UserStream { return &UserStream{ UserID: userID, timeOfLastChannel: time.Now(), @@ -83,7 +83,7 @@ func (s *UserStream) GetListener(ctx context.Context) UserStreamListener { } // Broadcast a new sync position for this user. -func (s *UserStream) Broadcast(pos types.PaginationToken) { +func (s *UserStream) Broadcast(pos types.StreamingToken) { s.lock.Lock() defer s.lock.Unlock() @@ -116,9 +116,9 @@ func (s *UserStream) TimeOfLastNonEmpty() time.Time { return s.timeOfLastChannel } -// GetStreamPosition returns last sync position which the UserStream was +// GetSyncPosition returns last sync position which the UserStream was // notified about -func (s *UserStreamListener) GetSyncPosition() types.PaginationToken { +func (s *UserStreamListener) GetSyncPosition() types.StreamingToken { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() @@ -130,7 +130,7 @@ func (s *UserStreamListener) GetSyncPosition() types.PaginationToken { // sincePos specifies from which point we want to be notified about. If there // has already been an update after sincePos we'll return a closed channel // immediately. -func (s *UserStreamListener) GetNotifyChannel(sincePos types.PaginationToken) <-chan struct{} { +func (s *UserStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-chan struct{} { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() diff --git a/syncapi/types/types.go b/syncapi/types/types.go index c04fe5219..c1b6d7dd5 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -27,19 +27,19 @@ import ( ) var ( - // ErrInvalidPaginationTokenType is returned when an attempt at creating a - // new instance of PaginationToken with an invalid type (i.e. neither "s" + // ErrInvalidSyncTokenType is returned when an attempt at creating a + // new instance of SyncToken with an invalid type (i.e. neither "s" // nor "t"). - ErrInvalidPaginationTokenType = fmt.Errorf("Pagination token has an unknown prefix (should be either s or t)") - // ErrInvalidPaginationTokenLen is returned when the pagination token is an + ErrInvalidSyncTokenType = fmt.Errorf("Sync token has an unknown prefix (should be either s or t)") + // ErrInvalidSyncTokenLen is returned when the pagination token is an // invalid length - ErrInvalidPaginationTokenLen = fmt.Errorf("Pagination token has an invalid length") + ErrInvalidSyncTokenLen = fmt.Errorf("Sync token has an invalid length") ) // StreamPosition represents the offset in the sync stream a client is at. type StreamPosition int64 -// Same as gomatrixserverlib.Event but also has the PDU stream position for this event. +// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event. type StreamEvent struct { gomatrixserverlib.HeaderedEvent StreamPosition StreamPosition @@ -47,118 +47,201 @@ type StreamEvent struct { ExcludeFromSync bool } -// PaginationTokenType represents the type of a pagination token. +// SyncTokenType represents the type of a sync token. // It can be either "s" (representing a position in the whole stream of events) // or "t" (representing a position in a room's topology/depth). -type PaginationTokenType string +type SyncTokenType string const ( - // PaginationTokenTypeStream represents a position in the server's whole + // SyncTokenTypeStream represents a position in the server's whole // stream of events - PaginationTokenTypeStream PaginationTokenType = "s" - // PaginationTokenTypeTopology represents a position in a room's topology. - PaginationTokenTypeTopology PaginationTokenType = "t" + SyncTokenTypeStream SyncTokenType = "s" + // SyncTokenTypeTopology represents a position in a room's topology. + SyncTokenTypeTopology SyncTokenType = "t" ) -// PaginationToken represents a pagination token, used for interactions with -// /sync or /messages, for example. -type PaginationToken struct { - //Position StreamPosition - Type PaginationTokenType - // For /sync, this is the PDU position. For /messages, this is the topological position (depth). - // TODO: Given how different the positions are depending on the token type, they should probably be renamed - // or use different structs altogether. - PDUPosition StreamPosition - // For /sync, this is the EDU position. For /messages, this is the stream (PDU) position. - // TODO: Given how different the positions are depending on the token type, they should probably be renamed - // or use different structs altogether. - EDUTypingPosition StreamPosition +type StreamingToken struct { + syncToken } -// NewPaginationTokenFromString takes a string of the form "xyyyy..." where "x" -// represents the type of a pagination token and "yyyy..." the token itself, and -// parses it in order to create a new instance of PaginationToken. Returns an -// error if the token couldn't be parsed into an int64, or if the token type -// isn't a known type (returns ErrInvalidPaginationTokenType in the latter -// case). -func NewPaginationTokenFromString(s string) (token *PaginationToken, err error) { - if len(s) == 0 { - return nil, ErrInvalidPaginationTokenLen - } +func (t *StreamingToken) PDUPosition() StreamPosition { + return t.Positions[0] +} +func (t *StreamingToken) EDUPosition() StreamPosition { + return t.Positions[1] +} - token = new(PaginationToken) - var positions []string - - switch t := PaginationTokenType(s[:1]); t { - case PaginationTokenTypeStream, PaginationTokenTypeTopology: - token.Type = t - positions = strings.Split(s[1:], "_") - default: - token.Type = PaginationTokenTypeStream - positions = strings.Split(s, "_") - } - - // Try to get the PDU position. - if len(positions) >= 1 { - if pduPos, err := strconv.ParseInt(positions[0], 10, 64); err != nil { - return nil, err - } else if pduPos < 0 { - return nil, errors.New("negative PDU position not allowed") - } else { - token.PDUPosition = StreamPosition(pduPos) +// IsAfter returns true if ANY position in this token is greater than `other`. +func (t *StreamingToken) IsAfter(other StreamingToken) bool { + for i := range other.Positions { + if t.Positions[i] > other.Positions[i] { + return true } } + return false +} - // Try to get the typing position. - if len(positions) >= 2 { - if typPos, err := strconv.ParseInt(positions[1], 10, 64); err != nil { - return nil, err - } else if typPos < 0 { - return nil, errors.New("negative EDU typing position not allowed") - } else { - token.EDUTypingPosition = StreamPosition(typPos) +// WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken. +// If the latter StreamingToken contains a field that is not 0, it is considered an update, +// and its value will replace the corresponding value in the StreamingToken on which WithUpdates is called. +func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken) { + ret.Type = t.Type + ret.Positions = make([]StreamPosition, len(t.Positions)) + for i := range t.Positions { + ret.Positions[i] = t.Positions[i] + if other.Positions[i] == 0 { + continue } - } - - return -} - -// NewPaginationTokenFromTypeAndPosition takes a PaginationTokenType and a -// StreamPosition and returns an instance of PaginationToken. -func NewPaginationTokenFromTypeAndPosition( - t PaginationTokenType, pdupos StreamPosition, typpos StreamPosition, -) (p *PaginationToken) { - return &PaginationToken{ - Type: t, - PDUPosition: pdupos, - EDUTypingPosition: typpos, - } -} - -// String translates a PaginationToken to a string of the "xyyyy..." (see -// NewPaginationToken to know what it represents). -func (p *PaginationToken) String() string { - return fmt.Sprintf("%s%d_%d", p.Type, p.PDUPosition, p.EDUTypingPosition) -} - -// WithUpdates returns a copy of the PaginationToken with updates applied from another PaginationToken. -// If the latter PaginationToken contains a field that is not 0, it is considered an update, -// and its value will replace the corresponding value in the PaginationToken on which WithUpdates is called. -func (pt *PaginationToken) WithUpdates(other PaginationToken) PaginationToken { - ret := *pt - if other.PDUPosition != 0 { - ret.PDUPosition = other.PDUPosition - } - if other.EDUTypingPosition != 0 { - ret.EDUTypingPosition = other.EDUTypingPosition + ret.Positions[i] = other.Positions[i] } return ret } -// IsAfter returns whether one PaginationToken refers to states newer than another PaginationToken. -func (sp *PaginationToken) IsAfter(other PaginationToken) bool { - return sp.PDUPosition > other.PDUPosition || - sp.EDUTypingPosition > other.EDUTypingPosition +type TopologyToken struct { + syncToken +} + +func (t *TopologyToken) Depth() StreamPosition { + return t.Positions[0] +} +func (t *TopologyToken) PDUPosition() StreamPosition { + return t.Positions[1] +} +func (t *TopologyToken) StreamToken() StreamingToken { + return NewStreamToken(t.PDUPosition(), 0) +} +func (t *TopologyToken) String() string { + return t.syncToken.String() +} + +// Decrement the topology token to one event earlier. +func (t *TopologyToken) Decrement() { + depth := t.Positions[0] + pduPos := t.Positions[1] + if depth-1 <= 0 { + depth = 1 + } else { + depth-- + pduPos += 1000 + } + // The lowest token value is 1, therefore we need to manually set it to that + // value if we're below it. + if depth < 1 { + depth = 1 + } + t.Positions = []StreamPosition{ + depth, pduPos, + } +} + +// NewSyncTokenFromString takes a string of the form "xyyyy..." where "x" +// represents the type of a pagination token and "yyyy..." the token itself, and +// parses it in order to create a new instance of SyncToken. Returns an +// error if the token couldn't be parsed into an int64, or if the token type +// isn't a known type (returns ErrInvalidSyncTokenType in the latter +// case). +func newSyncTokenFromString(s string) (token *syncToken, err error) { + if len(s) == 0 { + return nil, ErrInvalidSyncTokenLen + } + + token = new(syncToken) + var positions []string + + switch t := SyncTokenType(s[:1]); t { + case SyncTokenTypeStream, SyncTokenTypeTopology: + token.Type = t + positions = strings.Split(s[1:], "_") + default: + return nil, ErrInvalidSyncTokenType + } + + for _, pos := range positions { + if posInt, err := strconv.ParseInt(pos, 10, 64); err != nil { + return nil, err + } else if posInt < 0 { + return nil, errors.New("negative position not allowed") + } else { + token.Positions = append(token.Positions, StreamPosition(posInt)) + } + } + return +} + +// NewTopologyToken creates a new sync token for /messages +func NewTopologyToken(depth, streamPos StreamPosition) TopologyToken { + if depth < 0 { + depth = 1 + } + return TopologyToken{ + syncToken: syncToken{ + Type: SyncTokenTypeTopology, + Positions: []StreamPosition{depth, streamPos}, + }, + } +} +func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) { + t, err := newSyncTokenFromString(tok) + if err != nil { + return + } + if t.Type != SyncTokenTypeTopology { + err = fmt.Errorf("token %s is not a topology token", tok) + return + } + if len(t.Positions) != 2 { + err = fmt.Errorf("token %s wrong number of values, got %d want 2", tok, len(t.Positions)) + return + } + return TopologyToken{ + syncToken: *t, + }, nil +} + +// NewStreamToken creates a new sync token for /sync +func NewStreamToken(pduPos, eduPos StreamPosition) StreamingToken { + return StreamingToken{ + syncToken: syncToken{ + Type: SyncTokenTypeStream, + Positions: []StreamPosition{pduPos, eduPos}, + }, + } +} +func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { + t, err := newSyncTokenFromString(tok) + if err != nil { + return + } + if t.Type != SyncTokenTypeStream { + err = fmt.Errorf("token %s is not a streaming token", tok) + return + } + if len(t.Positions) != 2 { + err = fmt.Errorf("token %s wrong number of values, got %d want 2", tok, len(t.Positions)) + return + } + return StreamingToken{ + syncToken: *t, + }, nil +} + +// syncToken represents a syncapi token, used for interactions with +// /sync or /messages, for example. +type syncToken struct { + Type SyncTokenType + // A list of stream positions, their meanings vary depending on the token type. + Positions []StreamPosition +} + +// String translates a SyncToken to a string of the "xyyyy..." (see +// NewSyncToken to know what it represents). +func (p *syncToken) String() string { + posStr := make([]string, len(p.Positions)) + for i := range p.Positions { + posStr[i] = strconv.FormatInt(int64(p.Positions[i]), 10) + } + + return fmt.Sprintf("%s%s", p.Type, strings.Join(posStr, "_")) } // PrevEventRef represents a reference to a previous event in a state event upgrade @@ -185,7 +268,7 @@ type Response struct { } // NewResponse creates an empty response with initialised maps. -func NewResponse(token PaginationToken) *Response { +func NewResponse(token StreamingToken) *Response { res := Response{ NextBatch: token.String(), } @@ -202,14 +285,6 @@ func NewResponse(token PaginationToken) *Response { res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0) res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0) - // Fill next_batch with a pagination token. Since this is a response to a sync request, we can assume - // we'll always return a stream token. - res.NextBatch = NewPaginationTokenFromTypeAndPosition( - PaginationTokenTypeStream, - StreamPosition(token.PDUPosition), - StreamPosition(token.EDUTypingPosition), - ).String() - return &res } diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index f4c84e0d1..1e27a8e32 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -2,26 +2,11 @@ package types import "testing" -func TestNewPaginationTokenFromString(t *testing.T) { - shouldPass := map[string]PaginationToken{ - "2": PaginationToken{ - Type: PaginationTokenTypeStream, - PDUPosition: 2, - }, - "s4": PaginationToken{ - Type: PaginationTokenTypeStream, - PDUPosition: 4, - }, - "s3_1": PaginationToken{ - Type: PaginationTokenTypeStream, - PDUPosition: 3, - EDUTypingPosition: 1, - }, - "t3_1_4": PaginationToken{ - Type: PaginationTokenTypeTopology, - PDUPosition: 3, - EDUTypingPosition: 1, - }, +func TestNewSyncTokenFromString(t *testing.T) { + shouldPass := map[string]syncToken{ + "s4_0": NewStreamToken(4, 0).syncToken, + "s3_1": NewStreamToken(3, 1).syncToken, + "t3_1": NewTopologyToken(3, 1).syncToken, } shouldFail := []string{ @@ -32,20 +17,21 @@ func TestNewPaginationTokenFromString(t *testing.T) { "b", "b-1", "-4", + "2", } for test, expected := range shouldPass { - result, err := NewPaginationTokenFromString(test) + result, err := newSyncTokenFromString(test) if err != nil { t.Error(err) } - if *result != expected { - t.Errorf("expected %v but got %v", expected.String(), result.String()) + if result.String() != expected.String() { + t.Errorf("%s expected %v but got %v", test, expected.String(), result.String()) } } for _, test := range shouldFail { - if _, err := NewPaginationTokenFromString(test); err == nil { + if _, err := newSyncTokenFromString(test); err == nil { t.Errorf("input '%v' should have errored but didn't", test) } }