mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-31 18:53:10 -06:00
Update to use time.AfterFunc, add more tests
This commit is contained in:
parent
3c7c4de00d
commit
87d2d29fbe
|
|
@ -78,8 +78,8 @@ type sessionsDict struct {
|
||||||
timer map[string]*time.Timer
|
timer map[string]*time.Timer
|
||||||
}
|
}
|
||||||
|
|
||||||
// defaultTimeout is the timeout used to
|
// defaultTimeout is the timeout used to clean up sessions
|
||||||
const defaultTimeOut = time.Minute * 10
|
const defaultTimeOut = time.Minute * 5
|
||||||
|
|
||||||
// getCompletedStages returns the completed stages for a session.
|
// getCompletedStages returns the completed stages for a session.
|
||||||
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
|
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
|
||||||
|
|
@ -95,9 +95,9 @@ func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginTyp
|
||||||
|
|
||||||
// addParams adds a registerRequest to a sessionID and starts a timer to delete that registerRequest
|
// addParams adds a registerRequest to a sessionID and starts a timer to delete that registerRequest
|
||||||
func (d *sessionsDict) addParams(sessionID string, r registerRequest) {
|
func (d *sessionsDict) addParams(sessionID string, r registerRequest) {
|
||||||
d.Lock()
|
|
||||||
d.Unlock()
|
|
||||||
d.startTimer(defaultTimeOut, sessionID)
|
d.startTimer(defaultTimeOut, sessionID)
|
||||||
|
d.Lock()
|
||||||
|
defer d.Unlock()
|
||||||
d.params[sessionID] = r
|
d.params[sessionID] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -117,6 +117,7 @@ func (d *sessionsDict) deleteSession(sessionID string) {
|
||||||
delete(d.sessions, sessionID)
|
delete(d.sessions, sessionID)
|
||||||
// stop the timer, e.g. because the registration was completed
|
// stop the timer, e.g. because the registration was completed
|
||||||
if t, ok := d.timer[sessionID]; ok {
|
if t, ok := d.timer[sessionID]; ok {
|
||||||
|
// trying to drain the channel results in a deadlock?
|
||||||
t.Stop()
|
t.Stop()
|
||||||
delete(d.timer, sessionID)
|
delete(d.timer, sessionID)
|
||||||
}
|
}
|
||||||
|
|
@ -131,20 +132,19 @@ func newSessionsDict() *sessionsDict {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *sessionsDict) startTimer(duration time.Duration, sessionID string) {
|
func (d *sessionsDict) startTimer(duration time.Duration, sessionID string) {
|
||||||
d.RLock()
|
d.Lock()
|
||||||
defer d.RUnlock()
|
defer d.Unlock()
|
||||||
if _, ok := d.timer[sessionID]; !ok {
|
t, ok := d.timer[sessionID]
|
||||||
go func() {
|
if ok {
|
||||||
timer := time.NewTimer(duration)
|
if !t.Stop() {
|
||||||
d.Lock()
|
<-t.C
|
||||||
d.timer[sessionID] = timer
|
}
|
||||||
d.Unlock()
|
t.Reset(duration)
|
||||||
select {
|
return
|
||||||
case <-timer.C:
|
|
||||||
d.deleteSession(sessionID)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
d.timer[sessionID] = time.AfterFunc(duration, func() {
|
||||||
|
d.deleteSession(sessionID)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// addCompletedSessionStage records that a session has completed an auth stage
|
// addCompletedSessionStage records that a session has completed an auth stage
|
||||||
|
|
|
||||||
|
|
@ -214,19 +214,19 @@ func TestSessionCleanUp(t *testing.T) {
|
||||||
s := newSessionsDict()
|
s := newSessionsDict()
|
||||||
|
|
||||||
t.Run("session is cleaned up after a while", func(t *testing.T) {
|
t.Run("session is cleaned up after a while", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
dummySession := "helloWorld"
|
dummySession := "helloWorld"
|
||||||
s.Lock()
|
|
||||||
// manually added, as s.addParams() would start the timer with the default timeout
|
// manually added, as s.addParams() would start the timer with the default timeout
|
||||||
s.params[dummySession] = registerRequest{Username: "Testing"}
|
s.params[dummySession] = registerRequest{Username: "Testing"}
|
||||||
s.Unlock()
|
|
||||||
s.startTimer(time.Millisecond, dummySession)
|
s.startTimer(time.Millisecond, dummySession)
|
||||||
time.Sleep(time.Millisecond * 10)
|
time.Sleep(time.Millisecond * 2)
|
||||||
if data, ok := s.getParams(dummySession); ok {
|
if data, ok := s.getParams(dummySession); ok {
|
||||||
t.Errorf("expected session to be deleted: %+v", data)
|
t.Errorf("expected session to be deleted: %+v", data)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("session is deleted, once the registration completed", func(t *testing.T) {
|
t.Run("session is deleted, once the registration completed", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
dummySession := "helloWorld2"
|
dummySession := "helloWorld2"
|
||||||
s.startTimer(time.Minute, dummySession)
|
s.startTimer(time.Minute, dummySession)
|
||||||
s.deleteSession(dummySession)
|
s.deleteSession(dummySession)
|
||||||
|
|
@ -235,4 +235,19 @@ func TestSessionCleanUp(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("session timer is restarted after second call", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
dummySession := "helloWorld3"
|
||||||
|
// the following will start a timer with the default timeout of 5min
|
||||||
|
s.addParams(dummySession, registerRequest{Username: "Testing"})
|
||||||
|
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha)
|
||||||
|
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy)
|
||||||
|
s.getCompletedStages(dummySession)
|
||||||
|
// reset the timer with a lower timeout
|
||||||
|
s.startTimer(time.Millisecond, dummySession)
|
||||||
|
time.Sleep(time.Millisecond * 2)
|
||||||
|
if data, ok := s.getParams(dummySession); ok {
|
||||||
|
t.Errorf("expected session to be deleted: %+v", data)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
Loading…
Reference in a new issue