mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-13 01:43:09 -06:00
Handle cases where expireTime is updated
This commit is contained in:
parent
59aa8683f1
commit
58996fb131
|
|
@ -17,13 +17,10 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
userExists = struct{}{} // Value denoting user is present in a userSet.
|
||||
defaultTypingTimeout = 10 * time.Second
|
||||
)
|
||||
var defaultTypingTimeout = 10 * time.Second
|
||||
|
||||
// userSet is a map of user IDs.
|
||||
type userSet map[string]struct{}
|
||||
// userSet is a map of user IDs to their time of expiry.
|
||||
type userSet map[string]time.Time
|
||||
|
||||
// TypingCache maintains a list of users typing in each room.
|
||||
type TypingCache struct {
|
||||
|
|
@ -57,33 +54,40 @@ func (t *TypingCache) GetTypingUsers(roomID string) (users []string) {
|
|||
func (t *TypingCache) AddTypingUser(userID, roomID string, expire *time.Time) {
|
||||
expireTime := getExpireTime(expire)
|
||||
if until := time.Until(expireTime); until > 0 {
|
||||
t.addUser(userID, roomID)
|
||||
t.removeUserAfterDuration(userID, roomID, until)
|
||||
t.addUser(userID, roomID, expireTime)
|
||||
t.removeUserAfterTime(userID, roomID, expireTime)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TypingCache) addUser(userID, roomID string) {
|
||||
// addUser with mutex lock.
|
||||
func (t *TypingCache) addUser(userID, roomID string, expireTime time.Time) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
|
||||
if t.data[roomID] == nil {
|
||||
t.data[roomID] = make(userSet)
|
||||
}
|
||||
|
||||
t.data[roomID][userID] = userExists
|
||||
t.Unlock()
|
||||
t.data[roomID][userID] = expireTime
|
||||
}
|
||||
|
||||
// Creates a go routine which removes the user after d duration has elapsed.
|
||||
func (t *TypingCache) removeUserAfterDuration(userID, roomID string, d time.Duration) {
|
||||
// Creates a go routine which removes the user after expireTime has elapsed,
|
||||
// only if the expiration is not updated to a later time in cache.
|
||||
func (t *TypingCache) removeUserAfterTime(userID, roomID string, expireTime time.Time) {
|
||||
go func() {
|
||||
time.Sleep(d)
|
||||
t.removeUser(userID, roomID)
|
||||
time.Sleep(time.Until(expireTime))
|
||||
t.removeUserIfExpired(userID, roomID)
|
||||
}()
|
||||
}
|
||||
|
||||
func (t *TypingCache) removeUser(userID, roomID string) {
|
||||
// removeUserIfExpired with mutex lock.
|
||||
func (t *TypingCache) removeUserIfExpired(userID, roomID string) {
|
||||
t.Lock()
|
||||
delete(t.data[roomID], userID)
|
||||
t.Unlock()
|
||||
defer t.Unlock()
|
||||
|
||||
if time.Until(t.data[roomID][userID]) <= 0 {
|
||||
delete(t.data[roomID], userID)
|
||||
}
|
||||
}
|
||||
|
||||
func getExpireTime(expire *time.Time) time.Time {
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
const defaultInterval = time.Second
|
||||
const longInterval = time.Hour
|
||||
|
||||
func TestTypingCache(t *testing.T) {
|
||||
tCache := NewTypingCache()
|
||||
|
|
@ -35,13 +35,13 @@ func TestTypingCache(t *testing.T) {
|
|||
testGetTypingUsers(t, tCache)
|
||||
})
|
||||
|
||||
t.Run("GetTypingUsersAfterTimeout", func(t *testing.T) {
|
||||
testGetTypingUsersAfterTimeout(t, tCache)
|
||||
t.Run("removeUserIfExpired", func(t *testing.T) {
|
||||
testRemoveUserIfExpired(t, tCache)
|
||||
})
|
||||
}
|
||||
|
||||
func testAddTypingUser(t *testing.T, tCache *TypingCache) {
|
||||
timeAfterDefaultInterval := time.Now().Add(defaultInterval)
|
||||
timeAfterLongInterval := time.Now().Add(longInterval)
|
||||
tests := []struct {
|
||||
userID string
|
||||
roomID string
|
||||
|
|
@ -51,9 +51,8 @@ func testAddTypingUser(t *testing.T, tCache *TypingCache) {
|
|||
{"user2", "room1", nil},
|
||||
{"user3", "room1", nil},
|
||||
{"user4", "room1", nil},
|
||||
// Override timeout
|
||||
{"user1", "room2", &timeAfterDefaultInterval},
|
||||
{"user1", "room2", nil},
|
||||
// removeUserIfExpired should not remove the user before expiration time.
|
||||
{"user1", "room2", &timeAfterLongInterval},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
|
@ -80,16 +79,17 @@ func testGetTypingUsers(t *testing.T, tCache *TypingCache) {
|
|||
}
|
||||
}
|
||||
|
||||
func testGetTypingUsersAfterTimeout(t *testing.T, tCache *TypingCache) {
|
||||
time.Sleep(defaultInterval)
|
||||
func testRemoveUserIfExpired(t *testing.T, tCache *TypingCache) {
|
||||
tests := []struct {
|
||||
roomID string
|
||||
userID string
|
||||
wantUsers []string
|
||||
}{
|
||||
{"room2", []string{"user1"}},
|
||||
{"room2", "user1", []string{"user1"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tCache.removeUserIfExpired(tt.userID, tt.roomID)
|
||||
if gotUsers := tCache.GetTypingUsers(tt.roomID); !reflect.DeepEqual(gotUsers, tt.wantUsers) {
|
||||
t.Errorf("TypingCache.GetTypingUsers(%s) = %v, want %v", tt.roomID, gotUsers, tt.wantUsers)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue