mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-12 09:23:09 -06:00
Split up registration flow checking into smaller methods
* Fix broken link * Added tests for flow checking Signed-off-by: Andrew Morgan (https://amorgan.xyz) <andrew@amorgan.xyz>
This commit is contained in:
parent
9066880402
commit
d2c4ad43a1
|
|
@ -15,7 +15,7 @@
|
|||
package authtypes
|
||||
|
||||
// Flow represents one possible way that the client can authenticate a request.
|
||||
// http://matrix.org/docs/spec/HEAD/client_server/r0.3.0.html#user-interactive-authentication-api
|
||||
// https://matrix.org/docs/spec/client_server/r0.3.0.html#user-interactive-authentication-api
|
||||
type Flow struct {
|
||||
Stages []LoginType `json:"stages"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -251,21 +251,19 @@ func handleRegistrationFlow(
|
|||
}
|
||||
}
|
||||
|
||||
// Check if a registration flow has been completed successfully
|
||||
for _, flow := range cfg.Derived.Registration.Flows {
|
||||
if checkFlowsEqual(flow, authtypes.Flow{sessions[sessionID]}) {
|
||||
return completeRegistration(req.Context(), accountDB, deviceDB,
|
||||
r.Username, r.Password, r.InitialDisplayName)
|
||||
// Check if the user's registration flow has been completed successfully
|
||||
if !checkFlowCompleted(authtypes.Flow{sessions[sessionID]}, cfg.Derived.Registration.Flows) {
|
||||
// There are still more stages to complete.
|
||||
// Return the flows and those that have been completed.
|
||||
return util.JSONResponse{
|
||||
Code: 401,
|
||||
JSON: newUserInteractiveResponse(sessionID,
|
||||
cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params),
|
||||
}
|
||||
}
|
||||
|
||||
// There are still more stages to complete.
|
||||
// Return the flows and those that have been completed.
|
||||
return util.JSONResponse{
|
||||
Code: 401,
|
||||
JSON: newUserInteractiveResponse(sessionID,
|
||||
cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params),
|
||||
}
|
||||
return completeRegistration(req.Context(), accountDB, deviceDB,
|
||||
r.Username, r.Password, r.InitialDisplayName)
|
||||
}
|
||||
|
||||
// LegacyRegister process register requests from the legacy v1 API
|
||||
|
|
@ -420,24 +418,22 @@ func isValidMacLogin(
|
|||
return hmac.Equal(givenMac, expectedMAC), nil
|
||||
}
|
||||
|
||||
// checkFlowsEqual checks if two registration flows have the same stages
|
||||
// within them. Order of stages does not matter.
|
||||
func checkFlowsEqual(aFlow, bFlow authtypes.Flow) bool {
|
||||
a := aFlow.Stages
|
||||
b := bFlow.Stages
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
// checkFlows checks a single flow a against another, b. If a contains at least
|
||||
// all of the stages that b does, checkFlows returns true.
|
||||
func checkFlows(a []authtypes.LoginType, b []authtypes.LoginType) bool {
|
||||
// Sort the slices for simple comparison
|
||||
sort.Slice(a, func(i, j int) bool { return a[i] < a[j] })
|
||||
sort.Slice(b, func(i, j int) bool { return b[i] < b[j] })
|
||||
|
||||
// Account for any extra stages a user may do unnecessarily
|
||||
extraStages := len(b) - len(a)
|
||||
// Account for any extra stages a user may unnecessarily do.
|
||||
extraStages := len(a) - len(b)
|
||||
for i := range b {
|
||||
if extraStages < 0 {
|
||||
// The provided flow has run out of possible extraneous stages.
|
||||
return false
|
||||
}
|
||||
if a[i] != b[i] {
|
||||
// Wasn't a match, drop an extraneous stage.
|
||||
extraStages--
|
||||
continue
|
||||
}
|
||||
|
|
@ -445,11 +441,24 @@ func checkFlowsEqual(aFlow, bFlow authtypes.Flow) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// checkFlowCompleted checks if a registration flow complies with any flow
|
||||
// dictated by the server. Order of stages does not matter. A user may complete
|
||||
// extra stages as long as the required stages of at least one flow is met.
|
||||
func checkFlowCompleted(flow authtypes.Flow, derivedFlows []authtypes.Flow) bool {
|
||||
// Iterate through possible flows to check whether any have been fully completed.
|
||||
for _, derivedFlow := range derivedFlows {
|
||||
if checkFlows(flow.Stages, derivedFlow.Stages) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type availableResponse struct {
|
||||
Available bool `json:"available"`
|
||||
}
|
||||
|
||||
// RegisterAvailable checks if the username is already taken or invalid
|
||||
// RegisterAvailable checks if the username is already taken or invalid.
|
||||
func RegisterAvailable(
|
||||
req *http.Request,
|
||||
accountDB *accounts.Database,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,86 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
// Copyright 2017 New Vector Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
)
|
||||
|
||||
func TestFlowChecking(t *testing.T) {
|
||||
derivedFlows := []authtypes.Flow{
|
||||
{
|
||||
[]authtypes.LoginType{
|
||||
authtypes.LoginType("type1"),
|
||||
authtypes.LoginType("type2"),
|
||||
},
|
||||
},
|
||||
{
|
||||
[]authtypes.LoginType{
|
||||
authtypes.LoginType("type1"),
|
||||
authtypes.LoginType("type3"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
testFlow1 := authtypes.Flow{
|
||||
[]authtypes.LoginType{
|
||||
authtypes.LoginType("type1"),
|
||||
authtypes.LoginType("type3"),
|
||||
},
|
||||
}
|
||||
testFlow2 := authtypes.Flow{
|
||||
[]authtypes.LoginType{
|
||||
authtypes.LoginType("type2"),
|
||||
authtypes.LoginType("type3"),
|
||||
},
|
||||
}
|
||||
testFlow3 := authtypes.Flow{
|
||||
[]authtypes.LoginType{
|
||||
authtypes.LoginType("type1"),
|
||||
authtypes.LoginType("type3"),
|
||||
authtypes.LoginType("type4"),
|
||||
},
|
||||
}
|
||||
testFlow4 := authtypes.Flow{
|
||||
[]authtypes.LoginType{},
|
||||
}
|
||||
testFlow5 := authtypes.Flow{
|
||||
[]authtypes.LoginType{
|
||||
authtypes.LoginType("type3"),
|
||||
authtypes.LoginType("type2"),
|
||||
authtypes.LoginType("type1"),
|
||||
},
|
||||
}
|
||||
|
||||
if !checkFlowCompleted(testFlow1, derivedFlows) {
|
||||
t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow1, ", from derived flows: ", derivedFlows, ". Should be true."))
|
||||
}
|
||||
if checkFlowCompleted(testFlow2, derivedFlows) {
|
||||
t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow2, ", from derived flows: ", derivedFlows, ". Should be false."))
|
||||
}
|
||||
if !checkFlowCompleted(testFlow3, derivedFlows) {
|
||||
t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow3, ", from derived flows: ", derivedFlows, ". Should be true."))
|
||||
}
|
||||
if checkFlowCompleted(testFlow4, derivedFlows) {
|
||||
t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow4, ", from derived flows: ", derivedFlows, ". Should be false."))
|
||||
}
|
||||
if !checkFlowCompleted(testFlow5, derivedFlows) {
|
||||
t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow5, ", from derived flows: ", derivedFlows, ". Should be true."))
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue