Remember parameters on registration (#2225)
* Remember parameters for sessions Cleanup sessions on successfully registering or after a while * Add flakey test * Update to use time.AfterFunc, add more tests * Try to drain the channel, if possible
This commit is contained in:
parent
4c07374c42
commit
cf27e26712
|
@ -162,7 +162,7 @@ func AuthFallback(
|
|||
}
|
||||
|
||||
// Success. Add recaptcha as a completed login flow
|
||||
AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
|
||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
|
||||
|
||||
serveSuccess()
|
||||
return nil
|
||||
|
|
|
@ -70,7 +70,7 @@ func UploadCrossSigningDeviceKeys(
|
|||
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil {
|
||||
return *authErr
|
||||
}
|
||||
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||
|
||||
uploadReq.UserID = device.UserID
|
||||
keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes)
|
||||
|
|
|
@ -74,7 +74,7 @@ func Password(
|
|||
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {
|
||||
return *authErr
|
||||
}
|
||||
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||
|
||||
// Check the new password strength.
|
||||
if resErr = validatePassword(r.NewPassword); resErr != nil {
|
||||
|
|
|
@ -72,14 +72,19 @@ func init() {
|
|||
// sessionsDict keeps track of completed auth stages for each session.
|
||||
// It shouldn't be passed by value because it contains a mutex.
|
||||
type sessionsDict struct {
|
||||
sync.Mutex
|
||||
sync.RWMutex
|
||||
sessions map[string][]authtypes.LoginType
|
||||
params map[string]registerRequest
|
||||
timer map[string]*time.Timer
|
||||
}
|
||||
|
||||
// GetCompletedStages returns the completed stages for a session.
|
||||
func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
// defaultTimeout is the timeout used to clean up sessions
|
||||
const defaultTimeOut = time.Minute * 5
|
||||
|
||||
// getCompletedStages returns the completed stages for a session.
|
||||
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
|
||||
if completedStages, ok := d.sessions[sessionID]; ok {
|
||||
return completedStages
|
||||
|
@ -88,28 +93,79 @@ func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginTyp
|
|||
return make([]authtypes.LoginType, 0)
|
||||
}
|
||||
|
||||
// addParams adds a registerRequest to a sessionID and starts a timer to delete that registerRequest
|
||||
func (d *sessionsDict) addParams(sessionID string, r registerRequest) {
|
||||
d.startTimer(defaultTimeOut, sessionID)
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
d.params[sessionID] = r
|
||||
}
|
||||
|
||||
func (d *sessionsDict) getParams(sessionID string) (registerRequest, bool) {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
r, ok := d.params[sessionID]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
// deleteSession cleans up a given session, either because the registration completed
|
||||
// successfully, or because a given timeout (default: 5min) was reached.
|
||||
func (d *sessionsDict) deleteSession(sessionID string) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
delete(d.params, sessionID)
|
||||
delete(d.sessions, sessionID)
|
||||
// stop the timer, e.g. because the registration was completed
|
||||
if t, ok := d.timer[sessionID]; ok {
|
||||
if !t.Stop() {
|
||||
select {
|
||||
case <-t.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
delete(d.timer, sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
func newSessionsDict() *sessionsDict {
|
||||
return &sessionsDict{
|
||||
sessions: make(map[string][]authtypes.LoginType),
|
||||
params: make(map[string]registerRequest),
|
||||
timer: make(map[string]*time.Timer),
|
||||
}
|
||||
}
|
||||
|
||||
// AddCompletedSessionStage records that a session has completed an auth stage.
|
||||
func AddCompletedSessionStage(sessionID string, stage authtypes.LoginType) {
|
||||
sessions.Lock()
|
||||
defer sessions.Unlock()
|
||||
func (d *sessionsDict) startTimer(duration time.Duration, sessionID string) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
t, ok := d.timer[sessionID]
|
||||
if ok {
|
||||
if !t.Stop() {
|
||||
<-t.C
|
||||
}
|
||||
t.Reset(duration)
|
||||
return
|
||||
}
|
||||
d.timer[sessionID] = time.AfterFunc(duration, func() {
|
||||
d.deleteSession(sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
for _, completedStage := range sessions.sessions[sessionID] {
|
||||
// addCompletedSessionStage records that a session has completed an auth stage
|
||||
// also starts a timer to delete the session once done.
|
||||
func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtypes.LoginType) {
|
||||
d.startTimer(defaultTimeOut, sessionID)
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
for _, completedStage := range d.sessions[sessionID] {
|
||||
if completedStage == stage {
|
||||
return
|
||||
}
|
||||
}
|
||||
sessions.sessions[sessionID] = append(sessions.sessions[sessionID], stage)
|
||||
d.sessions[sessionID] = append(sessions.sessions[sessionID], stage)
|
||||
}
|
||||
|
||||
var (
|
||||
// TODO: Remove old sessions. Need to do so on a session-specific timeout.
|
||||
// sessions stores the completed flow stages for all sessions. Referenced using their sessionID.
|
||||
sessions = newSessionsDict()
|
||||
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
|
||||
)
|
||||
|
@ -167,7 +223,7 @@ func newUserInteractiveResponse(
|
|||
params map[string]interface{},
|
||||
) userInteractiveResponse {
|
||||
return userInteractiveResponse{
|
||||
fs, sessions.GetCompletedStages(sessionID), params, sessionID,
|
||||
fs, sessions.getCompletedStages(sessionID), params, sessionID,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -645,12 +701,12 @@ func handleRegistrationFlow(
|
|||
}
|
||||
|
||||
// Add Recaptcha to the list of completed registration stages
|
||||
AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
|
||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
|
||||
|
||||
case authtypes.LoginTypeDummy:
|
||||
// there is nothing to do
|
||||
// Add Dummy to the list of completed registration stages
|
||||
AddCompletedSessionStage(sessionID, authtypes.LoginTypeDummy)
|
||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeDummy)
|
||||
|
||||
case "":
|
||||
// An empty auth type means that we want to fetch the available
|
||||
|
@ -666,7 +722,7 @@ func handleRegistrationFlow(
|
|||
// Check if the user's registration flow has been completed successfully
|
||||
// A response with current registration flow and remaining available methods
|
||||
// will be returned if a flow has not been successfully completed yet
|
||||
return checkAndCompleteFlow(sessions.GetCompletedStages(sessionID),
|
||||
return checkAndCompleteFlow(sessions.getCompletedStages(sessionID),
|
||||
req, r, sessionID, cfg, userAPI)
|
||||
}
|
||||
|
||||
|
@ -708,7 +764,7 @@ func handleApplicationServiceRegistration(
|
|||
// Don't need to worry about appending to registration stages as
|
||||
// application service registration is entirely separate.
|
||||
return completeRegistration(
|
||||
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(),
|
||||
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session,
|
||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService,
|
||||
)
|
||||
}
|
||||
|
@ -727,11 +783,11 @@ func checkAndCompleteFlow(
|
|||
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
|
||||
// This flow was completed, registration can continue
|
||||
return completeRegistration(
|
||||
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(),
|
||||
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID,
|
||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser,
|
||||
)
|
||||
}
|
||||
|
||||
sessions.addParams(sessionID, r)
|
||||
// There are still more stages to complete.
|
||||
// Return the flows and those that have been completed.
|
||||
return util.JSONResponse{
|
||||
|
@ -750,11 +806,25 @@ func checkAndCompleteFlow(
|
|||
func completeRegistration(
|
||||
ctx context.Context,
|
||||
userAPI userapi.UserInternalAPI,
|
||||
username, password, appserviceID, ipAddr, userAgent string,
|
||||
username, password, appserviceID, ipAddr, userAgent, sessionID string,
|
||||
inhibitLogin eventutil.WeakBoolean,
|
||||
displayName, deviceID *string,
|
||||
accType userapi.AccountType,
|
||||
) util.JSONResponse {
|
||||
var registrationOK bool
|
||||
defer func() {
|
||||
if registrationOK {
|
||||
sessions.deleteSession(sessionID)
|
||||
}
|
||||
}()
|
||||
|
||||
if data, ok := sessions.getParams(sessionID); ok {
|
||||
username = data.Username
|
||||
password = data.Password
|
||||
deviceID = data.DeviceID
|
||||
displayName = data.InitialDisplayName
|
||||
inhibitLogin = data.InhibitLogin
|
||||
}
|
||||
if username == "" {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
|
@ -795,6 +865,7 @@ func completeRegistration(
|
|||
// Check whether inhibit_login option is set. If so, don't create an access
|
||||
// token or a device for this user
|
||||
if inhibitLogin {
|
||||
registrationOK = true
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: registerResponse{
|
||||
|
@ -828,6 +899,7 @@ func completeRegistration(
|
|||
}
|
||||
}
|
||||
|
||||
registrationOK = true
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: registerResponse{
|
||||
|
@ -976,5 +1048,5 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS
|
|||
if ssrr.Admin {
|
||||
accType = userapi.AccountTypeAdmin
|
||||
}
|
||||
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID, accType)
|
||||
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType)
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ package routing
|
|||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
|
@ -140,7 +141,7 @@ func TestFlowCheckingExtraneousIncorrectInput(t *testing.T) {
|
|||
func TestEmptyCompletedFlows(t *testing.T) {
|
||||
fakeEmptySessions := newSessionsDict()
|
||||
fakeSessionID := "aRandomSessionIDWhichDoesNotExist"
|
||||
ret := fakeEmptySessions.GetCompletedStages(fakeSessionID)
|
||||
ret := fakeEmptySessions.getCompletedStages(fakeSessionID)
|
||||
|
||||
// check for []
|
||||
if ret == nil || len(ret) != 0 {
|
||||
|
@ -208,3 +209,45 @@ func TestValidationOfApplicationServices(t *testing.T) {
|
|||
t.Errorf("user_id should not have been valid: @_something_else:localhost")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionCleanUp(t *testing.T) {
|
||||
s := newSessionsDict()
|
||||
|
||||
t.Run("session is cleaned up after a while", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dummySession := "helloWorld"
|
||||
// manually added, as s.addParams() would start the timer with the default timeout
|
||||
s.params[dummySession] = registerRequest{Username: "Testing"}
|
||||
s.startTimer(time.Millisecond, dummySession)
|
||||
time.Sleep(time.Millisecond * 2)
|
||||
if data, ok := s.getParams(dummySession); ok {
|
||||
t.Errorf("expected session to be deleted: %+v", data)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("session is deleted, once the registration completed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dummySession := "helloWorld2"
|
||||
s.startTimer(time.Minute, dummySession)
|
||||
s.deleteSession(dummySession)
|
||||
if data, ok := s.getParams(dummySession); ok {
|
||||
t.Errorf("expected session to be deleted: %+v", data)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("session timer is restarted after second call", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dummySession := "helloWorld3"
|
||||
// the following will start a timer with the default timeout of 5min
|
||||
s.addParams(dummySession, registerRequest{Username: "Testing"})
|
||||
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha)
|
||||
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy)
|
||||
s.getCompletedStages(dummySession)
|
||||
// reset the timer with a lower timeout
|
||||
s.startTimer(time.Millisecond, dummySession)
|
||||
time.Sleep(time.Millisecond * 2)
|
||||
if data, ok := s.getParams(dummySession); ok {
|
||||
t.Errorf("expected session to be deleted: %+v", data)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -24,6 +24,7 @@ Local device key changes get to remote servers with correct prev_id
|
|||
|
||||
# Flakey
|
||||
Local device key changes appear in /keys/changes
|
||||
/context/ with lazy_load_members filter works
|
||||
|
||||
# we don't support groups
|
||||
Remove group category
|
||||
|
|
|
@ -601,3 +601,7 @@ Can query remote device keys using POST after notification
|
|||
Device deletion propagates over federation
|
||||
Get left notifs in sync and /keys/changes when other user leaves
|
||||
Remote banned user is kicked and may not rejoin until unbanned
|
||||
registration remembers parameters
|
||||
registration accepts non-ascii passwords
|
||||
registration with inhibit_login inhibits login
|
||||
|
||||
|
|
Loading…
Reference in a new issue