Split up registration flow checking into smaller methods

* Fix broken link
* Added tests for flow checking

Signed-off-by: Andrew Morgan (https://amorgan.xyz) <andrew@amorgan.xyz>
This commit is contained in:
Andrew Morgan (https://amorgan.xyz) 2017-11-26 16:15:05 -08:00
parent 9066880402
commit d2c4ad43a1
No known key found for this signature in database
GPG key ID: 174BEAB009FD176D
3 changed files with 119 additions and 24 deletions

View file

@ -15,7 +15,7 @@
package authtypes
// Flow represents one possible way that the client can authenticate a request.
// http://matrix.org/docs/spec/HEAD/client_server/r0.3.0.html#user-interactive-authentication-api
// https://matrix.org/docs/spec/client_server/r0.3.0.html#user-interactive-authentication-api
type Flow struct {
Stages []LoginType `json:"stages"`
}

View file

@ -251,21 +251,19 @@ func handleRegistrationFlow(
}
}
// Check if a registration flow has been completed successfully
for _, flow := range cfg.Derived.Registration.Flows {
if checkFlowsEqual(flow, authtypes.Flow{sessions[sessionID]}) {
return completeRegistration(req.Context(), accountDB, deviceDB,
r.Username, r.Password, r.InitialDisplayName)
// Check if the user's registration flow has been completed successfully
if !checkFlowCompleted(authtypes.Flow{sessions[sessionID]}, cfg.Derived.Registration.Flows) {
// There are still more stages to complete.
// Return the flows and those that have been completed.
return util.JSONResponse{
Code: 401,
JSON: newUserInteractiveResponse(sessionID,
cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params),
}
}
// There are still more stages to complete.
// Return the flows and those that have been completed.
return util.JSONResponse{
Code: 401,
JSON: newUserInteractiveResponse(sessionID,
cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params),
}
return completeRegistration(req.Context(), accountDB, deviceDB,
r.Username, r.Password, r.InitialDisplayName)
}
// LegacyRegister process register requests from the legacy v1 API
@ -420,24 +418,22 @@ func isValidMacLogin(
return hmac.Equal(givenMac, expectedMAC), nil
}
// checkFlowsEqual checks if two registration flows have the same stages
// within them. Order of stages does not matter.
func checkFlowsEqual(aFlow, bFlow authtypes.Flow) bool {
a := aFlow.Stages
b := bFlow.Stages
if len(a) != len(b) {
return false
}
// checkFlows checks a single flow a against another, b. If a contains at least
// all of the stages that b does, checkFlows returns true.
func checkFlows(a []authtypes.LoginType, b []authtypes.LoginType) bool {
// Sort the slices for simple comparison
sort.Slice(a, func(i, j int) bool { return a[i] < a[j] })
sort.Slice(b, func(i, j int) bool { return b[i] < b[j] })
// Account for any extra stages a user may do unnecessarily
extraStages := len(b) - len(a)
// Account for any extra stages a user may unnecessarily do.
extraStages := len(a) - len(b)
for i := range b {
if extraStages < 0 {
// The provided flow has run out of possible extraneous stages.
return false
}
if a[i] != b[i] {
// Wasn't a match, drop an extraneous stage.
extraStages--
continue
}
@ -445,11 +441,24 @@ func checkFlowsEqual(aFlow, bFlow authtypes.Flow) bool {
return true
}
// checkFlowCompleted checks if a registration flow complies with any flow
// dictated by the server. Order of stages does not matter. A user may complete
// extra stages as long as the required stages of at least one flow is met.
func checkFlowCompleted(flow authtypes.Flow, derivedFlows []authtypes.Flow) bool {
// Iterate through possible flows to check whether any have been fully completed.
for _, derivedFlow := range derivedFlows {
if checkFlows(flow.Stages, derivedFlow.Stages) {
return true
}
}
return false
}
type availableResponse struct {
Available bool `json:"available"`
}
// RegisterAvailable checks if the username is already taken or invalid
// RegisterAvailable checks if the username is already taken or invalid.
func RegisterAvailable(
req *http.Request,
accountDB *accounts.Database,

View file

@ -0,0 +1,86 @@
// Copyright 2017 Vector Creations Ltd
// Copyright 2017 New Vector Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"fmt"
"testing"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
func TestFlowChecking(t *testing.T) {
derivedFlows := []authtypes.Flow{
{
[]authtypes.LoginType{
authtypes.LoginType("type1"),
authtypes.LoginType("type2"),
},
},
{
[]authtypes.LoginType{
authtypes.LoginType("type1"),
authtypes.LoginType("type3"),
},
},
}
testFlow1 := authtypes.Flow{
[]authtypes.LoginType{
authtypes.LoginType("type1"),
authtypes.LoginType("type3"),
},
}
testFlow2 := authtypes.Flow{
[]authtypes.LoginType{
authtypes.LoginType("type2"),
authtypes.LoginType("type3"),
},
}
testFlow3 := authtypes.Flow{
[]authtypes.LoginType{
authtypes.LoginType("type1"),
authtypes.LoginType("type3"),
authtypes.LoginType("type4"),
},
}
testFlow4 := authtypes.Flow{
[]authtypes.LoginType{},
}
testFlow5 := authtypes.Flow{
[]authtypes.LoginType{
authtypes.LoginType("type3"),
authtypes.LoginType("type2"),
authtypes.LoginType("type1"),
},
}
if !checkFlowCompleted(testFlow1, derivedFlows) {
t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow1, ", from derived flows: ", derivedFlows, ". Should be true."))
}
if checkFlowCompleted(testFlow2, derivedFlows) {
t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow2, ", from derived flows: ", derivedFlows, ". Should be false."))
}
if !checkFlowCompleted(testFlow3, derivedFlows) {
t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow3, ", from derived flows: ", derivedFlows, ". Should be true."))
}
if checkFlowCompleted(testFlow4, derivedFlows) {
t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow4, ", from derived flows: ", derivedFlows, ". Should be false."))
}
if !checkFlowCompleted(testFlow5, derivedFlows) {
t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow5, ", from derived flows: ", derivedFlows, ". Should be true."))
}
}