diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/authtypes/flow.go b/src/github.com/matrix-org/dendrite/clientapi/auth/authtypes/flow.go index 60ce17710..0f9f29ea5 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/authtypes/flow.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/authtypes/flow.go @@ -15,7 +15,7 @@ package authtypes // Flow represents one possible way that the client can authenticate a request. -// http://matrix.org/docs/spec/HEAD/client_server/r0.3.0.html#user-interactive-authentication-api +// https://matrix.org/docs/spec/client_server/r0.3.0.html#user-interactive-authentication-api type Flow struct { Stages []LoginType `json:"stages"` } diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go index 225fc594a..070317215 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go @@ -251,21 +251,19 @@ func handleRegistrationFlow( } } - // Check if a registration flow has been completed successfully - for _, flow := range cfg.Derived.Registration.Flows { - if checkFlowsEqual(flow, authtypes.Flow{sessions[sessionID]}) { - return completeRegistration(req.Context(), accountDB, deviceDB, - r.Username, r.Password, r.InitialDisplayName) + // Check if the user's registration flow has been completed successfully + if !checkFlowCompleted(authtypes.Flow{sessions[sessionID]}, cfg.Derived.Registration.Flows) { + // There are still more stages to complete. + // Return the flows and those that have been completed. + return util.JSONResponse{ + Code: 401, + JSON: newUserInteractiveResponse(sessionID, + cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params), } } - // There are still more stages to complete. - // Return the flows and those that have been completed. - return util.JSONResponse{ - Code: 401, - JSON: newUserInteractiveResponse(sessionID, - cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params), - } + return completeRegistration(req.Context(), accountDB, deviceDB, + r.Username, r.Password, r.InitialDisplayName) } // LegacyRegister process register requests from the legacy v1 API @@ -420,24 +418,22 @@ func isValidMacLogin( return hmac.Equal(givenMac, expectedMAC), nil } -// checkFlowsEqual checks if two registration flows have the same stages -// within them. Order of stages does not matter. -func checkFlowsEqual(aFlow, bFlow authtypes.Flow) bool { - a := aFlow.Stages - b := bFlow.Stages - if len(a) != len(b) { - return false - } +// checkFlows checks a single flow a against another, b. If a contains at least +// all of the stages that b does, checkFlows returns true. +func checkFlows(a []authtypes.LoginType, b []authtypes.LoginType) bool { + // Sort the slices for simple comparison sort.Slice(a, func(i, j int) bool { return a[i] < a[j] }) sort.Slice(b, func(i, j int) bool { return b[i] < b[j] }) - // Account for any extra stages a user may do unnecessarily - extraStages := len(b) - len(a) + // Account for any extra stages a user may unnecessarily do. + extraStages := len(a) - len(b) for i := range b { if extraStages < 0 { + // The provided flow has run out of possible extraneous stages. return false } if a[i] != b[i] { + // Wasn't a match, drop an extraneous stage. extraStages-- continue } @@ -445,11 +441,24 @@ func checkFlowsEqual(aFlow, bFlow authtypes.Flow) bool { return true } +// checkFlowCompleted checks if a registration flow complies with any flow +// dictated by the server. Order of stages does not matter. A user may complete +// extra stages as long as the required stages of at least one flow is met. +func checkFlowCompleted(flow authtypes.Flow, derivedFlows []authtypes.Flow) bool { + // Iterate through possible flows to check whether any have been fully completed. + for _, derivedFlow := range derivedFlows { + if checkFlows(flow.Stages, derivedFlow.Stages) { + return true + } + } + return false +} + type availableResponse struct { Available bool `json:"available"` } -// RegisterAvailable checks if the username is already taken or invalid +// RegisterAvailable checks if the username is already taken or invalid. func RegisterAvailable( req *http.Request, accountDB *accounts.Database, diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/register_test.go b/src/github.com/matrix-org/dendrite/clientapi/routing/register_test.go new file mode 100644 index 000000000..86dcffd36 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/register_test.go @@ -0,0 +1,86 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "fmt" + "testing" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" +) + +func TestFlowChecking(t *testing.T) { + derivedFlows := []authtypes.Flow{ + { + []authtypes.LoginType{ + authtypes.LoginType("type1"), + authtypes.LoginType("type2"), + }, + }, + { + []authtypes.LoginType{ + authtypes.LoginType("type1"), + authtypes.LoginType("type3"), + }, + }, + } + + testFlow1 := authtypes.Flow{ + []authtypes.LoginType{ + authtypes.LoginType("type1"), + authtypes.LoginType("type3"), + }, + } + testFlow2 := authtypes.Flow{ + []authtypes.LoginType{ + authtypes.LoginType("type2"), + authtypes.LoginType("type3"), + }, + } + testFlow3 := authtypes.Flow{ + []authtypes.LoginType{ + authtypes.LoginType("type1"), + authtypes.LoginType("type3"), + authtypes.LoginType("type4"), + }, + } + testFlow4 := authtypes.Flow{ + []authtypes.LoginType{}, + } + testFlow5 := authtypes.Flow{ + []authtypes.LoginType{ + authtypes.LoginType("type3"), + authtypes.LoginType("type2"), + authtypes.LoginType("type1"), + }, + } + + if !checkFlowCompleted(testFlow1, derivedFlows) { + t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow1, ", from derived flows: ", derivedFlows, ". Should be true.")) + } + if checkFlowCompleted(testFlow2, derivedFlows) { + t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow2, ", from derived flows: ", derivedFlows, ". Should be false.")) + } + if !checkFlowCompleted(testFlow3, derivedFlows) { + t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow3, ", from derived flows: ", derivedFlows, ". Should be true.")) + } + if checkFlowCompleted(testFlow4, derivedFlows) { + t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow4, ", from derived flows: ", derivedFlows, ". Should be false.")) + } + if !checkFlowCompleted(testFlow5, derivedFlows) { + t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow5, ", from derived flows: ", derivedFlows, ". Should be true.")) + } +}