Fix NewPaginationTokenFromString, define unit test for it

This commit is contained in:
Neil Alexander 2020-01-23 10:00:18 +00:00
parent fe707a163e
commit fb9fedcc7f
2 changed files with 73 additions and 14 deletions

View file

@ -16,6 +16,7 @@ package types
import (
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
@ -29,6 +30,9 @@ var (
// new instance of PaginationToken with an invalid type (i.e. neither "s"
// nor "t").
ErrInvalidPaginationTokenType = fmt.Errorf("Pagination token has an unknown prefix (should be either s or t)")
// ErrInvalidPaginationTokenLen is returned when the pagination token is an
// invalid length
ErrInvalidPaginationTokenLen = fmt.Errorf("Pagination token has an invalid length")
)
// StreamPosition represents the offset in the sync stream a client is at.
@ -71,27 +75,28 @@ type PaginationToken struct {
// isn't a known type (returns ErrInvalidPaginationTokenType in the latter
// case).
func NewPaginationTokenFromString(s string) (token *PaginationToken, err error) {
token = new(PaginationToken)
// Check if the type is among the known ones.
token.Type = PaginationTokenType(s[:1])
if token.Type != PaginationTokenTypeStream && token.Type != PaginationTokenTypeTopology {
if pduPos, perr := strconv.ParseInt(s, 10, 64); perr != nil {
return nil, ErrInvalidPaginationTokenType
} else {
token.Type = PaginationTokenTypeStream
token.PDUPosition = StreamPosition(pduPos)
return
}
if len(s) == 0 {
return nil, ErrInvalidPaginationTokenLen
}
// Parse the token (aka position).
positions := strings.Split(s[1:], "_")
token = new(PaginationToken)
var positions []string
switch t := PaginationTokenType(s[:1]); t {
case PaginationTokenTypeStream, PaginationTokenTypeTopology:
token.Type = t
positions = strings.Split(s[1:], "_")
default:
token.Type = PaginationTokenTypeStream
positions = strings.Split(s, "_")
}
// Try to get the PDU position.
if len(positions) >= 1 {
if pduPos, err := strconv.ParseInt(positions[0], 10, 64); err != nil {
return nil, err
} else if pduPos < 0 {
return nil, errors.New("negative PDU position not allowed")
} else {
token.PDUPosition = StreamPosition(pduPos)
}
@ -101,6 +106,8 @@ func NewPaginationTokenFromString(s string) (token *PaginationToken, err error)
if len(positions) >= 2 {
if typPos, err := strconv.ParseInt(positions[1], 10, 64); err != nil {
return nil, err
} else if typPos < 0 {
return nil, errors.New("negative EDU typing position not allowed")
} else {
token.EDUTypingPosition = StreamPosition(typPos)
}

View file

@ -0,0 +1,52 @@
package types
import "testing"
func TestNewPaginationTokenFromString(t *testing.T) {
shouldPass := map[string]PaginationToken{
"2": PaginationToken{
Type: PaginationTokenType("s"),
PDUPosition: 2,
},
"s4": PaginationToken{
Type: PaginationTokenType("s"),
PDUPosition: 4,
},
"s3_1": PaginationToken{
Type: PaginationTokenType("s"),
PDUPosition: 3,
EDUTypingPosition: 1,
},
"t3_1_4": PaginationToken{
Type: PaginationTokenType("t"),
PDUPosition: 3,
EDUTypingPosition: 1,
},
}
shouldFail := []string{
"",
"s_1",
"s_",
"a3_4",
"b",
"b-1",
"-4",
}
for test, expected := range shouldPass {
result, err := NewPaginationTokenFromString(test)
if err != nil {
t.Error(err)
}
if *result != expected {
t.Errorf("expected %v but got %v", expected.String(), result.String())
}
}
for _, test := range shouldFail {
if _, err := NewPaginationTokenFromString(test); err == nil {
t.Errorf("input '%v' should have errored but didn't", test)
}
}
}