diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 922a7df6e..8916565dc 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -81,7 +81,10 @@ func Setup( })).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/messages", common.MakeAuthAPI("room_messages", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { - vars := mux.Vars(req) + vars, err := common.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], federation, queryAPI, cfg) })).Methods(http.MethodGet, http.MethodOptions) } diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index d93a55de0..68863d91c 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -566,12 +566,11 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( defer common.EndTransaction(txn, &succeeded) // Get the current sync position which we will base the sync response on. - /* - toPos, err = d.syncPositionTx(ctx, txn) - if err != nil { - return - } - */ + toPos, err = d.syncPositionTx(ctx, txn) + if err != nil { + return + } + // Get the current stream position which we will base the sync response on. pos, err := d.syncStreamPositionTx(ctx, txn) if err != nil { @@ -628,6 +627,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( types.PaginationTokenTypeTopology, backwardTopologyPos, 0, ).String() + // TODO: do we want short-form here? adds complexity elsewhere if prevPDUPos := recentStreamEvents[0].StreamPosition - 1; prevPDUPos > 0 { // Use the short form of batch token for prev_batch jr.Timeline.PrevBatch = strconv.FormatInt(int64(prevPDUPos), 10) diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index c7f9ba0f4..a1bce0a6b 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -52,7 +52,7 @@ func newSyncRequest(req *http.Request, device authtypes.Device) (*syncRequest, e timeout := getTimeout(req.URL.Query().Get("timeout")) fullState := req.URL.Query().Get("full_state") wantFullState := fullState != "" && fullState != "false" - since, err := getSyncStreamPosition(req.URL.Query().Get("since")) + since, err := getPaginationToken(req.URL.Query().Get("since")) if err != nil { return nil, err } @@ -82,11 +82,11 @@ func getTimeout(timeoutMS string) time.Duration { // getPaginationToken tries to parse a 'since' token taken from the API to a // pagination token. If the string is empty then (nil, nil) is returned. // Returns an error if the parsed token's type isn't types.PaginationTokenTypeStream. -func getPaginationToken(since string) (*types.StreamPosition, error) { +func getSyncStreamPosition(since string) (*types.StreamPosition, error) { if since == "" { return nil, nil } - p, err := types.NewPaginationTokenFromString(since) + p, err := getPaginationToken(since) if err != nil { return nil, err } @@ -101,15 +101,10 @@ func getPaginationToken(since string) (*types.StreamPosition, error) { // There are two forms of tokens: The full length form containing all PDU and EDU // positions separated by "_", and the short form containing only the PDU // position. Short form can be used for, e.g., `prev_batch` tokens. -func getSyncStreamPosition(since string) (*types.PaginationToken, error) { +func getPaginationToken(since string) (*types.PaginationToken, error) { if since == "" { return nil, nil } - pos, err := types.NewPaginationTokenFromString(since) - if err != nil { - return nil, err - } - - return pos, nil + return types.NewPaginationTokenFromString(since) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index f27377b88..b596aedfe 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -76,12 +76,18 @@ func NewPaginationTokenFromString(s string) (p *PaginationToken, err error) { // Check if the type is among the known ones. p.Type = PaginationTokenType(s[:1]) if p.Type != PaginationTokenTypeStream && p.Type != PaginationTokenTypeTopology { - err = ErrInvalidPaginationTokenType - return + if pduPos, perr := strconv.ParseInt(s, 10, 64); perr != nil { + return nil, ErrInvalidPaginationTokenType + } else { + // TODO: should this be topograpical? + p.Type = PaginationTokenTypeTopology + p.PDUPosition = StreamPosition(pduPos) + return + } } // Parse the token (aka position). - positions := strings.Split(s[:1], "_") + positions := strings.Split(s[1:], "_") // Try to get the PDU position. if len(positions) >= 1 {