mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-19 12:53:09 -06:00
Fix NewPaginationTokenFromString, define unit test for it
This commit is contained in:
parent
fe707a163e
commit
fb9fedcc7f
|
|
@ -16,6 +16,7 @@ package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
@ -29,6 +30,9 @@ var (
|
||||||
// new instance of PaginationToken with an invalid type (i.e. neither "s"
|
// new instance of PaginationToken with an invalid type (i.e. neither "s"
|
||||||
// nor "t").
|
// nor "t").
|
||||||
ErrInvalidPaginationTokenType = fmt.Errorf("Pagination token has an unknown prefix (should be either s or t)")
|
ErrInvalidPaginationTokenType = fmt.Errorf("Pagination token has an unknown prefix (should be either s or t)")
|
||||||
|
// ErrInvalidPaginationTokenLen is returned when the pagination token is an
|
||||||
|
// invalid length
|
||||||
|
ErrInvalidPaginationTokenLen = fmt.Errorf("Pagination token has an invalid length")
|
||||||
)
|
)
|
||||||
|
|
||||||
// StreamPosition represents the offset in the sync stream a client is at.
|
// StreamPosition represents the offset in the sync stream a client is at.
|
||||||
|
|
@ -71,27 +75,28 @@ type PaginationToken struct {
|
||||||
// isn't a known type (returns ErrInvalidPaginationTokenType in the latter
|
// isn't a known type (returns ErrInvalidPaginationTokenType in the latter
|
||||||
// case).
|
// case).
|
||||||
func NewPaginationTokenFromString(s string) (token *PaginationToken, err error) {
|
func NewPaginationTokenFromString(s string) (token *PaginationToken, err error) {
|
||||||
token = new(PaginationToken)
|
if len(s) == 0 {
|
||||||
|
return nil, ErrInvalidPaginationTokenLen
|
||||||
// Check if the type is among the known ones.
|
|
||||||
token.Type = PaginationTokenType(s[:1])
|
|
||||||
if token.Type != PaginationTokenTypeStream && token.Type != PaginationTokenTypeTopology {
|
|
||||||
if pduPos, perr := strconv.ParseInt(s, 10, 64); perr != nil {
|
|
||||||
return nil, ErrInvalidPaginationTokenType
|
|
||||||
} else {
|
|
||||||
token.Type = PaginationTokenTypeStream
|
|
||||||
token.PDUPosition = StreamPosition(pduPos)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the token (aka position).
|
token = new(PaginationToken)
|
||||||
positions := strings.Split(s[1:], "_")
|
var positions []string
|
||||||
|
|
||||||
|
switch t := PaginationTokenType(s[:1]); t {
|
||||||
|
case PaginationTokenTypeStream, PaginationTokenTypeTopology:
|
||||||
|
token.Type = t
|
||||||
|
positions = strings.Split(s[1:], "_")
|
||||||
|
default:
|
||||||
|
token.Type = PaginationTokenTypeStream
|
||||||
|
positions = strings.Split(s, "_")
|
||||||
|
}
|
||||||
|
|
||||||
// Try to get the PDU position.
|
// Try to get the PDU position.
|
||||||
if len(positions) >= 1 {
|
if len(positions) >= 1 {
|
||||||
if pduPos, err := strconv.ParseInt(positions[0], 10, 64); err != nil {
|
if pduPos, err := strconv.ParseInt(positions[0], 10, 64); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
} else if pduPos < 0 {
|
||||||
|
return nil, errors.New("negative PDU position not allowed")
|
||||||
} else {
|
} else {
|
||||||
token.PDUPosition = StreamPosition(pduPos)
|
token.PDUPosition = StreamPosition(pduPos)
|
||||||
}
|
}
|
||||||
|
|
@ -101,6 +106,8 @@ func NewPaginationTokenFromString(s string) (token *PaginationToken, err error)
|
||||||
if len(positions) >= 2 {
|
if len(positions) >= 2 {
|
||||||
if typPos, err := strconv.ParseInt(positions[1], 10, 64); err != nil {
|
if typPos, err := strconv.ParseInt(positions[1], 10, 64); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
} else if typPos < 0 {
|
||||||
|
return nil, errors.New("negative EDU typing position not allowed")
|
||||||
} else {
|
} else {
|
||||||
token.EDUTypingPosition = StreamPosition(typPos)
|
token.EDUTypingPosition = StreamPosition(typPos)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
52
syncapi/types/types_test.go
Normal file
52
syncapi/types/types_test.go
Normal file
|
|
@ -0,0 +1,52 @@
|
||||||
|
package types
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestNewPaginationTokenFromString(t *testing.T) {
|
||||||
|
shouldPass := map[string]PaginationToken{
|
||||||
|
"2": PaginationToken{
|
||||||
|
Type: PaginationTokenType("s"),
|
||||||
|
PDUPosition: 2,
|
||||||
|
},
|
||||||
|
"s4": PaginationToken{
|
||||||
|
Type: PaginationTokenType("s"),
|
||||||
|
PDUPosition: 4,
|
||||||
|
},
|
||||||
|
"s3_1": PaginationToken{
|
||||||
|
Type: PaginationTokenType("s"),
|
||||||
|
PDUPosition: 3,
|
||||||
|
EDUTypingPosition: 1,
|
||||||
|
},
|
||||||
|
"t3_1_4": PaginationToken{
|
||||||
|
Type: PaginationTokenType("t"),
|
||||||
|
PDUPosition: 3,
|
||||||
|
EDUTypingPosition: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldFail := []string{
|
||||||
|
"",
|
||||||
|
"s_1",
|
||||||
|
"s_",
|
||||||
|
"a3_4",
|
||||||
|
"b",
|
||||||
|
"b-1",
|
||||||
|
"-4",
|
||||||
|
}
|
||||||
|
|
||||||
|
for test, expected := range shouldPass {
|
||||||
|
result, err := NewPaginationTokenFromString(test)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if *result != expected {
|
||||||
|
t.Errorf("expected %v but got %v", expected.String(), result.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range shouldFail {
|
||||||
|
if _, err := NewPaginationTokenFromString(test); err == nil {
|
||||||
|
t.Errorf("input '%v' should have errored but didn't", test)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue