Add PerformDeviceUpdate and fix a few bugs

- Correct device deletion query on sqlite
- Return no keys on /keys/query rather than an empty key
This commit is contained in:
Kegan Dougal 2020-07-31 11:25:59 +01:00
parent d23d031565
commit 505dea2a00
8 changed files with 88 additions and 29 deletions

View file

@ -115,33 +115,9 @@ func GetDevicesByLocalpart(
// UpdateDeviceByID handles PUT on /devices/{deviceID} // UpdateDeviceByID handles PUT on /devices/{deviceID}
func UpdateDeviceByID( func UpdateDeviceByID(
req *http.Request, deviceDB devices.Database, device *api.Device, req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
deviceID string, deviceID string,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
ctx := req.Context()
dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID)
if err == sql.ErrNoRows {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Unknown device"),
}
} else if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByID failed")
return jsonerror.InternalServerError()
}
if dev.UserID != device.UserID {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("device not owned by current user"),
}
}
defer req.Body.Close() // nolint: errcheck defer req.Body.Close() // nolint: errcheck
@ -152,10 +128,28 @@ func UpdateDeviceByID(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil { var performRes api.PerformDeviceUpdateResponse
util.GetLogger(req.Context()).WithError(err).Error("deviceDB.UpdateDevice failed") err := userAPI.PerformDeviceUpdate(req.Context(), &api.PerformDeviceUpdateRequest{
RequestingUserID: device.UserID,
DeviceID: deviceID,
DisplayName: payload.DisplayName,
}, &performRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceUpdate failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if !performRes.DeviceExists {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.Forbidden("device does not exist"),
}
}
if performRes.Forbidden {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("device not owned by current user"),
}
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,

View file

@ -644,7 +644,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return UpdateDeviceByID(req, deviceDB, device, vars["deviceID"]) return UpdateDeviceByID(req, userAPI, device, vars["deviceID"])
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)

View file

@ -206,6 +206,9 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
res.DeviceKeys[userID] = make(map[string]json.RawMessage) res.DeviceKeys[userID] = make(map[string]json.RawMessage)
} }
for _, dk := range deviceKeys { for _, dk := range deviceKeys {
if len(dk.KeyJSON) == 0 {
continue // don't include blank keys
}
// inject display name if known // inject display name if known
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"` DisplayName string `json:"device_display_name,omitempty"`

View file

@ -28,6 +28,7 @@ type UserInternalAPI interface {
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
@ -48,6 +49,16 @@ type InputAccountDataRequest struct {
type InputAccountDataResponse struct { type InputAccountDataResponse struct {
} }
type PerformDeviceUpdateRequest struct {
RequestingUserID string
DeviceID string
DisplayName *string
}
type PerformDeviceUpdateResponse struct {
DeviceExists bool
Forbidden bool
}
type PerformDeviceDeletionRequest struct { type PerformDeviceDeletionRequest struct {
UserID string UserID string
// The devices to delete // The devices to delete

View file

@ -148,6 +148,35 @@ func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) er
return nil return nil
} }
func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return err
}
dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID)
if err == sql.ErrNoRows {
res.DeviceExists = false
return nil
} else if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed")
return err
}
res.DeviceExists = true
if dev.UserID != req.RequestingUserID {
res.Forbidden = true
return nil
}
err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err
}
return nil
}
func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfileRequest, res *api.QueryProfileResponse) error { func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfileRequest, res *api.QueryProfileResponse) error {
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {

View file

@ -31,6 +31,7 @@ const (
PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformDeviceCreationPath = "/userapi/performDeviceCreation"
PerformAccountCreationPath = "/userapi/performAccountCreation" PerformAccountCreationPath = "/userapi/performAccountCreation"
PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" PerformDeviceDeletionPath = "/userapi/performDeviceDeletion"
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
QueryProfilePath = "/userapi/queryProfile" QueryProfilePath = "/userapi/queryProfile"
QueryAccessTokenPath = "/userapi/queryAccessToken" QueryAccessTokenPath = "/userapi/queryAccessToken"
@ -104,6 +105,14 @@ func (h *httpUserInternalAPI) PerformDeviceDeletion(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpUserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceUpdate")
defer span.Finish()
apiURL := h.apiURL + PerformDeviceUpdatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) QueryProfile( func (h *httpUserInternalAPI) QueryProfile(
ctx context.Context, ctx context.Context,
request *api.QueryProfileRequest, request *api.QueryProfileRequest,

View file

@ -52,6 +52,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(PerformDeviceUpdatePath,
httputil.MakeInternalAPI("performDeviceUpdate", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceUpdateRequest{}
response := api.PerformDeviceUpdateResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformDeviceUpdate(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformDeviceDeletionPath, internalAPIMux.Handle(PerformDeviceDeletionPath,
httputil.MakeInternalAPI("performDeviceDeletion", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("performDeviceDeletion", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceDeletionRequest{} request := api.PerformDeviceDeletionRequest{}

View file

@ -174,7 +174,7 @@ func (s *devicesStatements) deleteDevice(
func (s *devicesStatements) deleteDevices( func (s *devicesStatements) deleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string, ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error { ) error {
orig := strings.Replace(deleteDevicesSQL, "($1)", sqlutil.QueryVariadic(len(devices)), 1) orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadic(len(devices)), 1)
prep, err := s.db.Prepare(orig) prep, err := s.db.Prepare(orig)
if err != nil { if err != nil {
return err return err