Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/consent-tracking

This commit is contained in:
Till Faelligen 2022-02-25 15:21:06 +01:00
commit ed16a2f107
10 changed files with 172 additions and 38 deletions

View file

@ -162,7 +162,7 @@ func AuthFallback(
} }
// Success. Add recaptcha as a completed login flow // Success. Add recaptcha as a completed login flow
AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
serveSuccess() serveSuccess()
return nil return nil

View file

@ -70,7 +70,7 @@ func UploadCrossSigningDeviceKeys(
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil {
return *authErr return *authErr
} }
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
uploadReq.UserID = device.UserID uploadReq.UserID = device.UserID
keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes)

View file

@ -74,7 +74,7 @@ func Password(
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil { if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {
return *authErr return *authErr
} }
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
// Check the new password strength. // Check the new password strength.
if resErr = validatePassword(r.NewPassword); resErr != nil { if resErr = validatePassword(r.NewPassword); resErr != nil {

View file

@ -72,14 +72,19 @@ func init() {
// sessionsDict keeps track of completed auth stages for each session. // sessionsDict keeps track of completed auth stages for each session.
// It shouldn't be passed by value because it contains a mutex. // It shouldn't be passed by value because it contains a mutex.
type sessionsDict struct { type sessionsDict struct {
sync.Mutex sync.RWMutex
sessions map[string][]authtypes.LoginType sessions map[string][]authtypes.LoginType
params map[string]registerRequest
timer map[string]*time.Timer
} }
// GetCompletedStages returns the completed stages for a session. // defaultTimeout is the timeout used to clean up sessions
func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType { const defaultTimeOut = time.Minute * 5
d.Lock()
defer d.Unlock() // getCompletedStages returns the completed stages for a session.
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
d.RLock()
defer d.RUnlock()
if completedStages, ok := d.sessions[sessionID]; ok { if completedStages, ok := d.sessions[sessionID]; ok {
return completedStages return completedStages
@ -88,28 +93,79 @@ func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginTyp
return make([]authtypes.LoginType, 0) return make([]authtypes.LoginType, 0)
} }
func newSessionsDict() *sessionsDict { // addParams adds a registerRequest to a sessionID and starts a timer to delete that registerRequest
return &sessionsDict{ func (d *sessionsDict) addParams(sessionID string, r registerRequest) {
sessions: make(map[string][]authtypes.LoginType), d.startTimer(defaultTimeOut, sessionID)
d.Lock()
defer d.Unlock()
d.params[sessionID] = r
}
func (d *sessionsDict) getParams(sessionID string) (registerRequest, bool) {
d.RLock()
defer d.RUnlock()
r, ok := d.params[sessionID]
return r, ok
}
// deleteSession cleans up a given session, either because the registration completed
// successfully, or because a given timeout (default: 5min) was reached.
func (d *sessionsDict) deleteSession(sessionID string) {
d.Lock()
defer d.Unlock()
delete(d.params, sessionID)
delete(d.sessions, sessionID)
// stop the timer, e.g. because the registration was completed
if t, ok := d.timer[sessionID]; ok {
if !t.Stop() {
select {
case <-t.C:
default:
}
}
delete(d.timer, sessionID)
} }
} }
// AddCompletedSessionStage records that a session has completed an auth stage. func newSessionsDict() *sessionsDict {
func AddCompletedSessionStage(sessionID string, stage authtypes.LoginType) { return &sessionsDict{
sessions.Lock() sessions: make(map[string][]authtypes.LoginType),
defer sessions.Unlock() params: make(map[string]registerRequest),
timer: make(map[string]*time.Timer),
}
}
for _, completedStage := range sessions.sessions[sessionID] { func (d *sessionsDict) startTimer(duration time.Duration, sessionID string) {
d.Lock()
defer d.Unlock()
t, ok := d.timer[sessionID]
if ok {
if !t.Stop() {
<-t.C
}
t.Reset(duration)
return
}
d.timer[sessionID] = time.AfterFunc(duration, func() {
d.deleteSession(sessionID)
})
}
// addCompletedSessionStage records that a session has completed an auth stage
// also starts a timer to delete the session once done.
func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtypes.LoginType) {
d.startTimer(defaultTimeOut, sessionID)
d.Lock()
defer d.Unlock()
for _, completedStage := range d.sessions[sessionID] {
if completedStage == stage { if completedStage == stage {
return return
} }
} }
sessions.sessions[sessionID] = append(sessions.sessions[sessionID], stage) d.sessions[sessionID] = append(sessions.sessions[sessionID], stage)
} }
var ( var (
// TODO: Remove old sessions. Need to do so on a session-specific timeout.
// sessions stores the completed flow stages for all sessions. Referenced using their sessionID.
sessions = newSessionsDict() sessions = newSessionsDict()
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
) )
@ -167,7 +223,7 @@ func newUserInteractiveResponse(
params map[string]interface{}, params map[string]interface{},
) userInteractiveResponse { ) userInteractiveResponse {
return userInteractiveResponse{ return userInteractiveResponse{
fs, sessions.GetCompletedStages(sessionID), params, sessionID, fs, sessions.getCompletedStages(sessionID), params, sessionID,
} }
} }
@ -638,7 +694,7 @@ func handleRegistrationFlow(
switch r.Auth.Type { switch r.Auth.Type {
case authtypes.LoginTypeTerms: case authtypes.LoginTypeTerms:
AddCompletedSessionStage(sessionID, authtypes.LoginTypeTerms) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeTerms)
case authtypes.LoginTypeRecaptcha: case authtypes.LoginTypeRecaptcha:
// Check given captcha response // Check given captcha response
resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr) resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
@ -647,12 +703,12 @@ func handleRegistrationFlow(
} }
// Add Recaptcha to the list of completed registration stages // Add Recaptcha to the list of completed registration stages
AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
case authtypes.LoginTypeDummy: case authtypes.LoginTypeDummy:
// there is nothing to do // there is nothing to do
// Add Dummy to the list of completed registration stages // Add Dummy to the list of completed registration stages
AddCompletedSessionStage(sessionID, authtypes.LoginTypeDummy) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeDummy)
case "": case "":
// An empty auth type means that we want to fetch the available // An empty auth type means that we want to fetch the available
@ -668,7 +724,7 @@ func handleRegistrationFlow(
// Check if the user's registration flow has been completed successfully // Check if the user's registration flow has been completed successfully
// A response with current registration flow and remaining available methods // A response with current registration flow and remaining available methods
// will be returned if a flow has not been successfully completed yet // will be returned if a flow has not been successfully completed yet
return checkAndCompleteFlow(sessions.GetCompletedStages(sessionID), return checkAndCompleteFlow(sessions.getCompletedStages(sessionID),
req, r, sessionID, cfg, userAPI) req, r, sessionID, cfg, userAPI)
} }
@ -715,7 +771,7 @@ func handleApplicationServiceRegistration(
// Don't need to worry about appending to registration stages as // Don't need to worry about appending to registration stages as
// application service registration is entirely separate. // application service registration is entirely separate.
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), policyVersion, req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, policyVersion,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService,
) )
} }
@ -737,13 +793,12 @@ func checkAndCompleteFlow(
policyVersion = cfg.Matrix.UserConsentOptions.Version policyVersion = cfg.Matrix.UserConsentOptions.Version
} }
// This flow was completed, registration can continue // This flow was completed, registration can continue
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), policyVersion, req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, policyVersion,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser,
) )
} }
sessions.addParams(sessionID, r)
// There are still more stages to complete. // There are still more stages to complete.
// Return the flows and those that have been completed. // Return the flows and those that have been completed.
return util.JSONResponse{ return util.JSONResponse{
@ -762,11 +817,25 @@ func checkAndCompleteFlow(
func completeRegistration( func completeRegistration(
ctx context.Context, ctx context.Context,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
username, password, appserviceID, ipAddr, userAgent, policyVersion string, username, password, appserviceID, ipAddr, userAgent, sessionID, policyVersion string,
inhibitLogin eventutil.WeakBoolean, inhibitLogin eventutil.WeakBoolean,
displayName, deviceID *string, displayName, deviceID *string,
accType userapi.AccountType, accType userapi.AccountType,
) util.JSONResponse { ) util.JSONResponse {
var registrationOK bool
defer func() {
if registrationOK {
sessions.deleteSession(sessionID)
}
}()
if data, ok := sessions.getParams(sessionID); ok {
username = data.Username
password = data.Password
deviceID = data.DeviceID
displayName = data.InitialDisplayName
inhibitLogin = data.InhibitLogin
}
if username == "" { if username == "" {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
@ -808,6 +877,7 @@ func completeRegistration(
// Check whether inhibit_login option is set. If so, don't create an access // Check whether inhibit_login option is set. If so, don't create an access
// token or a device for this user // token or a device for this user
if inhibitLogin { if inhibitLogin {
registrationOK = true
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: registerResponse{ JSON: registerResponse{
@ -841,6 +911,7 @@ func completeRegistration(
} }
} }
registrationOK = true
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: registerResponse{ JSON: registerResponse{
@ -989,5 +1060,5 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS
if ssrr.Admin { if ssrr.Admin {
accType = userapi.AccountTypeAdmin accType = userapi.AccountTypeAdmin
} }
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", "", false, &ssrr.User, &deviceID, accType)
} }

View file

@ -17,6 +17,7 @@ package routing
import ( import (
"regexp" "regexp"
"testing" "testing"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -140,7 +141,7 @@ func TestFlowCheckingExtraneousIncorrectInput(t *testing.T) {
func TestEmptyCompletedFlows(t *testing.T) { func TestEmptyCompletedFlows(t *testing.T) {
fakeEmptySessions := newSessionsDict() fakeEmptySessions := newSessionsDict()
fakeSessionID := "aRandomSessionIDWhichDoesNotExist" fakeSessionID := "aRandomSessionIDWhichDoesNotExist"
ret := fakeEmptySessions.GetCompletedStages(fakeSessionID) ret := fakeEmptySessions.getCompletedStages(fakeSessionID)
// check for [] // check for []
if ret == nil || len(ret) != 0 { if ret == nil || len(ret) != 0 {
@ -208,3 +209,45 @@ func TestValidationOfApplicationServices(t *testing.T) {
t.Errorf("user_id should not have been valid: @_something_else:localhost") t.Errorf("user_id should not have been valid: @_something_else:localhost")
} }
} }
func TestSessionCleanUp(t *testing.T) {
s := newSessionsDict()
t.Run("session is cleaned up after a while", func(t *testing.T) {
t.Parallel()
dummySession := "helloWorld"
// manually added, as s.addParams() would start the timer with the default timeout
s.params[dummySession] = registerRequest{Username: "Testing"}
s.startTimer(time.Millisecond, dummySession)
time.Sleep(time.Millisecond * 2)
if data, ok := s.getParams(dummySession); ok {
t.Errorf("expected session to be deleted: %+v", data)
}
})
t.Run("session is deleted, once the registration completed", func(t *testing.T) {
t.Parallel()
dummySession := "helloWorld2"
s.startTimer(time.Minute, dummySession)
s.deleteSession(dummySession)
if data, ok := s.getParams(dummySession); ok {
t.Errorf("expected session to be deleted: %+v", data)
}
})
t.Run("session timer is restarted after second call", func(t *testing.T) {
t.Parallel()
dummySession := "helloWorld3"
// the following will start a timer with the default timeout of 5min
s.addParams(dummySession, registerRequest{Username: "Testing"})
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha)
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy)
s.getCompletedStages(dummySession)
// reset the timer with a lower timeout
s.startTimer(time.Millisecond, dummySession)
time.Sleep(time.Millisecond * 2)
if data, ok := s.getParams(dummySession); ok {
t.Errorf("expected session to be deleted: %+v", data)
}
})
}

2
go.mod
View file

@ -39,7 +39,7 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.10 github.com/mattn/go-sqlite3 v1.14.10

4
go.sum
View file

@ -983,8 +983,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed h1:R8EiLWArq7KT96DrUq1xq9scPh8vLwKKeCTnORPyjhU= github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052 h1:+4Q/JQ3fGgA7sIHaLMlqREX8yEpsI+HlVoW9WId7SNc=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= github.com/matrix-org/gomatrixserverlib v0.0.0-20220224170509-f6ab9c54d052/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo=
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa h1:rMYFNVto66gp+eWS8XAUzgp4m0qmUBid6l1HX3mHstk=
github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/pinecone v0.0.0-20220223104432-0f0afd1a46aa/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=

View file

@ -814,6 +814,7 @@ func (v *StateResolution) resolveConflictsV2(
// events may be duplicated across these sets but that's OK. // events may be duplicated across these sets but that's OK.
authSets := make(map[string][]*gomatrixserverlib.Event, len(conflicted)) authSets := make(map[string][]*gomatrixserverlib.Event, len(conflicted))
authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3) authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3)
gotAuthEvents := make(map[string]struct{}, estimate*3)
authDifference := make([]*gomatrixserverlib.Event, 0, estimate) authDifference := make([]*gomatrixserverlib.Event, 0, estimate)
// For each conflicted event, let's try and get the needed auth events. // For each conflicted event, let's try and get the needed auth events.
@ -850,9 +851,22 @@ func (v *StateResolution) resolveConflictsV2(
if err != nil { if err != nil {
return nil, err return nil, err
} }
authEvents = append(authEvents, authSets[key]...)
// Only add auth events into the authEvents slice once, otherwise the
// check for the auth difference can become expensive and produce
// duplicate entries, which just waste memory and CPU time.
for _, event := range authSets[key] {
if _, ok := gotAuthEvents[event.EventID()]; !ok {
authEvents = append(authEvents, event)
gotAuthEvents[event.EventID()] = struct{}{}
}
}
} }
// Kill the reference to this so that the GC may pick it up, since we no
// longer need this after this point.
gotAuthEvents = nil // nolint:ineffassign
// This function helps us to work out whether an event exists in one of the // This function helps us to work out whether an event exists in one of the
// auth sets. // auth sets.
isInAuthList := func(k string, event *gomatrixserverlib.Event) bool { isInAuthList := func(k string, event *gomatrixserverlib.Event) bool {
@ -866,11 +880,12 @@ func (v *StateResolution) resolveConflictsV2(
// This function works out if an event exists in all of the auth sets. // This function works out if an event exists in all of the auth sets.
isInAllAuthLists := func(event *gomatrixserverlib.Event) bool { isInAllAuthLists := func(event *gomatrixserverlib.Event) bool {
found := true
for k := range authSets { for k := range authSets {
found = found && isInAuthList(k, event) if !isInAuthList(k, event) {
return false
}
} }
return found return true
} }
// Look through all of the auth events that we've been given and work out if // Look through all of the auth events that we've been given and work out if

View file

@ -24,6 +24,7 @@ Local device key changes get to remote servers with correct prev_id
# Flakey # Flakey
Local device key changes appear in /keys/changes Local device key changes appear in /keys/changes
/context/ with lazy_load_members filter works
# we don't support groups # we don't support groups
Remove group category Remove group category

View file

@ -601,3 +601,7 @@ Can query remote device keys using POST after notification
Device deletion propagates over federation Device deletion propagates over federation
Get left notifs in sync and /keys/changes when other user leaves Get left notifs in sync and /keys/changes when other user leaves
Remote banned user is kicked and may not rejoin until unbanned Remote banned user is kicked and may not rejoin until unbanned
registration remembers parameters
registration accepts non-ascii passwords
registration with inhibit_login inhibits login