diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go index d8be8b6cc..59a633419 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go @@ -425,18 +425,21 @@ func checkFlows(a []authtypes.LoginType, b []authtypes.LoginType) bool { sort.Slice(a, func(i, j int) bool { return a[i] < a[j] }) sort.Slice(b, func(i, j int) bool { return b[i] < b[j] }) - // Account for any extra stages a user may unnecessarily do. - extraStages := len(a) - len(b) - for i := range b { - if extraStages < 0 { - // The provided flow has run out of possible extraneous stages. + // Iterate through each slice, going to the next allowed slice only once + // we've found a match. + i, j := 0, 0 + for j < len(b) { + // Exit if we've reached the end of our input without being able to + // match all of the allowed stages. + if i >= len(a) { return false } - if a[i] != b[i] { - // Wasn't a match, drop an extraneous stage. - extraStages-- - continue + + // If we've found a stage we want, move on to the next allowed stage. + if a[i] == b[j] { + j++ } + i++ } return true } diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/register_test.go b/src/github.com/matrix-org/dendrite/clientapi/routing/register_test.go index 075d3d4b9..c6e692b07 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/register_test.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/register_test.go @@ -109,3 +109,26 @@ func TestFlowCheckingUnorderedAndExtraneous(t *testing.T) { t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be true.") } } + +// Should return false as we're providing less stages than are required. +func TestFlowCheckingShortIncorrectInput(t *testing.T) { + testFlow := []authtypes.LoginType{ + authtypes.LoginType("stage8"), + } + if checkFlowCompleted(testFlow, allowedFlows) { + t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be false.") + } +} + +// Should return false as we're providing less stages than are required. +func TestFlowCheckingExtraneousIncorrectInput(t *testing.T) { + testFlow := []authtypes.LoginType{ + authtypes.LoginType("stage8"), + authtypes.LoginType("stage9"), + authtypes.LoginType("stage10"), + authtypes.LoginType("stage11"), + } + if checkFlowCompleted(testFlow, allowedFlows) { + t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be false.") + } +}