Remove device DB from clientapi

This commit is contained in:
Kegan Dougal 2020-08-27 16:12:53 +01:00
parent c0f28845f8
commit e1bf62c629
13 changed files with 81 additions and 74 deletions

View file

@ -30,7 +30,6 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/dendrite/userapi/storage/devices"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -39,7 +38,6 @@ func AddPublicRoutes(
router *mux.Router, router *mux.Router,
cfg *config.ClientAPI, cfg *config.ClientAPI,
producer sarama.SyncProducer, producer sarama.SyncProducer,
deviceDB devices.Database,
accountsDB accounts.Database, accountsDB accounts.Database,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
@ -59,7 +57,7 @@ func AddPublicRoutes(
routing.Setup( routing.Setup(
router, cfg, eduInputAPI, rsAPI, asAPI, router, cfg, eduInputAPI, rsAPI, asAPI,
accountsDB, deviceDB, userAPI, federation, accountsDB, userAPI, federation,
syncProducer, transactionsCache, fsAPI, stateAPI, keyAPI, extRoomsProvider, syncProducer, transactionsCache, fsAPI, stateAPI, keyAPI, extRoomsProvider,
) )
} }

View file

@ -15,7 +15,6 @@
package routing package routing
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -23,7 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/devices" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -50,57 +49,56 @@ type devicesDeleteJSON struct {
// GetDeviceByID handles /devices/{deviceID} // GetDeviceByID handles /devices/{deviceID}
func GetDeviceByID( func GetDeviceByID(
req *http.Request, deviceDB devices.Database, device *api.Device, req *http.Request, userAPI userapi.UserInternalAPI, device *api.Device,
deviceID string, deviceID string,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) var queryRes userapi.QueryDevicesResponse
err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{
UserID: device.UserID,
}, &queryRes)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("QueryDevices failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
var targetDevice *userapi.Device
ctx := req.Context() for _, device := range queryRes.Devices {
dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID) if device.ID == deviceID {
if err == sql.ErrNoRows { targetDevice = &device
break
}
}
if targetDevice == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Unknown device"), JSON: jsonerror.NotFound("Unknown device"),
} }
} else if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByID failed")
return jsonerror.InternalServerError()
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: deviceJSON{ JSON: deviceJSON{
DeviceID: dev.ID, DeviceID: targetDevice.ID,
DisplayName: dev.DisplayName, DisplayName: targetDevice.DisplayName,
}, },
} }
} }
// GetDevicesByLocalpart handles /devices // GetDevicesByLocalpart handles /devices
func GetDevicesByLocalpart( func GetDevicesByLocalpart(
req *http.Request, deviceDB devices.Database, device *api.Device, req *http.Request, userAPI userapi.UserInternalAPI, device *api.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) var queryRes userapi.QueryDevicesResponse
err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{
UserID: device.UserID,
}, &queryRes)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("QueryDevices failed")
return jsonerror.InternalServerError()
}
ctx := req.Context()
deviceList, err := deviceDB.GetDevicesByLocalpart(ctx, localpart)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDevicesByLocalpart failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
res := devicesJSON{} res := devicesJSON{}
for _, dev := range deviceList { for _, dev := range queryRes.Devices {
res.Devices = append(res.Devices, deviceJSON{ res.Devices = append(res.Devices, deviceJSON{
DeviceID: dev.ID, DeviceID: dev.ID,
DisplayName: dev.DisplayName, DisplayName: dev.DisplayName,

View file

@ -19,23 +19,21 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/devices" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// Logout handles POST /logout // Logout handles POST /logout
func Logout( func Logout(
req *http.Request, deviceDB devices.Database, device *api.Device, req *http.Request, userAPI userapi.UserInternalAPI, device *api.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) var performRes userapi.PerformDeviceDeletionResponse
err := userAPI.PerformDeviceDeletion(req.Context(), &userapi.PerformDeviceDeletionRequest{
UserID: device.UserID,
DeviceIDs: []string{device.ID},
}, &performRes)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed")
return jsonerror.InternalServerError()
}
if err := deviceDB.RemoveDevice(req.Context(), device.ID, localpart); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveDevice failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -47,16 +45,15 @@ func Logout(
// LogoutAll handles POST /logout/all // LogoutAll handles POST /logout/all
func LogoutAll( func LogoutAll(
req *http.Request, deviceDB devices.Database, device *api.Device, req *http.Request, userAPI userapi.UserInternalAPI, device *api.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) var performRes userapi.PerformDeviceDeletionResponse
err := userAPI.PerformDeviceDeletion(req.Context(), &userapi.PerformDeviceDeletionRequest{
UserID: device.UserID,
DeviceIDs: nil,
}, &performRes)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed")
return jsonerror.InternalServerError()
}
if err := deviceDB.RemoveAllDevices(req.Context(), localpart); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveAllDevices failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

View file

@ -35,7 +35,6 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/dendrite/userapi/storage/devices"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -52,7 +51,6 @@ func Setup(
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
accountDB accounts.Database, accountDB accounts.Database,
deviceDB devices.Database,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
syncProducer *producers.SyncAPIProducer, syncProducer *producers.SyncAPIProducer,
@ -322,13 +320,13 @@ func Setup(
r0mux.Handle("/logout", r0mux.Handle("/logout",
httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return Logout(req, deviceDB, device) return Logout(req, userAPI, device)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/logout/all", r0mux.Handle("/logout/all",
httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return LogoutAll(req, deviceDB, device) return LogoutAll(req, userAPI, device)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
@ -632,7 +630,7 @@ func Setup(
r0mux.Handle("/devices", r0mux.Handle("/devices",
httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetDevicesByLocalpart(req, deviceDB, device) return GetDevicesByLocalpart(req, userAPI, device)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -642,7 +640,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetDeviceByID(req, deviceDB, device, vars["deviceID"]) return GetDeviceByID(req, userAPI, device, vars["deviceID"])
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)

View file

@ -27,7 +27,6 @@ func main() {
defer base.Close() // nolint: errcheck defer base.Close() // nolint: errcheck
accountDB := base.CreateAccountsDB() accountDB := base.CreateAccountsDB()
deviceDB := base.CreateDeviceDB()
federation := base.CreateFederationClient() federation := base.CreateFederationClient()
asQuery := base.AppserviceHTTPClient() asQuery := base.AppserviceHTTPClient()
@ -39,7 +38,7 @@ func main() {
keyAPI := base.KeyServerHTTPClient() keyAPI := base.KeyServerHTTPClient()
clientapi.AddPublicRoutes( clientapi.AddPublicRoutes(
base.PublicClientAPIMux, &base.Cfg.ClientAPI, base.KafkaProducer, deviceDB, accountDB, federation, base.PublicClientAPIMux, &base.Cfg.ClientAPI, base.KafkaProducer, accountDB, federation,
rsAPI, eduInputAPI, asQuery, stateAPI, transactions.New(), fsAPI, userAPI, keyAPI, nil, rsAPI, eduInputAPI, asQuery, stateAPI, transactions.New(), fsAPI, userAPI, keyAPI, nil,
) )

View file

@ -65,7 +65,7 @@ type Monolith struct {
// AddAllPublicRoutes attaches all public paths to the given router // AddAllPublicRoutes attaches all public paths to the given router
func (m *Monolith) AddAllPublicRoutes(csMux, ssMux, keyMux, mediaMux *mux.Router) { func (m *Monolith) AddAllPublicRoutes(csMux, ssMux, keyMux, mediaMux *mux.Router) {
clientapi.AddPublicRoutes( clientapi.AddPublicRoutes(
csMux, &m.Config.ClientAPI, m.KafkaProducer, m.DeviceDB, m.AccountDB, csMux, &m.Config.ClientAPI, m.KafkaProducer, m.AccountDB,
m.FedClient, m.RoomserverAPI, m.FedClient, m.RoomserverAPI,
m.EDUInternalAPI, m.AppserviceAPI, m.StateAPI, transactions.New(), m.EDUInternalAPI, m.AppserviceAPI, m.StateAPI, transactions.New(),
m.FederationSenderAPI, m.UserAPI, m.KeyAPI, m.ExtPublicRoomsProvider, m.FederationSenderAPI, m.UserAPI, m.KeyAPI, m.ExtPublicRoomsProvider,

View file

@ -61,7 +61,7 @@ type PerformDeviceUpdateResponse struct {
type PerformDeviceDeletionRequest struct { type PerformDeviceDeletionRequest struct {
UserID string UserID string
// The devices to delete // The devices to delete. An empty slice means delete all devices.
DeviceIDs []string DeviceIDs []string
} }
@ -193,7 +193,6 @@ type Device struct {
// Can be used as a secure substitution in places where data needs to be // Can be used as a secure substitution in places where data needs to be
// associated with access tokens. // associated with access tokens.
SessionID int64 SessionID int64
// TODO: display name, last used timestamp, keys, etc
DisplayName string DisplayName string
} }

View file

@ -123,12 +123,21 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
if domain != a.ServerName { if domain != a.ServerName {
return fmt.Errorf("cannot PerformDeviceDeletion of remote users: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot PerformDeviceDeletion of remote users: got %s want %s", domain, a.ServerName)
} }
deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 {
var devices []api.Device
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local)
for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
}
} else {
err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs)
}
if err != nil { if err != nil {
return err return err
} }
// create empty device keys and upload them to delete what was once there and trigger device list changes // create empty device keys and upload them to delete what was once there and trigger device list changes
return a.deviceListUpdate(req.UserID, req.DeviceIDs) return a.deviceListUpdate(req.UserID, deletedDeviceIDs)
} }
func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error { func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error {

View file

@ -35,5 +35,6 @@ type Database interface {
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error
RemoveAllDevices(ctx context.Context, localpart string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart string) (devices []api.Device, err error)
} }

View file

@ -251,11 +251,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
} }
func (s *devicesStatements) selectDevicesByLocalpart( func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) ([]api.Device, error) { ) ([]api.Device, error) {
devices := []api.Device{} devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart)
rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart)
if err != nil { if err != nil {
return devices, err return devices, err

View file

@ -68,7 +68,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart( func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ([]api.Device, error) { ) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, localpart) return d.devices.selectDevicesByLocalpart(ctx, nil, localpart)
} }
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -176,11 +176,16 @@ func (d *Database) RemoveDevices(
// If something went wrong during the deletion, it will return the SQL error. // If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices( func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) error { ) (devices []api.Device, err error) {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err return err
} }
return nil return nil
}) })
return
} }

View file

@ -231,11 +231,10 @@ func (s *devicesStatements) selectDeviceByID(
} }
func (s *devicesStatements) selectDevicesByLocalpart( func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) ([]api.Device, error) { ) ([]api.Device, error) {
devices := []api.Device{} devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart)
rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart)
if err != nil { if err != nil {
return devices, err return devices, err

View file

@ -72,7 +72,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart( func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ([]api.Device, error) { ) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, localpart) return d.devices.selectDevicesByLocalpart(ctx, nil, localpart)
} }
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -180,11 +180,16 @@ func (d *Database) RemoveDevices(
// If something went wrong during the deletion, it will return the SQL error. // If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices( func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) error { ) (devices []api.Device, err error) {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err return err
} }
return nil return nil
}) })
return
} }