Create a copy of slices before comparing

We do so to avoid sorting the original slices unexpectedly.

Signed-off-by: Andrew Morgan (https://amorgan.xyz) <andrew@amorgan.xyz>
This commit is contained in:
Andrew Morgan (https://amorgan.xyz) 2017-11-28 14:34:23 -08:00
parent 1537188123
commit 8bd5fd66fe
No known key found for this signature in database
GPG key ID: 174BEAB009FD176D
2 changed files with 25 additions and 15 deletions

View file

@ -418,25 +418,35 @@ func isValidMacLogin(
return hmac.Equal(givenMac, expectedMAC), nil
}
// checkFlows checks a single flow a against another, b. If a contains at least
// all of the stages that b does, checkFlows returns true.
func checkFlows(a []authtypes.LoginType, b []authtypes.LoginType) bool {
// Sort the slices for simple comparison
sort.Slice(a, func(i, j int) bool { return a[i] < a[j] })
sort.Slice(b, func(i, j int) bool { return b[i] < b[j] })
// checkFlows checks a single completed flow against another required one. If
// one contains at least all of the stages that the other does, checkFlows
// returns true.
func checkFlows(
completedStages []authtypes.LoginType,
requiredStages []authtypes.LoginType,
) bool {
// Create temporary slices so they originals will not be modified on sorting
completed := make([]authtypes.LoginType, len(completedStages))
required := make([]authtypes.LoginType, len(requiredStages))
copy(completed, completedStages)
copy(required, requiredStages)
// Iterate through each slice, going to the next allowed slice only once
// Sort the slices for simple comparison
sort.Slice(completed, func(i, j int) bool { return completed[i] < completed[j] })
sort.Slice(required, func(i, j int) bool { return required[i] < required[j] })
// Iterate through each slice, going to the next required slice only once
// we've found a match.
i, j := 0, 0
for j < len(b) {
for j < len(required) {
// Exit if we've reached the end of our input without being able to
// match all of the allowed stages.
if i >= len(a) {
// match all of the required stages.
if i >= len(completed) {
return false
}
// If we've found a stage we want, move on to the next allowed stage.
if a[i] == b[j] {
// If we've found a stage we want, move on to the next required stage.
if completed[i] == required[j] {
j++
}
i++

View file

@ -97,7 +97,7 @@ func TestFlowCheckingInvalidStage(t *testing.T) {
// Should return true as we complete all stages of an allowed flow, though out
// of order, as well as extraneous stages.
func TestFlowCheckingUnorderedAndExtraneous(t *testing.T) {
func TestFlowCheckingExtraneousUnordered(t *testing.T) {
testFlow := []authtypes.LoginType{
authtypes.LoginType("stage5"),
authtypes.LoginType("stage4"),
@ -110,7 +110,7 @@ func TestFlowCheckingUnorderedAndExtraneous(t *testing.T) {
}
}
// Should return false as we're providing less stages than are required.
// Should return false as we're providing fewer stages than are required.
func TestFlowCheckingShortIncorrectInput(t *testing.T) {
testFlow := []authtypes.LoginType{
authtypes.LoginType("stage8"),
@ -120,7 +120,7 @@ func TestFlowCheckingShortIncorrectInput(t *testing.T) {
}
}
// Should return false as we're providing less stages than are required.
// Should return false as we're providing different stages than are required.
func TestFlowCheckingExtraneousIncorrectInput(t *testing.T) {
testFlow := []authtypes.LoginType{
authtypes.LoginType("stage8"),