diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index 22a45ec1b..e295f8f07 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -16,6 +16,8 @@ package auth import ( "context" + "database/sql" + "net/http" "reflect" "strings" "testing" @@ -23,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" ) func TestLoginFromJSONReader(t *testing.T) { @@ -32,12 +35,10 @@ func TestLoginFromJSONReader(t *testing.T) { Name string Body string - WantErrCode string WantUsername string WantDeviceID string WantDeletedTokens []string }{ - {Name: "empty", WantErrCode: "M_BAD_JSON"}, { Name: "passwordWorks", Body: `{ @@ -70,20 +71,11 @@ func TestLoginFromJSONReader(t *testing.T) { ServerName: serverName, }, } - login, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg) - if tst.WantErrCode == "" { - if errRes != nil { - t.Fatalf("LoginFromJSONReader failed: %+v", errRes) - } - cleanup(ctx, nil) - } else { - if errRes == nil { - t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) - } else if merr, ok := errRes.JSON.(*jsonerror.MatrixError); ok && merr.ErrCode != tst.WantErrCode { - t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) - } - return + login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg) + if err != nil { + t.Fatalf("LoginFromJSONReader failed: %+v", err) } + cleanup(ctx, &util.JSONResponse{Code: http.StatusOK}) if login.Username() != tst.WantUsername { t.Errorf("Username: got %q, want %q", login.Username(), tst.WantUsername) @@ -106,11 +98,78 @@ func TestLoginFromJSONReader(t *testing.T) { } } +func TestBadLoginFromJSONReader(t *testing.T) { + ctx := context.Background() + + tsts := []struct { + Name string + Body string + + WantErrCode string + }{ + {Name: "empty", WantErrCode: "M_BAD_JSON"}, + { + Name: "badUnmarshal", + Body: `badsyntaxJSON`, + WantErrCode: "M_BAD_JSON", + }, + { + Name: "badPassword", + Body: `{ + "type": "m.login.password", + "identifier": { "type": "m.id.user", "user": "alice" }, + "password": "invalidpassword", + "device_id": "adevice" + }`, + WantErrCode: "M_FORBIDDEN", + }, + { + Name: "badToken", + Body: `{ + "type": "m.login.token", + "token": "invalidtoken", + "device_id": "adevice" + }`, + WantErrCode: "M_FORBIDDEN", + }, + { + Name: "badType", + Body: `{ + "type": "m.login.invalid", + "device_id": "adevice" + }`, + WantErrCode: "M_INVALID_ARGUMENT_VALUE", + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + var accountDB fakeAccountDB + var userAPI fakeUserInternalAPI + cfg := &config.ClientAPI{ + Matrix: &config.Global{ + ServerName: serverName, + }, + } + _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg) + if errRes == nil { + cleanup(ctx, nil) + t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) + } else if merr, ok := errRes.JSON.(*jsonerror.MatrixError); ok && merr.ErrCode != tst.WantErrCode { + t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) + } + }) + } +} + type fakeAccountDB struct { AccountDatabase } func (*fakeAccountDB) GetAccountByPassword(ctx context.Context, localpart, password string) (*uapi.Account, error) { + if password == "invalidpassword" { + return nil, sql.ErrNoRows + } + return &uapi.Account{}, nil } @@ -126,6 +185,10 @@ func (ua *fakeUserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, re } func (*fakeUserInternalAPI) QueryLoginToken(ctx context.Context, req *uapi.QueryLoginTokenRequest, res *uapi.QueryLoginTokenResponse) error { + if req.Token == "invalidtoken" { + return nil + } + res.Data = &uapi.LoginTokenData{UserID: "@auser:example.com"} return nil } diff --git a/clientapi/auth/login_token.go b/clientapi/auth/login_token.go index bfc745006..845eb5de9 100644 --- a/clientapi/auth/login_token.go +++ b/clientapi/auth/login_token.go @@ -45,17 +45,6 @@ func (t *LoginTypeToken) LoginFromJSON(ctx context.Context, reqBytes []byte) (*L return nil, nil, err } - return t.login(ctx, &r) -} - -// loginTokenRequest struct to hold the possible parameters from an HTTP request. -type loginTokenRequest struct { - Login - Token string `json:"token"` -} - -// login parses and validates the login token. It returns basic user information. -func (t *LoginTypeToken) login(ctx context.Context, r *loginTokenRequest) (*Login, LoginCleanupFunc, *util.JSONResponse) { var res uapi.QueryLoginTokenResponse if err := t.UserAPI.QueryLoginToken(ctx, &uapi.QueryLoginTokenRequest{Token: r.Token}, &res); err != nil { util.GetLogger(ctx).WithError(err).Error("UserAPI.QueryLoginToken failed") @@ -73,7 +62,11 @@ func (t *LoginTypeToken) login(ctx context.Context, r *loginTokenRequest) (*Logi r.Login.Identifier.User = res.Data.UserID cleanup := func(ctx context.Context, authRes *util.JSONResponse) { - if authRes == nil || authRes.Code == http.StatusOK { + if authRes == nil { + util.GetLogger(ctx).Error("No JSONResponse provided to LoginTokenType cleanup function") + return + } + if authRes.Code == http.StatusOK { var res uapi.PerformLoginTokenDeletionResponse if err := t.UserAPI.PerformLoginTokenDeletion(ctx, &uapi.PerformLoginTokenDeletionRequest{Token: r.Token}, &res); err != nil { util.GetLogger(ctx).WithError(err).Error("UserAPI.PerformLoginTokenDeletion failed") @@ -82,3 +75,9 @@ func (t *LoginTypeToken) login(ctx context.Context, r *loginTokenRequest) (*Logi } return &r.Login, cleanup, nil } + +// loginTokenRequest struct to hold the possible parameters from an HTTP request. +type loginTokenRequest struct { + Login + Token string `json:"token"` +} diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index e3effbe99..b48b9e93b 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -22,7 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" - uapi "github.com/matrix-org/dendrite/userapi/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -54,7 +54,7 @@ func passwordLogin() flows { // Login implements GET and POST /login func Login( - req *http.Request, accountDB accounts.Database, userAPI uapi.UserInternalAPI, + req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI, cfg *config.ClientAPI, ) util.JSONResponse { if req.Method == http.MethodGet { @@ -69,9 +69,9 @@ func Login( return *authErr } // make a device/access token - authzErr := completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) - cleanup(req.Context(), &authzErr) - return authzErr + authErr2 := completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) + cleanup(req.Context(), &authErr2) + return authErr2 } return util.JSONResponse{ Code: http.StatusMethodNotAllowed, @@ -80,7 +80,7 @@ func Login( } func completeAuth( - ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI uapi.UserInternalAPI, login *auth.Login, + ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.UserInternalAPI, login *auth.Login, ipAddr, userAgent string, ) util.JSONResponse { token, err := auth.GenerateAccessToken() @@ -95,8 +95,8 @@ func completeAuth( return jsonerror.InternalServerError() } - var performRes uapi.PerformDeviceCreationResponse - err = userAPI.PerformDeviceCreation(ctx, &uapi.PerformDeviceCreationRequest{ + var performRes userapi.PerformDeviceCreationResponse + err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ DeviceDisplayName: login.InitialDisplayName, DeviceID: login.DeviceID, AccessToken: token, diff --git a/userapi/api/api_logintoken.go b/userapi/api/api_logintoken.go index f5b7e3974..f3aa037e4 100644 --- a/userapi/api/api_logintoken.go +++ b/userapi/api/api_logintoken.go @@ -23,7 +23,8 @@ type LoginTokenInternalAPI interface { // PerformLoginTokenCreation creates a new login token and associates it with the provided data. PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error - // PerformLoginTokenDeletion ensures the token doesn't exist. + // PerformLoginTokenDeletion ensures the token doesn't exist. Success + // is returned even if the token didn't exist, or was already deleted. PerformLoginTokenDeletion(ctx context.Context, req *PerformLoginTokenDeletionRequest, res *PerformLoginTokenDeletionResponse) error // QueryLoginToken returns the data associated with a login token. If diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go index b79ff054c..86ffc58f3 100644 --- a/userapi/internal/api_logintoken.go +++ b/userapi/internal/api_logintoken.go @@ -51,7 +51,7 @@ func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *ap // QueryLoginToken returns the data associated with a login token. If // the token is not valid, success is returned, but res.Data == nil. func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error { - tokenData, err := a.DeviceDB.GetLoginTokenByToken(ctx, req.Token) + tokenData, err := a.DeviceDB.GetLoginTokenDataByToken(ctx, req.Token) if err != nil { res.Data = nil if err == sql.ErrNoRows { @@ -59,10 +59,13 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog } return err } - localpart, _, err := gomatrixserverlib.SplitID('@', tokenData.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', tokenData.UserID) if err != nil { return err } + if domain != a.ServerName { + return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName) + } if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil { res.Data = nil if err == sql.ErrNoRows { diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index cb7602f0a..8ff91cf1c 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -46,7 +46,7 @@ type Database interface { // RemoveLoginToken removes the named token (and may clean up other expired tokens). RemoveLoginToken(ctx context.Context, token string) error - // GetLoginTokenByToken returns the data associated with the given token. + // GetLoginTokenDataByToken returns the data associated with the given token. // May return sql.ErrNoRows. - GetLoginTokenByToken(ctx context.Context, token string) (*api.LoginTokenData, error) + GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) } diff --git a/userapi/storage/devices/postgres/logintoken_table.go b/userapi/storage/devices/postgres/logintoken_table.go index b695b5b99..f601fc7db 100644 --- a/userapi/storage/devices/postgres/logintoken_table.go +++ b/userapi/storage/devices/postgres/logintoken_table.go @@ -25,9 +25,9 @@ import ( ) type loginTokenStatements struct { - insertStmt *sql.Stmt - deleteStmt *sql.Stmt - selectStmt *sql.Stmt + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectByTokenStmt *sql.Stmt } // execSchema ensures tables and indices exist. @@ -54,7 +54,7 @@ func (s *loginTokenStatements) prepare(db *sql.DB) error { return sqlutil.StatementList{ {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, - {&s.selectStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, + {&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, }.Prepare(db) } @@ -76,7 +76,7 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t return err } if n, err := res.RowsAffected(); err == nil && n > 1 { - util.GetLogger(ctx).WithField("num_deleted", n).Infof("Deleted %d login tokens (%d likely garbage collected)", n, n-1) + util.GetLogger(ctx).WithField("num_deleted", n).Infof("Deleted %d login tokens (%d likely additional expired token)", n, n-1) } return nil } @@ -84,7 +84,7 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t // selectByToken returns the data associated with the given token. May return sql.ErrNoRows. func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { var data api.LoginTokenData - err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) + err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) if err != nil { return nil, err } diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index 881788066..fd9d513f1 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -263,8 +263,8 @@ func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { }) } -// GetLoginTokenByToken returns the data associated with the given token. +// GetLoginTokenDataByToken returns the data associated with the given token. // May return sql.ErrNoRows. -func (d *Database) GetLoginTokenByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { +func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { return d.loginTokens.selectByToken(ctx, token) } diff --git a/userapi/storage/devices/sqlite3/logintoken_table.go b/userapi/storage/devices/sqlite3/logintoken_table.go index 96f4ee416..75ef272f8 100644 --- a/userapi/storage/devices/sqlite3/logintoken_table.go +++ b/userapi/storage/devices/sqlite3/logintoken_table.go @@ -25,9 +25,9 @@ import ( ) type loginTokenStatements struct { - insertStmt *sql.Stmt - deleteStmt *sql.Stmt - selectStmt *sql.Stmt + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectByTokenStmt *sql.Stmt } // execSchema ensures tables and indices exist. @@ -54,7 +54,7 @@ func (s *loginTokenStatements) prepare(db *sql.DB) error { return sqlutil.StatementList{ {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, - {&s.selectStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, + {&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, }.Prepare(db) } @@ -76,7 +76,7 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t return err } if n, err := res.RowsAffected(); err == nil && n > 1 { - util.GetLogger(ctx).WithField("num_deleted", n).Infof("Deleted %d login tokens (%d likely garbage collected)", n, n-1) + util.GetLogger(ctx).WithField("num_deleted", n).Infof("Deleted %d login tokens (%d likely additional expired token)", n, n-1) } return nil } @@ -84,7 +84,7 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t // selectByToken returns the data associated with the given token. May return sql.ErrNoRows. func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { var data api.LoginTokenData - err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) + err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) if err != nil { return nil, err } diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index 7f76166b9..6e90413be 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -264,8 +264,8 @@ func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { }) } -// GetLoginTokenByToken returns the data associated with the given token. +// GetLoginTokenDataByToken returns the data associated with the given token. // May return sql.ErrNoRows. -func (d *Database) GetLoginTokenByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { +func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { return d.loginTokens.selectByToken(ctx, token) }