From 67149833492ad74d9a2759b92f918f96e678dc5a Mon Sep 17 00:00:00 2001 From: Tommie Gannert Date: Sun, 26 Sep 2021 14:35:14 +0200 Subject: [PATCH] Refactor LoginTypePassword and Type to support m.login.token and m.login.sso. For login token: * m.login.token will require deleting the token after completeAuth has generated an access token, so a cleanup function is returned by Type.Login. * Allowing different login types will require parsing the /login body twice: first to extract the "type" and then the type-specific parsing. Thus, we will have to buffer the request JSON in /login, like UserInteractive already does. For SSO: * NewUserInteractive will have to also use GetAccountByLocalpart. It makes more sense to just pass a (narrowed-down) accountDB interface to it than adding more function pointers. Code quality: * Passing around (and down-casting) interface{} for login request types has drawbacks in terms of type-safety, and no inherent benefits. We always decode JSON anyway. Hence renaming to Type.LoginFromJSON. Code that directly uses LoginTypePassword with parsed data can still use Login. * Removed a TODO for SSO. This is already tracked in #1297. * httputil.UnmarshalJSON is useful because it returns a JSONResponse. This change is intended to have no functional changes. --- clientapi/auth/auth.go | 1 + clientapi/auth/password.go | 22 ++++++++++----- clientapi/auth/user_interactive.go | 36 +++++++++++-------------- clientapi/auth/user_interactive_test.go | 8 ++++-- clientapi/httputil/httputil.go | 4 +++ clientapi/routing/login.go | 18 ++++++++----- clientapi/routing/routing.go | 2 +- 7 files changed, 55 insertions(+), 36 deletions(-) diff --git a/clientapi/auth/auth.go b/clientapi/auth/auth.go index b4c39ae38..481bd36b2 100644 --- a/clientapi/auth/auth.go +++ b/clientapi/auth/auth.go @@ -42,6 +42,7 @@ type DeviceDatabase interface { type AccountDatabase interface { // Look up the account matching the given localpart. GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) + GetAccountByPassword(ctx context.Context, localpart, password string) (*api.Account, error) } // VerifyUserFromRequest authenticates the HTTP request, diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index 7dd21b3f2..19f7e3773 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -19,6 +19,8 @@ import ( "net/http" "strings" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" @@ -40,16 +42,24 @@ type LoginTypePassword struct { } func (t *LoginTypePassword) Name() string { - return "m.login.password" + return authtypes.LoginTypePassword } -func (t *LoginTypePassword) Request() interface{} { - return &PasswordRequest{} +func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte) (*Login, LoginCleanupFunc, *util.JSONResponse) { + var r PasswordRequest + if err := httputil.UnmarshalJSON(reqBytes, &r); err != nil { + return nil, nil, err + } + + login, err := t.Login(ctx, &r) + if err != nil { + return nil, nil, err + } + + return login, func(context.Context, *util.JSONResponse) {}, nil } -func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { - r := req.(*PasswordRequest) - // Squash username to all lowercase letters +func (t *LoginTypePassword) Login(ctx context.Context, r *PasswordRequest) (*Login, *util.JSONResponse) { username := strings.ToLower(r.Username()) if username == "" { return nil, &util.JSONResponse{ diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 30469fc47..447690ef2 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -32,22 +32,24 @@ import ( type Type interface { // Name returns the name of the auth type e.g `m.login.password` Name() string - // Request returns a pointer to a new request body struct to unmarshal into. - Request() interface{} // Login with the auth type, returning an error response on failure. // Not all types support login, only m.login.password and m.login.token // See https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-login - // `req` is guaranteed to be the type returned from Request() // This function will be called when doing login and when doing 'sudo' style // actions e.g deleting devices. The response must be a 401 as per: // "If the homeserver decides that an attempt on a stage was unsuccessful, but the // client may make a second attempt, it returns the same HTTP status 401 response as above, // with the addition of the standard errcode and error fields describing the error." - Login(ctx context.Context, req interface{}) (login *Login, errRes *util.JSONResponse) + // + // The returned cleanup function must be non-nil on success, and will be called after + // authorization has been completed. Its argument is the final result of authorization. + LoginFromJSON(ctx context.Context, reqBytes []byte) (login *Login, cleanup LoginCleanupFunc, errRes *util.JSONResponse) // TODO: Extend to support Register() flow // Register(ctx context.Context, sessionID string, req interface{}) } +type LoginCleanupFunc func(context.Context, *util.JSONResponse) + // LoginIdentifier represents identifier types // https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types type LoginIdentifier struct { @@ -111,12 +113,11 @@ type UserInteractive struct { Sessions map[string][]string } -func NewUserInteractive(getAccByPass GetAccountByPassword, cfg *config.ClientAPI) *UserInteractive { +func NewUserInteractive(accountDB AccountDatabase, cfg *config.ClientAPI) *UserInteractive { typePassword := &LoginTypePassword{ - GetAccountByPassword: getAccByPass, + GetAccountByPassword: accountDB.GetAccountByPassword, Config: cfg, } - // TODO: Add SSO login return &UserInteractive{ Completed: []string{}, Flows: []userInteractiveFlow{ @@ -236,18 +237,13 @@ func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device * } } - r := loginType.Request() - if err := json.Unmarshal([]byte(gjson.GetBytes(bodyBytes, "auth").Raw), r); err != nil { - return nil, &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), - } + login, cleanup, resErr := loginType.LoginFromJSON(ctx, []byte(gjson.GetBytes(bodyBytes, "auth").Raw)) + if resErr != nil { + return nil, u.ResponseWithChallenge(sessionID, resErr.JSON) } - login, resErr := loginType.Login(ctx, r) - if resErr == nil { - u.AddCompletedStage(sessionID, authType) - // TODO: Check if there's more stages to go and return an error - return login, nil - } - return nil, u.ResponseWithChallenge(sessionID, resErr.JSON) + + u.AddCompletedStage(sessionID, authType) + cleanup(ctx, nil) + // TODO: Check if there's more stages to go and return an error + return login, nil } diff --git a/clientapi/auth/user_interactive_test.go b/clientapi/auth/user_interactive_test.go index 0b7df3545..76d161a74 100644 --- a/clientapi/auth/user_interactive_test.go +++ b/clientapi/auth/user_interactive_test.go @@ -24,7 +24,11 @@ var ( } ) -func getAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) { +type fakeAccountDatabase struct { + AccountDatabase +} + +func (*fakeAccountDatabase) GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) { acc, ok := lookup[localpart+" "+plaintextPassword] if !ok { return nil, fmt.Errorf("unknown user/password") @@ -38,7 +42,7 @@ func setup() *UserInteractive { ServerName: serverName, }, } - return NewUserInteractive(getAccountByPassword, cfg) + return NewUserInteractive(&fakeAccountDatabase{}, cfg) } func TestUserInteractiveChallenge(t *testing.T) { diff --git a/clientapi/httputil/httputil.go b/clientapi/httputil/httputil.go index 29d7b0b37..b47701368 100644 --- a/clientapi/httputil/httputil.go +++ b/clientapi/httputil/httputil.go @@ -36,6 +36,10 @@ func UnmarshalJSONRequest(req *http.Request, iface interface{}) *util.JSONRespon return &resp } + return UnmarshalJSON(body, iface) +} + +func UnmarshalJSON(body []byte, iface interface{}) *util.JSONResponse { if !utf8.Valid(body) { return &util.JSONResponse{ Code: http.StatusBadRequest, diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 589efe0b2..3a3c06319 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -16,10 +16,10 @@ package routing import ( "context" + "io/ioutil" "net/http" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" @@ -69,17 +69,21 @@ func Login( GetAccountByPassword: accountDB.GetAccountByPassword, Config: cfg, } - r := typePassword.Request() - resErr := httputil.UnmarshalJSONRequest(req, r) - if resErr != nil { - return *resErr + body, err := ioutil.ReadAll(req.Body) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("Reading request body failed: " + err.Error()), + } } - login, authErr := typePassword.Login(req.Context(), r) + login, cleanup, authErr := typePassword.LoginFromJSON(req.Context(), body) if authErr != nil { return *authErr } // make a device/access token - return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) + authzErr := completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) + cleanup(req.Context(), &authzErr) + return authzErr } return util.JSONResponse{ Code: http.StatusMethodNotAllowed, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 9263c66bb..732066166 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -62,7 +62,7 @@ func Setup( mscCfg *config.MSCs, ) { rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) - userInteractiveAuth := auth.NewUserInteractive(accountDB.GetAccountByPassword, cfg) + userInteractiveAuth := auth.NewUserInteractive(accountDB, cfg) unstableFeatures := map[string]bool{ "org.matrix.e2e_cross_signing": true,