diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/authtypes/flow.go b/src/github.com/matrix-org/dendrite/clientapi/auth/authtypes/flow.go index 0f9f29ea5..d5766fcc2 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/authtypes/flow.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/authtypes/flow.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright Andrew Morgan // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go index 070317215..d8be8b6cc 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go @@ -252,7 +252,7 @@ func handleRegistrationFlow( } // Check if the user's registration flow has been completed successfully - if !checkFlowCompleted(authtypes.Flow{sessions[sessionID]}, cfg.Derived.Registration.Flows) { + if !checkFlowCompleted(sessions[sessionID], cfg.Derived.Registration.Flows) { // There are still more stages to complete. // Return the flows and those that have been completed. return util.JSONResponse{ @@ -441,13 +441,13 @@ func checkFlows(a []authtypes.LoginType, b []authtypes.LoginType) bool { return true } -// checkFlowCompleted checks if a registration flow complies with any flow +// checkFlowCompleted checks if a registration flow complies with any allowed flow // dictated by the server. Order of stages does not matter. A user may complete // extra stages as long as the required stages of at least one flow is met. -func checkFlowCompleted(flow authtypes.Flow, derivedFlows []authtypes.Flow) bool { +func checkFlowCompleted(flow []authtypes.LoginType, allowedFlows []authtypes.Flow) bool { // Iterate through possible flows to check whether any have been fully completed. - for _, derivedFlow := range derivedFlows { - if checkFlows(flow.Stages, derivedFlow.Stages) { + for _, allowedFlow := range allowedFlows { + if checkFlows(flow, allowedFlow.Stages) { return true } } diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/register_test.go b/src/github.com/matrix-org/dendrite/clientapi/routing/register_test.go index 86dcffd36..075d3d4b9 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/register_test.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/register_test.go @@ -1,5 +1,4 @@ -// Copyright 2017 Vector Creations Ltd -// Copyright 2017 New Vector Ltd +// Copyright 2017 Andrew Morgan // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,71 +15,97 @@ package routing import ( - "fmt" "testing" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) -func TestFlowChecking(t *testing.T) { - derivedFlows := []authtypes.Flow{ +var ( + // Registration Flows that the server allows. + allowedFlows []authtypes.Flow = []authtypes.Flow{ { []authtypes.LoginType{ - authtypes.LoginType("type1"), - authtypes.LoginType("type2"), + authtypes.LoginType("stage1"), + authtypes.LoginType("stage2"), }, }, { []authtypes.LoginType{ - authtypes.LoginType("type1"), - authtypes.LoginType("type3"), + authtypes.LoginType("stage1"), + authtypes.LoginType("stage3"), }, }, } +) - testFlow1 := authtypes.Flow{ - []authtypes.LoginType{ - authtypes.LoginType("type1"), - authtypes.LoginType("type3"), - }, - } - testFlow2 := authtypes.Flow{ - []authtypes.LoginType{ - authtypes.LoginType("type2"), - authtypes.LoginType("type3"), - }, - } - testFlow3 := authtypes.Flow{ - []authtypes.LoginType{ - authtypes.LoginType("type1"), - authtypes.LoginType("type3"), - authtypes.LoginType("type4"), - }, - } - testFlow4 := authtypes.Flow{ - []authtypes.LoginType{}, - } - testFlow5 := authtypes.Flow{ - []authtypes.LoginType{ - authtypes.LoginType("type3"), - authtypes.LoginType("type2"), - authtypes.LoginType("type1"), - }, +// Should return true as we're completing all the stages of a single flow in +// order. +func TestFlowCheckingCompleteFlowOrdered(t *testing.T) { + testFlow := []authtypes.LoginType{ + authtypes.LoginType("stage1"), + authtypes.LoginType("stage3"), } - if !checkFlowCompleted(testFlow1, derivedFlows) { - t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow1, ", from derived flows: ", derivedFlows, ". Should be true.")) - } - if checkFlowCompleted(testFlow2, derivedFlows) { - t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow2, ", from derived flows: ", derivedFlows, ". Should be false.")) - } - if !checkFlowCompleted(testFlow3, derivedFlows) { - t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow3, ", from derived flows: ", derivedFlows, ". Should be true.")) - } - if checkFlowCompleted(testFlow4, derivedFlows) { - t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow4, ", from derived flows: ", derivedFlows, ". Should be false.")) - } - if !checkFlowCompleted(testFlow5, derivedFlows) { - t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow5, ", from derived flows: ", derivedFlows, ". Should be true.")) + if !checkFlowCompleted(testFlow, allowedFlows) { + t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be true.") + } +} + +// Should return false as all stages in a single flow need to be completed. +func TestFlowCheckingStagesFromDifferentFlows(t *testing.T) { + testFlow := []authtypes.LoginType{ + authtypes.LoginType("stage2"), + authtypes.LoginType("stage3"), + } + + if checkFlowCompleted(testFlow, allowedFlows) { + t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be false.") + } +} + +// Should return true as we're completing all the stages from a single flow, as +// well as some extraneous stages. +func TestFlowCheckingCompleteOrderedExtraneous(t *testing.T) { + testFlow := []authtypes.LoginType{ + authtypes.LoginType("stage1"), + authtypes.LoginType("stage3"), + authtypes.LoginType("stage4"), + authtypes.LoginType("stage5"), + } + if !checkFlowCompleted(testFlow, allowedFlows) { + t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be true.") + } +} + +// Should return false as we're submitting an empty flow. +func TestFlowCheckingEmptyFlow(t *testing.T) { + testFlow := []authtypes.LoginType{} + if checkFlowCompleted(testFlow, allowedFlows) { + t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be false.") + } +} + +// Should return false as we've completed a stage that isn't in any allowed flow. +func TestFlowCheckingInvalidStage(t *testing.T) { + testFlow := []authtypes.LoginType{ + authtypes.LoginType("stage8"), + } + if checkFlowCompleted(testFlow, allowedFlows) { + t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be false.") + } +} + +// Should return true as we complete all stages of an allowed flow, though out +// of order, as well as extraneous stages. +func TestFlowCheckingUnorderedAndExtraneous(t *testing.T) { + testFlow := []authtypes.LoginType{ + authtypes.LoginType("stage5"), + authtypes.LoginType("stage4"), + authtypes.LoginType("stage3"), + authtypes.LoginType("stage2"), + authtypes.LoginType("stage1"), + } + if !checkFlowCompleted(testFlow, allowedFlows) { + t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be true.") } }