mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-12 09:23:09 -06:00
Split up registration flow checking into smaller methods
* Fix broken link * Added tests for flow checking Signed-off-by: Andrew Morgan (https://amorgan.xyz) <andrew@amorgan.xyz>
This commit is contained in:
parent
9066880402
commit
d2c4ad43a1
|
|
@ -15,7 +15,7 @@
|
||||||
package authtypes
|
package authtypes
|
||||||
|
|
||||||
// Flow represents one possible way that the client can authenticate a request.
|
// Flow represents one possible way that the client can authenticate a request.
|
||||||
// http://matrix.org/docs/spec/HEAD/client_server/r0.3.0.html#user-interactive-authentication-api
|
// https://matrix.org/docs/spec/client_server/r0.3.0.html#user-interactive-authentication-api
|
||||||
type Flow struct {
|
type Flow struct {
|
||||||
Stages []LoginType `json:"stages"`
|
Stages []LoginType `json:"stages"`
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -251,21 +251,19 @@ func handleRegistrationFlow(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if a registration flow has been completed successfully
|
// Check if the user's registration flow has been completed successfully
|
||||||
for _, flow := range cfg.Derived.Registration.Flows {
|
if !checkFlowCompleted(authtypes.Flow{sessions[sessionID]}, cfg.Derived.Registration.Flows) {
|
||||||
if checkFlowsEqual(flow, authtypes.Flow{sessions[sessionID]}) {
|
// There are still more stages to complete.
|
||||||
return completeRegistration(req.Context(), accountDB, deviceDB,
|
// Return the flows and those that have been completed.
|
||||||
r.Username, r.Password, r.InitialDisplayName)
|
return util.JSONResponse{
|
||||||
|
Code: 401,
|
||||||
|
JSON: newUserInteractiveResponse(sessionID,
|
||||||
|
cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// There are still more stages to complete.
|
return completeRegistration(req.Context(), accountDB, deviceDB,
|
||||||
// Return the flows and those that have been completed.
|
r.Username, r.Password, r.InitialDisplayName)
|
||||||
return util.JSONResponse{
|
|
||||||
Code: 401,
|
|
||||||
JSON: newUserInteractiveResponse(sessionID,
|
|
||||||
cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LegacyRegister process register requests from the legacy v1 API
|
// LegacyRegister process register requests from the legacy v1 API
|
||||||
|
|
@ -420,24 +418,22 @@ func isValidMacLogin(
|
||||||
return hmac.Equal(givenMac, expectedMAC), nil
|
return hmac.Equal(givenMac, expectedMAC), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkFlowsEqual checks if two registration flows have the same stages
|
// checkFlows checks a single flow a against another, b. If a contains at least
|
||||||
// within them. Order of stages does not matter.
|
// all of the stages that b does, checkFlows returns true.
|
||||||
func checkFlowsEqual(aFlow, bFlow authtypes.Flow) bool {
|
func checkFlows(a []authtypes.LoginType, b []authtypes.LoginType) bool {
|
||||||
a := aFlow.Stages
|
// Sort the slices for simple comparison
|
||||||
b := bFlow.Stages
|
|
||||||
if len(a) != len(b) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
sort.Slice(a, func(i, j int) bool { return a[i] < a[j] })
|
sort.Slice(a, func(i, j int) bool { return a[i] < a[j] })
|
||||||
sort.Slice(b, func(i, j int) bool { return b[i] < b[j] })
|
sort.Slice(b, func(i, j int) bool { return b[i] < b[j] })
|
||||||
|
|
||||||
// Account for any extra stages a user may do unnecessarily
|
// Account for any extra stages a user may unnecessarily do.
|
||||||
extraStages := len(b) - len(a)
|
extraStages := len(a) - len(b)
|
||||||
for i := range b {
|
for i := range b {
|
||||||
if extraStages < 0 {
|
if extraStages < 0 {
|
||||||
|
// The provided flow has run out of possible extraneous stages.
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if a[i] != b[i] {
|
if a[i] != b[i] {
|
||||||
|
// Wasn't a match, drop an extraneous stage.
|
||||||
extraStages--
|
extraStages--
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
@ -445,11 +441,24 @@ func checkFlowsEqual(aFlow, bFlow authtypes.Flow) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkFlowCompleted checks if a registration flow complies with any flow
|
||||||
|
// dictated by the server. Order of stages does not matter. A user may complete
|
||||||
|
// extra stages as long as the required stages of at least one flow is met.
|
||||||
|
func checkFlowCompleted(flow authtypes.Flow, derivedFlows []authtypes.Flow) bool {
|
||||||
|
// Iterate through possible flows to check whether any have been fully completed.
|
||||||
|
for _, derivedFlow := range derivedFlows {
|
||||||
|
if checkFlows(flow.Stages, derivedFlow.Stages) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type availableResponse struct {
|
type availableResponse struct {
|
||||||
Available bool `json:"available"`
|
Available bool `json:"available"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterAvailable checks if the username is already taken or invalid
|
// RegisterAvailable checks if the username is already taken or invalid.
|
||||||
func RegisterAvailable(
|
func RegisterAvailable(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
accountDB *accounts.Database,
|
accountDB *accounts.Database,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,86 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
// Copyright 2017 New Vector Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFlowChecking(t *testing.T) {
|
||||||
|
derivedFlows := []authtypes.Flow{
|
||||||
|
{
|
||||||
|
[]authtypes.LoginType{
|
||||||
|
authtypes.LoginType("type1"),
|
||||||
|
authtypes.LoginType("type2"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]authtypes.LoginType{
|
||||||
|
authtypes.LoginType("type1"),
|
||||||
|
authtypes.LoginType("type3"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testFlow1 := authtypes.Flow{
|
||||||
|
[]authtypes.LoginType{
|
||||||
|
authtypes.LoginType("type1"),
|
||||||
|
authtypes.LoginType("type3"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
testFlow2 := authtypes.Flow{
|
||||||
|
[]authtypes.LoginType{
|
||||||
|
authtypes.LoginType("type2"),
|
||||||
|
authtypes.LoginType("type3"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
testFlow3 := authtypes.Flow{
|
||||||
|
[]authtypes.LoginType{
|
||||||
|
authtypes.LoginType("type1"),
|
||||||
|
authtypes.LoginType("type3"),
|
||||||
|
authtypes.LoginType("type4"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
testFlow4 := authtypes.Flow{
|
||||||
|
[]authtypes.LoginType{},
|
||||||
|
}
|
||||||
|
testFlow5 := authtypes.Flow{
|
||||||
|
[]authtypes.LoginType{
|
||||||
|
authtypes.LoginType("type3"),
|
||||||
|
authtypes.LoginType("type2"),
|
||||||
|
authtypes.LoginType("type1"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !checkFlowCompleted(testFlow1, derivedFlows) {
|
||||||
|
t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow1, ", from derived flows: ", derivedFlows, ". Should be true."))
|
||||||
|
}
|
||||||
|
if checkFlowCompleted(testFlow2, derivedFlows) {
|
||||||
|
t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow2, ", from derived flows: ", derivedFlows, ". Should be false."))
|
||||||
|
}
|
||||||
|
if !checkFlowCompleted(testFlow3, derivedFlows) {
|
||||||
|
t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow3, ", from derived flows: ", derivedFlows, ". Should be true."))
|
||||||
|
}
|
||||||
|
if checkFlowCompleted(testFlow4, derivedFlows) {
|
||||||
|
t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow4, ", from derived flows: ", derivedFlows, ". Should be false."))
|
||||||
|
}
|
||||||
|
if !checkFlowCompleted(testFlow5, derivedFlows) {
|
||||||
|
t.Error(fmt.Sprint("Failed to verify registration flow: ", testFlow5, ", from derived flows: ", derivedFlows, ". Should be true."))
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue