diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go index 59a633419..ef8fbaf54 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go @@ -418,25 +418,35 @@ func isValidMacLogin( return hmac.Equal(givenMac, expectedMAC), nil } -// checkFlows checks a single flow a against another, b. If a contains at least -// all of the stages that b does, checkFlows returns true. -func checkFlows(a []authtypes.LoginType, b []authtypes.LoginType) bool { - // Sort the slices for simple comparison - sort.Slice(a, func(i, j int) bool { return a[i] < a[j] }) - sort.Slice(b, func(i, j int) bool { return b[i] < b[j] }) +// checkFlows checks a single completed flow against another required one. If +// one contains at least all of the stages that the other does, checkFlows +// returns true. +func checkFlows( + completedStages []authtypes.LoginType, + requiredStages []authtypes.LoginType, +) bool { + // Create temporary slices so they originals will not be modified on sorting + completed := make([]authtypes.LoginType, len(completedStages)) + required := make([]authtypes.LoginType, len(requiredStages)) + copy(completed, completedStages) + copy(required, requiredStages) - // Iterate through each slice, going to the next allowed slice only once + // Sort the slices for simple comparison + sort.Slice(completed, func(i, j int) bool { return completed[i] < completed[j] }) + sort.Slice(required, func(i, j int) bool { return required[i] < required[j] }) + + // Iterate through each slice, going to the next required slice only once // we've found a match. i, j := 0, 0 - for j < len(b) { + for j < len(required) { // Exit if we've reached the end of our input without being able to - // match all of the allowed stages. - if i >= len(a) { + // match all of the required stages. + if i >= len(completed) { return false } - // If we've found a stage we want, move on to the next allowed stage. - if a[i] == b[j] { + // If we've found a stage we want, move on to the next required stage. + if completed[i] == required[j] { j++ } i++ diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/register_test.go b/src/github.com/matrix-org/dendrite/clientapi/routing/register_test.go index c6e692b07..de18c8d2a 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/register_test.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/register_test.go @@ -97,7 +97,7 @@ func TestFlowCheckingInvalidStage(t *testing.T) { // Should return true as we complete all stages of an allowed flow, though out // of order, as well as extraneous stages. -func TestFlowCheckingUnorderedAndExtraneous(t *testing.T) { +func TestFlowCheckingExtraneousUnordered(t *testing.T) { testFlow := []authtypes.LoginType{ authtypes.LoginType("stage5"), authtypes.LoginType("stage4"), @@ -110,7 +110,7 @@ func TestFlowCheckingUnorderedAndExtraneous(t *testing.T) { } } -// Should return false as we're providing less stages than are required. +// Should return false as we're providing fewer stages than are required. func TestFlowCheckingShortIncorrectInput(t *testing.T) { testFlow := []authtypes.LoginType{ authtypes.LoginType("stage8"), @@ -120,7 +120,7 @@ func TestFlowCheckingShortIncorrectInput(t *testing.T) { } } -// Should return false as we're providing less stages than are required. +// Should return false as we're providing different stages than are required. func TestFlowCheckingExtraneousIncorrectInput(t *testing.T) { testFlow := []authtypes.LoginType{ authtypes.LoginType("stage8"),