From af94d74362b4dfb802c3a6dbce8ac44bf8d39ec9 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 4 Sep 2020 12:36:20 +0100 Subject: [PATCH] User API support for password changes --- userapi/api/api.go | 17 +++++++++++++++++ userapi/internal/api.go | 13 ++++++++++++- userapi/inthttp/client.go | 13 +++++++++++++ userapi/inthttp/server.go | 13 +++++++++++++ userapi/storage/accounts/interface.go | 1 + .../storage/accounts/postgres/accounts_table.go | 16 ++++++++++++++-- userapi/storage/accounts/postgres/storage.go | 11 +++++++++++ .../storage/accounts/sqlite3/accounts_table.go | 16 ++++++++++++++-- userapi/storage/accounts/sqlite3/storage.go | 15 +++++++++++++++ 9 files changed, 110 insertions(+), 5 deletions(-) diff --git a/userapi/api/api.go b/userapi/api/api.go index e6d05c335..3baaa1002 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -26,6 +26,7 @@ import ( type UserInternalAPI interface { InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error + PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error @@ -63,6 +64,10 @@ type PerformDeviceDeletionRequest struct { UserID string // The devices to delete. An empty slice means delete all devices. DeviceIDs []string + // The requesting device ID to exclude from deletion. This is needed + // so that a password change doesn't cause that client to be logged + // out. Only specify when DeviceIDs is empty. + ExceptDeviceID string } type PerformDeviceDeletionResponse struct { @@ -165,6 +170,18 @@ type PerformAccountCreationResponse struct { Account *Account } +// PerformAccountCreationRequest is the request for PerformAccountCreation +type PerformPasswordUpdateRequest struct { + Localpart string // Required: The localpart for this account. + Password string // Required: The new password to set. +} + +// PerformAccountCreationResponse is the response for PerformAccountCreation +type PerformPasswordUpdateResponse struct { + PasswordUpdated bool + Account *Account +} + // PerformDeviceCreationRequest is the request for PerformDeviceCreation type PerformDeviceCreationRequest struct { Localpart string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index b97f148e0..8c7b610ed 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -98,6 +98,15 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P res.Account = acc return nil } + +func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { + if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil { + return err + } + res.PasswordUpdated = true + return nil +} + func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error { util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, @@ -128,7 +137,9 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe var devices []api.Device devices, err = a.DeviceDB.RemoveAllDevices(ctx, local) for _, d := range devices { - deletedDeviceIDs = append(deletedDeviceIDs, d.ID) + if d.ID != req.ExceptDeviceID { + deletedDeviceIDs = append(deletedDeviceIDs, d.ID) + } } } else { err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 5f4df0eb1..6dcaf7568 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -30,6 +30,7 @@ const ( PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformAccountCreationPath = "/userapi/performAccountCreation" + PerformPasswordUpdatePath = "/userapi/performPasswordUpdate" PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" @@ -81,6 +82,18 @@ func (h *httpUserInternalAPI) PerformAccountCreation( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +func (h *httpUserInternalAPI) PerformPasswordUpdate( + ctx context.Context, + request *api.PerformPasswordUpdateRequest, + response *api.PerformPasswordUpdateResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPasswordUpdate") + defer span.Finish() + + apiURL := h.apiURL + PerformPasswordUpdatePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + func (h *httpUserInternalAPI) PerformDeviceCreation( ctx context.Context, request *api.PerformDeviceCreationRequest, diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 47d68ff21..d26746788 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -39,6 +39,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(PerformAccountCreationPath, + httputil.MakeInternalAPI("performPasswordUpdate", func(req *http.Request) util.JSONResponse { + request := api.PerformPasswordUpdateRequest{} + response := api.PerformPasswordUpdateResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformPasswordUpdate(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(PerformDeviceCreationPath, httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse { request := api.PerformDeviceCreationRequest{} diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index 86b91e603..49446f11f 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -28,6 +28,7 @@ type Database interface { internal.PartitionStorer GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) + SetPassword(ctx context.Context, localpart string, plaintextPassword string) error SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetDisplayName(ctx context.Context, localpart string, displayName string) error // CreateAccount makes a new account with the given login name and password, and creates an empty profile diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go index 931ffb73d..8c8d32cf8 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -47,6 +47,9 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; const insertAccountSQL = "" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" +const updatePasswordSQL = "" + + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" + const selectAccountByLocalpartSQL = "" + "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" @@ -56,10 +59,9 @@ const selectPasswordHashSQL = "" + const selectNewNumericLocalpartSQL = "" + "SELECT nextval('numeric_username_seq')" -// TODO: Update password - type accountsStatements struct { insertAccountStmt *sql.Stmt + updatePasswordStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt @@ -74,6 +76,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { return } + if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil { + return + } if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { return } @@ -114,6 +119,13 @@ func (s *accountsStatements) insertAccount( }, nil } +func (s *accountsStatements) updatePassword( + ctx context.Context, localpart, passwordHash string, +) (err error) { + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + return +} + func (s *accountsStatements) selectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index b36264dd9..8b9ebef80 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -112,6 +112,17 @@ func (d *Database) SetDisplayName( return d.profiles.setDisplayName(ctx, localpart, displayName) } +// SetPassword sets the account password to the given hash. +func (d *Database) SetPassword( + ctx context.Context, localpart, plaintextPassword string, +) error { + hash, err := hashPassword(plaintextPassword) + if err != nil { + return err + } + return d.accounts.updatePassword(ctx, localpart, hash) +} + // CreateGuestAccount makes a new guest account and creates an empty profile // for this account. func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 798a6de96..fbbdc3370 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -45,6 +45,9 @@ CREATE TABLE IF NOT EXISTS account_accounts ( const insertAccountSQL = "" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" +const updatePasswordSQL = "" + + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" + const selectAccountByLocalpartSQL = "" + "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" @@ -54,11 +57,10 @@ const selectPasswordHashSQL = "" + const selectNewNumericLocalpartSQL = "" + "SELECT COUNT(localpart) FROM account_accounts" -// TODO: Update password - type accountsStatements struct { db *sql.DB insertAccountStmt *sql.Stmt + updatePasswordStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt @@ -75,6 +77,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { return } + if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil { + return + } if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { return } @@ -115,6 +120,13 @@ func (s *accountsStatements) insertAccount( }, nil } +func (s *accountsStatements) updatePassword( + ctx context.Context, localpart, passwordHash string, +) (err error) { + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + return +} + func (s *accountsStatements) selectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 46106297b..eaed0d493 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -19,6 +19,7 @@ import ( "database/sql" "encoding/json" "errors" + "fmt" "strconv" "sync" @@ -126,6 +127,20 @@ func (d *Database) SetDisplayName( }) } +// SetPassword sets the account password to the given hash. +func (d *Database) SetPassword( + ctx context.Context, localpart, plaintextPassword string, +) error { + hash, err := hashPassword(plaintextPassword) + if err != nil { + return err + } + fmt.Println("PASSWORD:", localpart, plaintextPassword, hash) + err = d.accounts.updatePassword(ctx, localpart, hash) + fmt.Println("ERROR:", err) + return err +} + // CreateGuestAccount makes a new guest account and creates an empty profile // for this account. func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {