Update to use time.AfterFunc, add more tests

This commit is contained in:
Till Faelligen 2022-02-24 12:50:49 +01:00
parent 3c7c4de00d
commit 87d2d29fbe
2 changed files with 35 additions and 20 deletions

View file

@ -78,8 +78,8 @@ type sessionsDict struct {
timer map[string]*time.Timer
}
// defaultTimeout is the timeout used to
const defaultTimeOut = time.Minute * 10
// defaultTimeout is the timeout used to clean up sessions
const defaultTimeOut = time.Minute * 5
// getCompletedStages returns the completed stages for a session.
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
@ -95,9 +95,9 @@ func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginTyp
// addParams adds a registerRequest to a sessionID and starts a timer to delete that registerRequest
func (d *sessionsDict) addParams(sessionID string, r registerRequest) {
d.Lock()
d.Unlock()
d.startTimer(defaultTimeOut, sessionID)
d.Lock()
defer d.Unlock()
d.params[sessionID] = r
}
@ -117,6 +117,7 @@ func (d *sessionsDict) deleteSession(sessionID string) {
delete(d.sessions, sessionID)
// stop the timer, e.g. because the registration was completed
if t, ok := d.timer[sessionID]; ok {
// trying to drain the channel results in a deadlock?
t.Stop()
delete(d.timer, sessionID)
}
@ -131,20 +132,19 @@ func newSessionsDict() *sessionsDict {
}
func (d *sessionsDict) startTimer(duration time.Duration, sessionID string) {
d.RLock()
defer d.RUnlock()
if _, ok := d.timer[sessionID]; !ok {
go func() {
timer := time.NewTimer(duration)
d.Lock()
d.timer[sessionID] = timer
d.Unlock()
select {
case <-timer.C:
d.deleteSession(sessionID)
}
}()
d.Lock()
defer d.Unlock()
t, ok := d.timer[sessionID]
if ok {
if !t.Stop() {
<-t.C
}
t.Reset(duration)
return
}
d.timer[sessionID] = time.AfterFunc(duration, func() {
d.deleteSession(sessionID)
})
}
// addCompletedSessionStage records that a session has completed an auth stage

View file

@ -214,19 +214,19 @@ func TestSessionCleanUp(t *testing.T) {
s := newSessionsDict()
t.Run("session is cleaned up after a while", func(t *testing.T) {
t.Parallel()
dummySession := "helloWorld"
s.Lock()
// manually added, as s.addParams() would start the timer with the default timeout
s.params[dummySession] = registerRequest{Username: "Testing"}
s.Unlock()
s.startTimer(time.Millisecond, dummySession)
time.Sleep(time.Millisecond * 10)
time.Sleep(time.Millisecond * 2)
if data, ok := s.getParams(dummySession); ok {
t.Errorf("expected session to be deleted: %+v", data)
}
})
t.Run("session is deleted, once the registration completed", func(t *testing.T) {
t.Parallel()
dummySession := "helloWorld2"
s.startTimer(time.Minute, dummySession)
s.deleteSession(dummySession)
@ -235,4 +235,19 @@ func TestSessionCleanUp(t *testing.T) {
}
})
t.Run("session timer is restarted after second call", func(t *testing.T) {
t.Parallel()
dummySession := "helloWorld3"
// the following will start a timer with the default timeout of 5min
s.addParams(dummySession, registerRequest{Username: "Testing"})
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha)
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy)
s.getCompletedStages(dummySession)
// reset the timer with a lower timeout
s.startTimer(time.Millisecond, dummySession)
time.Sleep(time.Millisecond * 2)
if data, ok := s.getParams(dummySession); ok {
t.Errorf("expected session to be deleted: %+v", data)
}
})
}