diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go index 09c8c4e1d..8d3ca2247 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go @@ -50,26 +50,36 @@ const ( sessionIDLength = 24 ) +// sessionsDict represents every sessions' completed flow stages. type sessionsDict struct { sessions map[string][]authtypes.LoginType } -func (d sessionsDict) Get(key string) []authtypes.LoginType { - if v, ok := d.sessions[key]; ok { - return v +// GetCompletedStages returns the completed stages for a session. +func (d sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType { + if completedStages, ok := d.sessions[sessionID]; ok { + return completedStages } + // Ensure that a empty slice is return and not nil. See gh #399. return make([]authtypes.LoginType, 0) } -func (d *sessionsDict) AddCompletedStage(key string, v authtypes.LoginType) { - d.sessions[key] = append(d.Get(key), v) +// AddCompletedStage adds a completed stage to the session. +func (d *sessionsDict) AddCompletedStage(sessionID string, stage authtypes.LoginType) { + d.sessions[sessionID] = append(d.GetCompletedStages(sessionID), stage) +} + +// newSessionsDict returns a sessionsDict whose contained map is initialized and empty. +func newSessionsDict() *sessionsDict { + return &sessionsDict{ + sessions: make(map[string][]authtypes.LoginType), + } } var ( // TODO: Remove old sessions. Need to do so on a session-specific timeout. - sessions = sessionsDict{ // Sessions and completed flow stages - sessions: make(map[string][]authtypes.LoginType), - } + // sessions stores the completed flow stages for all sessions. Referenced using their sessionID. + sessions = newSessionsDict() validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-./]+$`) ) @@ -129,7 +139,7 @@ func newUserInteractiveResponse( params map[string]interface{}, ) userInteractiveResponse { return userInteractiveResponse{ - fs, sessions.Get(sessionID), params, sessionID, + fs, sessions.GetCompletedStages(sessionID), params, sessionID, } } @@ -495,7 +505,8 @@ func handleRegistrationFlow( // Check if the user's registration flow has been completed successfully // A response with current registration flow and remaining available methods // will be returned if a flow has not been successfully completed yet - return checkAndCompleteFlow(sessions.Get(sessionID), req, r, sessionID, cfg, accountDB, deviceDB) + return checkAndCompleteFlow(sessions.GetCompletedStages(sessionID), + req, r, sessionID, cfg, accountDB, deviceDB) } // checkAndCompleteFlow checks if a given registration flow is completed given