mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-19 12:53:09 -06:00
Fix NewPaginationTokenFromString, define unit test for it
This commit is contained in:
parent
fe707a163e
commit
fb9fedcc7f
|
|
@ -16,6 +16,7 @@ package types
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
|
@ -29,6 +30,9 @@ var (
|
|||
// new instance of PaginationToken with an invalid type (i.e. neither "s"
|
||||
// nor "t").
|
||||
ErrInvalidPaginationTokenType = fmt.Errorf("Pagination token has an unknown prefix (should be either s or t)")
|
||||
// ErrInvalidPaginationTokenLen is returned when the pagination token is an
|
||||
// invalid length
|
||||
ErrInvalidPaginationTokenLen = fmt.Errorf("Pagination token has an invalid length")
|
||||
)
|
||||
|
||||
// StreamPosition represents the offset in the sync stream a client is at.
|
||||
|
|
@ -71,27 +75,28 @@ type PaginationToken struct {
|
|||
// isn't a known type (returns ErrInvalidPaginationTokenType in the latter
|
||||
// case).
|
||||
func NewPaginationTokenFromString(s string) (token *PaginationToken, err error) {
|
||||
if len(s) == 0 {
|
||||
return nil, ErrInvalidPaginationTokenLen
|
||||
}
|
||||
|
||||
token = new(PaginationToken)
|
||||
var positions []string
|
||||
|
||||
// Check if the type is among the known ones.
|
||||
token.Type = PaginationTokenType(s[:1])
|
||||
if token.Type != PaginationTokenTypeStream && token.Type != PaginationTokenTypeTopology {
|
||||
if pduPos, perr := strconv.ParseInt(s, 10, 64); perr != nil {
|
||||
return nil, ErrInvalidPaginationTokenType
|
||||
} else {
|
||||
switch t := PaginationTokenType(s[:1]); t {
|
||||
case PaginationTokenTypeStream, PaginationTokenTypeTopology:
|
||||
token.Type = t
|
||||
positions = strings.Split(s[1:], "_")
|
||||
default:
|
||||
token.Type = PaginationTokenTypeStream
|
||||
token.PDUPosition = StreamPosition(pduPos)
|
||||
return
|
||||
positions = strings.Split(s, "_")
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the token (aka position).
|
||||
positions := strings.Split(s[1:], "_")
|
||||
|
||||
// Try to get the PDU position.
|
||||
if len(positions) >= 1 {
|
||||
if pduPos, err := strconv.ParseInt(positions[0], 10, 64); err != nil {
|
||||
return nil, err
|
||||
} else if pduPos < 0 {
|
||||
return nil, errors.New("negative PDU position not allowed")
|
||||
} else {
|
||||
token.PDUPosition = StreamPosition(pduPos)
|
||||
}
|
||||
|
|
@ -101,6 +106,8 @@ func NewPaginationTokenFromString(s string) (token *PaginationToken, err error)
|
|||
if len(positions) >= 2 {
|
||||
if typPos, err := strconv.ParseInt(positions[1], 10, 64); err != nil {
|
||||
return nil, err
|
||||
} else if typPos < 0 {
|
||||
return nil, errors.New("negative EDU typing position not allowed")
|
||||
} else {
|
||||
token.EDUTypingPosition = StreamPosition(typPos)
|
||||
}
|
||||
|
|
|
|||
52
syncapi/types/types_test.go
Normal file
52
syncapi/types/types_test.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
package types
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNewPaginationTokenFromString(t *testing.T) {
|
||||
shouldPass := map[string]PaginationToken{
|
||||
"2": PaginationToken{
|
||||
Type: PaginationTokenType("s"),
|
||||
PDUPosition: 2,
|
||||
},
|
||||
"s4": PaginationToken{
|
||||
Type: PaginationTokenType("s"),
|
||||
PDUPosition: 4,
|
||||
},
|
||||
"s3_1": PaginationToken{
|
||||
Type: PaginationTokenType("s"),
|
||||
PDUPosition: 3,
|
||||
EDUTypingPosition: 1,
|
||||
},
|
||||
"t3_1_4": PaginationToken{
|
||||
Type: PaginationTokenType("t"),
|
||||
PDUPosition: 3,
|
||||
EDUTypingPosition: 1,
|
||||
},
|
||||
}
|
||||
|
||||
shouldFail := []string{
|
||||
"",
|
||||
"s_1",
|
||||
"s_",
|
||||
"a3_4",
|
||||
"b",
|
||||
"b-1",
|
||||
"-4",
|
||||
}
|
||||
|
||||
for test, expected := range shouldPass {
|
||||
result, err := NewPaginationTokenFromString(test)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if *result != expected {
|
||||
t.Errorf("expected %v but got %v", expected.String(), result.String())
|
||||
}
|
||||
}
|
||||
|
||||
for _, test := range shouldFail {
|
||||
if _, err := NewPaginationTokenFromString(test); err == nil {
|
||||
t.Errorf("input '%v' should have errored but didn't", test)
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue