From 505dea2a00ac412625b0e26775b850aa6b2ade9d Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Fri, 31 Jul 2020 11:25:59 +0100 Subject: [PATCH] Add PerformDeviceUpdate and fix a few bugs - Correct device deletion query on sqlite - Return no keys on /keys/query rather than an empty key --- clientapi/routing/device.go | 48 ++++++++----------- clientapi/routing/routing.go | 2 +- keyserver/internal/internal.go | 3 ++ userapi/api/api.go | 11 +++++ userapi/internal/api.go | 29 +++++++++++ userapi/inthttp/client.go | 9 ++++ userapi/inthttp/server.go | 13 +++++ .../storage/devices/sqlite3/devices_table.go | 2 +- 8 files changed, 88 insertions(+), 29 deletions(-) diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index 11c6c7827..d0b3bdbe5 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -115,33 +115,9 @@ func GetDevicesByLocalpart( // UpdateDeviceByID handles PUT on /devices/{deviceID} func UpdateDeviceByID( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI api.UserInternalAPI, device *api.Device, deviceID string, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - - ctx := req.Context() - dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID) - if err == sql.ErrNoRows { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Unknown device"), - } - } else if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByID failed") - return jsonerror.InternalServerError() - } - - if dev.UserID != device.UserID { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("device not owned by current user"), - } - } defer req.Body.Close() // nolint: errcheck @@ -152,10 +128,28 @@ func UpdateDeviceByID( return jsonerror.InternalServerError() } - if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.UpdateDevice failed") + var performRes api.PerformDeviceUpdateResponse + err := userAPI.PerformDeviceUpdate(req.Context(), &api.PerformDeviceUpdateRequest{ + RequestingUserID: device.UserID, + DeviceID: deviceID, + DisplayName: payload.DisplayName, + }, &performRes) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceUpdate failed") return jsonerror.InternalServerError() } + if !performRes.DeviceExists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.Forbidden("device does not exist"), + } + } + if performRes.Forbidden { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("device not owned by current user"), + } + } return util.JSONResponse{ Code: http.StatusOK, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 2d992392f..0e58129ef 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -644,7 +644,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return UpdateDeviceByID(req, deviceDB, device, vars["deviceID"]) + return UpdateDeviceByID(req, userAPI, device, vars["deviceID"]) }), ).Methods(http.MethodPut, http.MethodOptions) diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 480d1084e..bb8286635 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -206,6 +206,9 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques res.DeviceKeys[userID] = make(map[string]json.RawMessage) } for _, dk := range deviceKeys { + if len(dk.KeyJSON) == 0 { + continue // don't include blank keys + } // inject display name if known dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { DisplayName string `json:"device_display_name,omitempty"` diff --git a/userapi/api/api.go b/userapi/api/api.go index 5c964c4fd..84338dbf2 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -28,6 +28,7 @@ type UserInternalAPI interface { PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error + PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error @@ -48,6 +49,16 @@ type InputAccountDataRequest struct { type InputAccountDataResponse struct { } +type PerformDeviceUpdateRequest struct { + RequestingUserID string + DeviceID string + DisplayName *string +} +type PerformDeviceUpdateResponse struct { + DeviceExists bool + Forbidden bool +} + type PerformDeviceDeletionRequest struct { UserID string // The devices to delete diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 5de308166..b9d188229 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -148,6 +148,35 @@ func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) er return nil } +func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { + localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") + return err + } + dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID) + if err == sql.ErrNoRows { + res.DeviceExists = false + return nil + } else if err != nil { + util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed") + return err + } + res.DeviceExists = true + + if dev.UserID != req.RequestingUserID { + res.Forbidden = true + return nil + } + + err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") + return err + } + return nil +} + func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfileRequest, res *api.QueryProfileResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 47e2110f9..5f4df0eb1 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -31,6 +31,7 @@ const ( PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformAccountCreationPath = "/userapi/performAccountCreation" PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" + PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" QueryProfilePath = "/userapi/queryProfile" QueryAccessTokenPath = "/userapi/queryAccessToken" @@ -104,6 +105,14 @@ func (h *httpUserInternalAPI) PerformDeviceDeletion( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +func (h *httpUserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceUpdate") + defer span.Finish() + + apiURL := h.apiURL + PerformDeviceUpdatePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + func (h *httpUserInternalAPI) QueryProfile( ctx context.Context, request *api.QueryProfileRequest, diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index ebb9bf4e8..47d68ff21 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -52,6 +52,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(PerformDeviceUpdatePath, + httputil.MakeInternalAPI("performDeviceUpdate", func(req *http.Request) util.JSONResponse { + request := api.PerformDeviceUpdateRequest{} + response := api.PerformDeviceUpdateResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformDeviceUpdate(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(PerformDeviceDeletionPath, httputil.MakeInternalAPI("performDeviceDeletion", func(req *http.Request) util.JSONResponse { request := api.PerformDeviceDeletionRequest{} diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index efe6f927c..e386b9b86 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -174,7 +174,7 @@ func (s *devicesStatements) deleteDevice( func (s *devicesStatements) deleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { - orig := strings.Replace(deleteDevicesSQL, "($1)", sqlutil.QueryVariadic(len(devices)), 1) + orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadic(len(devices)), 1) prep, err := s.db.Prepare(orig) if err != nil { return err