diff --git a/clientapi/auth/authtypes/logintypes.go b/clientapi/auth/authtypes/logintypes.go index 087e45043..da0324251 100644 --- a/clientapi/auth/authtypes/logintypes.go +++ b/clientapi/auth/authtypes/logintypes.go @@ -5,6 +5,7 @@ type LoginType string // The relevant login types implemented in Dendrite const ( + LoginTypePassword = "m.login.password" LoginTypeDummy = "m.login.dummy" LoginTypeSharedSecret = "org.matrix.login.shared_secret" LoginTypeRecaptcha = "m.login.recaptcha" diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go new file mode 100644 index 000000000..8b81b9f02 --- /dev/null +++ b/clientapi/routing/password.go @@ -0,0 +1,127 @@ +package routing + +import ( + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/userapi/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +type newPasswordRequest struct { + NewPassword string `json:"new_password"` + LogoutDevices bool `json:"logout_devices"` + Auth newPasswordAuth `json:"auth"` +} + +type newPasswordAuth struct { + Type string `json:"type"` + Session string `json:"session"` + auth.PasswordRequest +} + +func Password( + req *http.Request, + userAPI userapi.UserInternalAPI, + accountDB accounts.Database, + device *api.Device, + cfg *config.ClientAPI, +) util.JSONResponse { + // Check that the existing password is right. + var r newPasswordRequest + r.LogoutDevices = true + + // Unmarshal the request. + resErr := httputil.UnmarshalJSONRequest(req, &r) + if resErr != nil { + return *resErr + } + + // Retrieve or generate the sessionID + sessionID := r.Auth.Session + if sessionID == "" { + // Generate a new, random session ID + sessionID = util.RandomString(sessionIDLength) + } + + // Require password auth to change the password. + if r.Auth.Type != authtypes.LoginTypePassword { + return util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: newUserInteractiveResponse( + sessionID, + []authtypes.Flow{ + { + Stages: []authtypes.LoginType{authtypes.LoginTypePassword}, + }, + }, + nil, + ), + } + } + + // Check if the existing password is correct. + typePassword := auth.LoginTypePassword{ + GetAccountByPassword: accountDB.GetAccountByPassword, + Config: cfg, + } + if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil { + return *authErr + } + AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) + + // Check the new password strength. + if resErr = validatePassword(r.NewPassword); resErr != nil { + return *resErr + } + + // Get the local part. + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") + return jsonerror.InternalServerError() + } + + // Ask the user API to perform the password change. + passwordReq := &userapi.PerformPasswordUpdateRequest{ + Localpart: localpart, + Password: r.NewPassword, + } + passwordRes := &userapi.PerformPasswordUpdateResponse{} + if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("PerformPasswordUpdate failed") + return jsonerror.InternalServerError() + } + if !passwordRes.PasswordUpdated { + util.GetLogger(req.Context()).Error("Expected password to have been updated but wasn't") + return jsonerror.InternalServerError() + } + + // If the request asks us to log out all other devices then + // ask the user API to do that. + if r.LogoutDevices { + logoutReq := &userapi.PerformDeviceDeletionRequest{ + UserID: device.UserID, + DeviceIDs: nil, + ExceptDeviceID: device.ID, + } + logoutRes := &userapi.PerformDeviceDeletionResponse{} + if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") + return jsonerror.InternalServerError() + } + } + + // Return a success code. + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 708f6feeb..b29fccf2e 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -417,6 +417,15 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/account/password", + httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } + return Password(req, userAPI, accountDB, device, cfg) + }), + ).Methods(http.MethodPost, http.MethodOptions) + // Stub endpoints required by Riot r0mux.Handle("/login", diff --git a/sytest-whitelist b/sytest-whitelist index 7ce59fef6..93d2de593 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -460,3 +460,8 @@ If user leaves room, remote user changes device and rejoins we see update in /sy Can search public room list Can get remote public room list Asking for a remote rooms list, but supplying the local server's name, returns the local rooms list +After changing password, can't log in with old password +After changing password, can log in with new password +After changing password, existing session still works +After changing password, different sessions can optionally be kept +After changing password, a different session no longer works by default diff --git a/userapi/api/api.go b/userapi/api/api.go index e6d05c335..3baaa1002 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -26,6 +26,7 @@ import ( type UserInternalAPI interface { InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error + PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error @@ -63,6 +64,10 @@ type PerformDeviceDeletionRequest struct { UserID string // The devices to delete. An empty slice means delete all devices. DeviceIDs []string + // The requesting device ID to exclude from deletion. This is needed + // so that a password change doesn't cause that client to be logged + // out. Only specify when DeviceIDs is empty. + ExceptDeviceID string } type PerformDeviceDeletionResponse struct { @@ -165,6 +170,18 @@ type PerformAccountCreationResponse struct { Account *Account } +// PerformAccountCreationRequest is the request for PerformAccountCreation +type PerformPasswordUpdateRequest struct { + Localpart string // Required: The localpart for this account. + Password string // Required: The new password to set. +} + +// PerformAccountCreationResponse is the response for PerformAccountCreation +type PerformPasswordUpdateResponse struct { + PasswordUpdated bool + Account *Account +} + // PerformDeviceCreationRequest is the request for PerformDeviceCreation type PerformDeviceCreationRequest struct { Localpart string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index b97f148e0..461c548cc 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -98,6 +98,15 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P res.Account = acc return nil } + +func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { + if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil { + return err + } + res.PasswordUpdated = true + return nil +} + func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error { util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, @@ -126,7 +135,7 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe deletedDeviceIDs := req.DeviceIDs if len(req.DeviceIDs) == 0 { var devices []api.Device - devices, err = a.DeviceDB.RemoveAllDevices(ctx, local) + devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) for _, d := range devices { deletedDeviceIDs = append(deletedDeviceIDs, d.ID) } diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 5f4df0eb1..6dcaf7568 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -30,6 +30,7 @@ const ( PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformAccountCreationPath = "/userapi/performAccountCreation" + PerformPasswordUpdatePath = "/userapi/performPasswordUpdate" PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" @@ -81,6 +82,18 @@ func (h *httpUserInternalAPI) PerformAccountCreation( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +func (h *httpUserInternalAPI) PerformPasswordUpdate( + ctx context.Context, + request *api.PerformPasswordUpdateRequest, + response *api.PerformPasswordUpdateResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPasswordUpdate") + defer span.Finish() + + apiURL := h.apiURL + PerformPasswordUpdatePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + func (h *httpUserInternalAPI) PerformDeviceCreation( ctx context.Context, request *api.PerformDeviceCreationRequest, diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 47d68ff21..d26746788 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -39,6 +39,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(PerformAccountCreationPath, + httputil.MakeInternalAPI("performPasswordUpdate", func(req *http.Request) util.JSONResponse { + request := api.PerformPasswordUpdateRequest{} + response := api.PerformPasswordUpdateResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformPasswordUpdate(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(PerformDeviceCreationPath, httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse { request := api.PerformDeviceCreationRequest{} diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index 86b91e603..49446f11f 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -28,6 +28,7 @@ type Database interface { internal.PartitionStorer GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) + SetPassword(ctx context.Context, localpart string, plaintextPassword string) error SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetDisplayName(ctx context.Context, localpart string, displayName string) error // CreateAccount makes a new account with the given login name and password, and creates an empty profile diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go index 931ffb73d..8c8d32cf8 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -47,6 +47,9 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; const insertAccountSQL = "" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" +const updatePasswordSQL = "" + + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" + const selectAccountByLocalpartSQL = "" + "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" @@ -56,10 +59,9 @@ const selectPasswordHashSQL = "" + const selectNewNumericLocalpartSQL = "" + "SELECT nextval('numeric_username_seq')" -// TODO: Update password - type accountsStatements struct { insertAccountStmt *sql.Stmt + updatePasswordStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt @@ -74,6 +76,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { return } + if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil { + return + } if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { return } @@ -114,6 +119,13 @@ func (s *accountsStatements) insertAccount( }, nil } +func (s *accountsStatements) updatePassword( + ctx context.Context, localpart, passwordHash string, +) (err error) { + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + return +} + func (s *accountsStatements) selectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index b36264dd9..8b9ebef80 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -112,6 +112,17 @@ func (d *Database) SetDisplayName( return d.profiles.setDisplayName(ctx, localpart, displayName) } +// SetPassword sets the account password to the given hash. +func (d *Database) SetPassword( + ctx context.Context, localpart, plaintextPassword string, +) error { + hash, err := hashPassword(plaintextPassword) + if err != nil { + return err + } + return d.accounts.updatePassword(ctx, localpart, hash) +} + // CreateGuestAccount makes a new guest account and creates an empty profile // for this account. func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 798a6de96..fbbdc3370 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -45,6 +45,9 @@ CREATE TABLE IF NOT EXISTS account_accounts ( const insertAccountSQL = "" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" +const updatePasswordSQL = "" + + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" + const selectAccountByLocalpartSQL = "" + "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" @@ -54,11 +57,10 @@ const selectPasswordHashSQL = "" + const selectNewNumericLocalpartSQL = "" + "SELECT COUNT(localpart) FROM account_accounts" -// TODO: Update password - type accountsStatements struct { db *sql.DB insertAccountStmt *sql.Stmt + updatePasswordStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt @@ -75,6 +77,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { return } + if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil { + return + } if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { return } @@ -115,6 +120,13 @@ func (s *accountsStatements) insertAccount( }, nil } +func (s *accountsStatements) updatePassword( + ctx context.Context, localpart, passwordHash string, +) (err error) { + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + return +} + func (s *accountsStatements) selectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 46106297b..4b66304c2 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -126,6 +126,18 @@ func (d *Database) SetDisplayName( }) } +// SetPassword sets the account password to the given hash. +func (d *Database) SetPassword( + ctx context.Context, localpart, plaintextPassword string, +) error { + hash, err := hashPassword(plaintextPassword) + if err != nil { + return err + } + err = d.accounts.updatePassword(ctx, localpart, hash) + return err +} + // CreateGuestAccount makes a new guest account and creates an empty profile // for this account. func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index 9b4261c9d..168c84c5c 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -36,5 +36,5 @@ type Database interface { RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. - RemoveAllDevices(ctx context.Context, localpart string) (devices []api.Device, err error) + RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) } diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go index 282466f8d..c06af7549 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/devices/postgres/devices_table.go @@ -70,7 +70,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name FROM device_devices WHERE localpart = $1" + "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2" const updateDeviceNameSQL = "" + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" @@ -79,7 +79,7 @@ const deleteDeviceSQL = "" + "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" @@ -179,10 +179,10 @@ func (s *devicesStatements) deleteDevices( // deleteDevicesByLocalpart removes all devices for the // given user localpart. func (s *devicesStatements) deleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart) + _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) return err } @@ -251,10 +251,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s } func (s *devicesStatements) selectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) if err != nil { return devices, err diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index 04dae9864..c5bd5b6cf 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -68,7 +68,7 @@ func (d *Database) GetDeviceByID( func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -175,14 +175,14 @@ func (d *Database) RemoveDevices( // database matching the given user ID localpart. // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( - ctx context.Context, localpart string, + ctx context.Context, localpart, exceptDeviceID string, ) (devices []api.Device, err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) if err != nil { return err } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { return err } return nil diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index ecf43524a..c75e19825 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -59,7 +59,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name FROM device_devices WHERE localpart = $1" + "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2" const updateDeviceNameSQL = "" + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" @@ -68,7 +68,7 @@ const deleteDeviceSQL = "" + "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" @@ -182,10 +182,10 @@ func (s *devicesStatements) deleteDevices( } func (s *devicesStatements) deleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart) + _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) return err } @@ -231,10 +231,10 @@ func (s *devicesStatements) selectDeviceByID( } func (s *devicesStatements) selectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) if err != nil { return devices, err diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index f775fb664..7c6645dd6 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -72,7 +72,7 @@ func (d *Database) GetDeviceByID( func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -179,14 +179,14 @@ func (d *Database) RemoveDevices( // database matching the given user ID localpart. // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( - ctx context.Context, localpart string, + ctx context.Context, localpart, exceptDeviceID string, ) (devices []api.Device, err error) { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) if err != nil { return err } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { return err } return nil