Fix NewPaginationTokenFromString, define unit test for it

This commit is contained in:
Neil Alexander 2020-01-23 10:00:18 +00:00
parent fe707a163e
commit fb9fedcc7f
2 changed files with 73 additions and 14 deletions

View file

@ -16,6 +16,7 @@ package types
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@ -29,6 +30,9 @@ var (
// new instance of PaginationToken with an invalid type (i.e. neither "s" // new instance of PaginationToken with an invalid type (i.e. neither "s"
// nor "t"). // nor "t").
ErrInvalidPaginationTokenType = fmt.Errorf("Pagination token has an unknown prefix (should be either s or t)") ErrInvalidPaginationTokenType = fmt.Errorf("Pagination token has an unknown prefix (should be either s or t)")
// ErrInvalidPaginationTokenLen is returned when the pagination token is an
// invalid length
ErrInvalidPaginationTokenLen = fmt.Errorf("Pagination token has an invalid length")
) )
// StreamPosition represents the offset in the sync stream a client is at. // StreamPosition represents the offset in the sync stream a client is at.
@ -71,27 +75,28 @@ type PaginationToken struct {
// isn't a known type (returns ErrInvalidPaginationTokenType in the latter // isn't a known type (returns ErrInvalidPaginationTokenType in the latter
// case). // case).
func NewPaginationTokenFromString(s string) (token *PaginationToken, err error) { func NewPaginationTokenFromString(s string) (token *PaginationToken, err error) {
if len(s) == 0 {
return nil, ErrInvalidPaginationTokenLen
}
token = new(PaginationToken) token = new(PaginationToken)
var positions []string
// Check if the type is among the known ones. switch t := PaginationTokenType(s[:1]); t {
token.Type = PaginationTokenType(s[:1]) case PaginationTokenTypeStream, PaginationTokenTypeTopology:
if token.Type != PaginationTokenTypeStream && token.Type != PaginationTokenTypeTopology { token.Type = t
if pduPos, perr := strconv.ParseInt(s, 10, 64); perr != nil { positions = strings.Split(s[1:], "_")
return nil, ErrInvalidPaginationTokenType default:
} else {
token.Type = PaginationTokenTypeStream token.Type = PaginationTokenTypeStream
token.PDUPosition = StreamPosition(pduPos) positions = strings.Split(s, "_")
return
} }
}
// Parse the token (aka position).
positions := strings.Split(s[1:], "_")
// Try to get the PDU position. // Try to get the PDU position.
if len(positions) >= 1 { if len(positions) >= 1 {
if pduPos, err := strconv.ParseInt(positions[0], 10, 64); err != nil { if pduPos, err := strconv.ParseInt(positions[0], 10, 64); err != nil {
return nil, err return nil, err
} else if pduPos < 0 {
return nil, errors.New("negative PDU position not allowed")
} else { } else {
token.PDUPosition = StreamPosition(pduPos) token.PDUPosition = StreamPosition(pduPos)
} }
@ -101,6 +106,8 @@ func NewPaginationTokenFromString(s string) (token *PaginationToken, err error)
if len(positions) >= 2 { if len(positions) >= 2 {
if typPos, err := strconv.ParseInt(positions[1], 10, 64); err != nil { if typPos, err := strconv.ParseInt(positions[1], 10, 64); err != nil {
return nil, err return nil, err
} else if typPos < 0 {
return nil, errors.New("negative EDU typing position not allowed")
} else { } else {
token.EDUTypingPosition = StreamPosition(typPos) token.EDUTypingPosition = StreamPosition(typPos)
} }

View file

@ -0,0 +1,52 @@
package types
import "testing"
func TestNewPaginationTokenFromString(t *testing.T) {
shouldPass := map[string]PaginationToken{
"2": PaginationToken{
Type: PaginationTokenType("s"),
PDUPosition: 2,
},
"s4": PaginationToken{
Type: PaginationTokenType("s"),
PDUPosition: 4,
},
"s3_1": PaginationToken{
Type: PaginationTokenType("s"),
PDUPosition: 3,
EDUTypingPosition: 1,
},
"t3_1_4": PaginationToken{
Type: PaginationTokenType("t"),
PDUPosition: 3,
EDUTypingPosition: 1,
},
}
shouldFail := []string{
"",
"s_1",
"s_",
"a3_4",
"b",
"b-1",
"-4",
}
for test, expected := range shouldPass {
result, err := NewPaginationTokenFromString(test)
if err != nil {
t.Error(err)
}
if *result != expected {
t.Errorf("expected %v but got %v", expected.String(), result.String())
}
}
for _, test := range shouldFail {
if _, err := NewPaginationTokenFromString(test); err == nil {
t.Errorf("input '%v' should have errored but didn't", test)
}
}
}