diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 1a4307c18..fe6789fcb 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -30,7 +30,6 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" ) @@ -39,7 +38,6 @@ func AddPublicRoutes( router *mux.Router, cfg *config.ClientAPI, producer sarama.SyncProducer, - deviceDB devices.Database, accountsDB accounts.Database, federation *gomatrixserverlib.FederationClient, rsAPI roomserverAPI.RoomserverInternalAPI, @@ -59,7 +57,7 @@ func AddPublicRoutes( routing.Setup( router, cfg, eduInputAPI, rsAPI, asAPI, - accountsDB, deviceDB, userAPI, federation, + accountsDB, userAPI, federation, syncProducer, transactionsCache, fsAPI, stateAPI, keyAPI, extRoomsProvider, ) } diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index d0b3bdbe5..56886d57f 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -15,7 +15,6 @@ package routing import ( - "database/sql" "encoding/json" "io/ioutil" "net/http" @@ -23,7 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/devices" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -50,57 +49,56 @@ type devicesDeleteJSON struct { // GetDeviceByID handles /devices/{deviceID} func GetDeviceByID( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI userapi.UserInternalAPI, device *api.Device, deviceID string, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + var queryRes userapi.QueryDevicesResponse + err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{ + UserID: device.UserID, + }, &queryRes) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") + util.GetLogger(req.Context()).WithError(err).Error("QueryDevices failed") return jsonerror.InternalServerError() } - - ctx := req.Context() - dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID) - if err == sql.ErrNoRows { + var targetDevice *userapi.Device + for _, device := range queryRes.Devices { + if device.ID == deviceID { + targetDevice = &device + break + } + } + if targetDevice == nil { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound("Unknown device"), } - } else if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByID failed") - return jsonerror.InternalServerError() } return util.JSONResponse{ Code: http.StatusOK, JSON: deviceJSON{ - DeviceID: dev.ID, - DisplayName: dev.DisplayName, + DeviceID: targetDevice.ID, + DisplayName: targetDevice.DisplayName, }, } } // GetDevicesByLocalpart handles /devices func GetDevicesByLocalpart( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI userapi.UserInternalAPI, device *api.Device, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + var queryRes userapi.QueryDevicesResponse + err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{ + UserID: device.UserID, + }, &queryRes) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - - ctx := req.Context() - deviceList, err := deviceDB.GetDevicesByLocalpart(ctx, localpart) - - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDevicesByLocalpart failed") + util.GetLogger(req.Context()).WithError(err).Error("QueryDevices failed") return jsonerror.InternalServerError() } res := devicesJSON{} - for _, dev := range deviceList { + for _, dev := range queryRes.Devices { res.Devices = append(res.Devices, deviceJSON{ DeviceID: dev.ID, DisplayName: dev.DisplayName, diff --git a/clientapi/routing/logout.go b/clientapi/routing/logout.go index 3ce47169e..cb300e9ff 100644 --- a/clientapi/routing/logout.go +++ b/clientapi/routing/logout.go @@ -19,23 +19,21 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/devices" - "github.com/matrix-org/gomatrixserverlib" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" ) // Logout handles POST /logout func Logout( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI userapi.UserInternalAPI, device *api.Device, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + var performRes userapi.PerformDeviceDeletionResponse + err := userAPI.PerformDeviceDeletion(req.Context(), &userapi.PerformDeviceDeletionRequest{ + UserID: device.UserID, + DeviceIDs: []string{device.ID}, + }, &performRes) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - - if err := deviceDB.RemoveDevice(req.Context(), device.ID, localpart); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveDevice failed") + util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") return jsonerror.InternalServerError() } @@ -47,16 +45,15 @@ func Logout( // LogoutAll handles POST /logout/all func LogoutAll( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI userapi.UserInternalAPI, device *api.Device, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + var performRes userapi.PerformDeviceDeletionResponse + err := userAPI.PerformDeviceDeletion(req.Context(), &userapi.PerformDeviceDeletionRequest{ + UserID: device.UserID, + DeviceIDs: nil, + }, &performRes) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - - if err := deviceDB.RemoveAllDevices(req.Context(), localpart); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveAllDevices failed") + util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index c259e5293..f2494dc78 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -35,7 +35,6 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -52,7 +51,6 @@ func Setup( rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, accountDB accounts.Database, - deviceDB devices.Database, userAPI userapi.UserInternalAPI, federation *gomatrixserverlib.FederationClient, syncProducer *producers.SyncAPIProducer, @@ -322,13 +320,13 @@ func Setup( r0mux.Handle("/logout", httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return Logout(req, deviceDB, device) + return Logout(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/logout/all", httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return LogoutAll(req, deviceDB, device) + return LogoutAll(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) @@ -632,7 +630,7 @@ func Setup( r0mux.Handle("/devices", httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return GetDevicesByLocalpart(req, deviceDB, device) + return GetDevicesByLocalpart(req, userAPI, device) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -642,7 +640,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetDeviceByID(req, deviceDB, device, vars["deviceID"]) + return GetDeviceByID(req, userAPI, device, vars["deviceID"]) }), ).Methods(http.MethodGet, http.MethodOptions) diff --git a/cmd/dendrite-client-api-server/main.go b/cmd/dendrite-client-api-server/main.go index 4961b34e0..35dbb7745 100644 --- a/cmd/dendrite-client-api-server/main.go +++ b/cmd/dendrite-client-api-server/main.go @@ -27,7 +27,6 @@ func main() { defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() federation := base.CreateFederationClient() asQuery := base.AppserviceHTTPClient() @@ -39,7 +38,7 @@ func main() { keyAPI := base.KeyServerHTTPClient() clientapi.AddPublicRoutes( - base.PublicClientAPIMux, &base.Cfg.ClientAPI, base.KafkaProducer, deviceDB, accountDB, federation, + base.PublicClientAPIMux, &base.Cfg.ClientAPI, base.KafkaProducer, accountDB, federation, rsAPI, eduInputAPI, asQuery, stateAPI, transactions.New(), fsAPI, userAPI, keyAPI, nil, ) diff --git a/internal/setup/monolith.go b/internal/setup/monolith.go index 5e6c8fcfc..a760654a6 100644 --- a/internal/setup/monolith.go +++ b/internal/setup/monolith.go @@ -65,7 +65,7 @@ type Monolith struct { // AddAllPublicRoutes attaches all public paths to the given router func (m *Monolith) AddAllPublicRoutes(csMux, ssMux, keyMux, mediaMux *mux.Router) { clientapi.AddPublicRoutes( - csMux, &m.Config.ClientAPI, m.KafkaProducer, m.DeviceDB, m.AccountDB, + csMux, &m.Config.ClientAPI, m.KafkaProducer, m.AccountDB, m.FedClient, m.RoomserverAPI, m.EDUInternalAPI, m.AppserviceAPI, m.StateAPI, transactions.New(), m.FederationSenderAPI, m.UserAPI, m.KeyAPI, m.ExtPublicRoomsProvider, diff --git a/userapi/api/api.go b/userapi/api/api.go index 84338dbf2..e6d05c335 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -61,7 +61,7 @@ type PerformDeviceUpdateResponse struct { type PerformDeviceDeletionRequest struct { UserID string - // The devices to delete + // The devices to delete. An empty slice means delete all devices. DeviceIDs []string } @@ -192,8 +192,7 @@ type Device struct { // The unique ID of the session identified by the access token. // Can be used as a secure substitution in places where data needs to be // associated with access tokens. - SessionID int64 - // TODO: display name, last used timestamp, keys, etc + SessionID int64 DisplayName string } diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 05cecc1bc..b97f148e0 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -123,12 +123,21 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe if domain != a.ServerName { return fmt.Errorf("cannot PerformDeviceDeletion of remote users: got %s want %s", domain, a.ServerName) } - err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) + deletedDeviceIDs := req.DeviceIDs + if len(req.DeviceIDs) == 0 { + var devices []api.Device + devices, err = a.DeviceDB.RemoveAllDevices(ctx, local) + for _, d := range devices { + deletedDeviceIDs = append(deletedDeviceIDs, d.ID) + } + } else { + err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) + } if err != nil { return err } // create empty device keys and upload them to delete what was once there and trigger device list changes - return a.deviceListUpdate(req.UserID, req.DeviceIDs) + return a.deviceListUpdate(req.UserID, deletedDeviceIDs) } func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error { diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index 3c9ec934a..9b4261c9d 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -35,5 +35,6 @@ type Database interface { UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error - RemoveAllDevices(ctx context.Context, localpart string) error + // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. + RemoveAllDevices(ctx context.Context, localpart string) (devices []api.Device, err error) } diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go index 03bf7c722..282466f8d 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/devices/postgres/devices_table.go @@ -251,11 +251,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s } func (s *devicesStatements) selectDevicesByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, txn *sql.Tx, localpart string, ) ([]api.Device, error) { devices := []api.Device{} - - rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) if err != nil { return devices, err diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index 4a7c7f975..04dae9864 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -68,7 +68,7 @@ func (d *Database) GetDeviceByID( func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, localpart) + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -176,11 +176,16 @@ func (d *Database) RemoveDevices( // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( ctx context.Context, localpart string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { +) (devices []api.Device, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) + if err != nil { + return err + } if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { return err } return nil }) + return } diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index c93e8b772..ecf43524a 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -231,11 +231,10 @@ func (s *devicesStatements) selectDeviceByID( } func (s *devicesStatements) selectDevicesByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, txn *sql.Tx, localpart string, ) ([]api.Device, error) { devices := []api.Device{} - - rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) if err != nil { return devices, err diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index 4f426c6ed..50b5721ec 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -72,7 +72,7 @@ func (d *Database) GetDeviceByID( func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, localpart) + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -180,11 +180,16 @@ func (d *Database) RemoveDevices( // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( ctx context.Context, localpart string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { +) (devices []api.Device, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) + if err != nil { + return err + } if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { return err } return nil }) + return }