diff --git a/clientapi/clientapi_test.go b/clientapi/clientapi_test.go index 98ea279a1..ddb72b71c 100644 --- a/clientapi/clientapi_test.go +++ b/clientapi/clientapi_test.go @@ -1762,6 +1762,13 @@ func (d dummyStore) FindSharedRooms(userID id.UserID) []id.RoomID { func TestKeyBackup(t *testing.T) { alice := test.NewUser(t) + handleResponseCode := func(t *testing.T, rec *httptest.ResponseRecorder, expectedCode int) { + t.Helper() + if rec.Code != expectedCode { + t.Fatalf("expected HTTP %d, but got %d: %s", expectedCode, rec.Code, rec.Body.String()) + } + } + testCases := []struct { name string request func(t *testing.T) *http.Request @@ -1770,40 +1777,299 @@ func TestKeyBackup(t *testing.T) { { name: "can not create backup with invalid JSON", request: func(t *testing.T) *http.Request { - reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1"`) + reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1"`) // missing closing braces return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody) }, validate: func(t *testing.T, rec *httptest.ResponseRecorder) { - if rec.Code != http.StatusBadRequest { - t.Fatalf("HTTP[%d]: expected an error, but got none: %s", rec.Code, rec.Body.String()) - } + handleResponseCode(t, rec, http.StatusBadRequest) }, }, { - name: "can create backup", + name: "can not create backup with missing auth_data", // as this would result in MarshalJSON errors when querying again request: func(t *testing.T) *http.Request { reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1"}`) return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody) }, validate: func(t *testing.T, rec *httptest.ResponseRecorder) { - if rec.Code != http.StatusOK { - t.Fatalf("HTTP[%d]: expected no error, but got: %s", rec.Code, rec.Body.String()) - } + handleResponseCode(t, rec, http.StatusBadRequest) + }, + }, + { + name: "can create backup", + request: func(t *testing.T) *http.Request { + reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1","auth_data":{"data":"random"}}`) + return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody) + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) wantVersion := "1" if gotVersion := gjson.GetBytes(rec.Body.Bytes(), "version").Str; gotVersion != wantVersion { t.Fatalf("expected version '%s', got '%s'", wantVersion, gotVersion) } }, }, + { + name: "can not query backup for invalid version", + request: func(t *testing.T) *http.Request { + return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version/1337", nil) + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusNotFound) + }, + }, + { + name: "can not query backup for invalid version string", + request: func(t *testing.T) *http.Request { + return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version/notanumber", nil) + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusNotFound) + }, + }, + { + name: "can query backup", + request: func(t *testing.T) *http.Request { + return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version", nil) + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + wantVersion := "1" + if gotVersion := gjson.GetBytes(rec.Body.Bytes(), "version").Str; gotVersion != wantVersion { + t.Fatalf("expected version '%s', got '%s'", wantVersion, gotVersion) + } + }, + }, + { + name: "can query backup without returning rooms", + request: func(t *testing.T) *http.Request { + req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys", test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "rooms").Map(); len(gotRooms) > 0 { + t.Fatalf("expected no rooms in version, but got %#v", gotRooms) + } + }, + }, + { + name: "can query backup for invalid room", + request: func(t *testing.T) *http.Request { + req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!abc:test", test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + if gotSessions := gjson.GetBytes(rec.Body.Bytes(), "sessions").Map(); len(gotSessions) > 0 { + t.Fatalf("expected no sessions in version, but got %#v", gotSessions) + } + }, + }, + { + name: "can not query backup for invalid session", + request: func(t *testing.T) *http.Request { + req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!abc:test/doesnotexist", test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusNotFound) + }, + }, + { + name: "can not update backup with missing version", + request: func(t *testing.T) *http.Request { + return test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys") + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusBadRequest) + }, + }, + { + name: "can not update backup with invalid data", + request: func(t *testing.T) *http.Request { + reqBody := test.WithJSONBody(t, "") + req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{ + "version": "0", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusBadRequest) + }, + }, + { + name: "can not update backup with wrong version", + request: func(t *testing.T) *http.Request { + reqBody := test.WithJSONBody(t, map[string]interface{}{ + "rooms": map[string]interface{}{ + "!testroom:test": map[string]interface{}{ + "sessions": map[string]uapi.KeyBackupSession{}, + }, + }, + }) + req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{ + "version": "5", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusForbidden) + }, + }, + { + name: "can update backup with correct version", + request: func(t *testing.T) *http.Request { + reqBody := test.WithJSONBody(t, map[string]interface{}{ + "rooms": map[string]interface{}{ + "!testroom:test": map[string]interface{}{ + "sessions": map[string]uapi.KeyBackupSession{ + "dummySession": { + FirstMessageIndex: 1, + }, + }, + }, + }, + }) + req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + }, + }, + { + name: "can update backup with correct version for specific room", + request: func(t *testing.T) *http.Request { + reqBody := test.WithJSONBody(t, map[string]interface{}{ + "sessions": map[string]uapi.KeyBackupSession{ + "dummySession": { + FirstMessageIndex: 1, + IsVerified: true, + SessionData: json.RawMessage("{}"), + }, + }, + }) + req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys/!testroom:test", reqBody, test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + t.Logf("%#v", rec.Body.String()) + }, + }, + { + name: "can update backup with correct version for specific room and session", + request: func(t *testing.T) *http.Request { + reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{ + FirstMessageIndex: 1, + SessionData: json.RawMessage("{}"), + IsVerified: true, + ForwardedCount: 0, + }) + req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys/!testroom:test/dummySession", reqBody, test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + }, + }, + { + name: "can update backup by version", + request: func(t *testing.T) *http.Request { + reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{ + FirstMessageIndex: 1, + SessionData: json.RawMessage("{asd}"), + IsVerified: true, + ForwardedCount: 0, + }) + req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/version/1", reqBody, test.WithQueryParams(map[string]string{"version": "1"})) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + t.Logf("%#v", rec.Body.String()) + }, + }, + { + name: "can not update backup by version for invalid version", + request: func(t *testing.T) *http.Request { + reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{ + FirstMessageIndex: 1, + SessionData: json.RawMessage("{}"), + IsVerified: true, + ForwardedCount: 0, + }) + req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/version/2", reqBody, test.WithQueryParams(map[string]string{"version": "1"})) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + }, + }, + { + name: "can query backup sessions", + request: func(t *testing.T) *http.Request { + req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys", test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "rooms").Map(); len(gotRooms) != 1 { + t.Fatalf("expected one room in response, but got %#v", rec.Body.String()) + } + }, + }, + { + name: "can query backup sessions by room", + request: func(t *testing.T) *http.Request { + req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!testroom:test", test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "sessions").Map(); len(gotRooms) != 1 { + t.Fatalf("expected one session in response, but got %#v", rec.Body.String()) + } + }, + }, + { + name: "can query backup sessions by room and sessionID", + request: func(t *testing.T) *http.Request { + req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!testroom:test/dummySession", test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + if !gjson.GetBytes(rec.Body.Bytes(), "is_verified").Bool() { + t.Fatalf("expected session to be verified, but wasn't: %#v", rec.Body.String()) + } + }, + }, { name: "can not delete invalid version backup", request: func(t *testing.T) *http.Request { return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/2", nil) }, validate: func(t *testing.T, rec *httptest.ResponseRecorder) { - if rec.Code != http.StatusNotFound { - t.Fatalf("HTTP[%d]: expected HTTP 404, but got: %d", rec.Code, rec.Code) - } + handleResponseCode(t, rec, http.StatusNotFound) }, }, { @@ -1812,9 +2078,7 @@ func TestKeyBackup(t *testing.T) { return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/1", nil) }, validate: func(t *testing.T, rec *httptest.ResponseRecorder) { - if rec.Code != http.StatusOK { - t.Fatalf("HTTP[%d]: expected HTTP 200, but got: %d", rec.Code, rec.Code) - } + handleResponseCode(t, rec, http.StatusOK) }, }, { @@ -1823,9 +2087,7 @@ func TestKeyBackup(t *testing.T) { return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/1", nil) }, validate: func(t *testing.T, rec *httptest.ResponseRecorder) { - if rec.Code != http.StatusOK { - t.Fatalf("HTTP[%d]: expected HTTP 200, but got: %d", rec.Code, rec.Code) - } + handleResponseCode(t, rec, http.StatusOK) }, }, } diff --git a/clientapi/jsonerror/jsonerror.go b/clientapi/jsonerror/jsonerror.go index be7d13a96..436e168ab 100644 --- a/clientapi/jsonerror/jsonerror.go +++ b/clientapi/jsonerror/jsonerror.go @@ -171,6 +171,23 @@ func LeaveServerNoticeError() *MatrixError { } } +// ErrRoomKeysVersion is an error returned by `PUT /room_keys/keys` +type ErrRoomKeysVersion struct { + MatrixError + CurrentVersion string `json:"current_version"` +} + +// WrongBackupVersionError is an error returned by `PUT /room_keys/keys` +func WrongBackupVersionError(currentVersion string) *ErrRoomKeysVersion { + return &ErrRoomKeysVersion{ + MatrixError: MatrixError{ + ErrCode: "M_WRONG_ROOM_KEYS_VERSION", + Err: "Wrong backup version.", + }, + CurrentVersion: currentVersion, + } +} + type IncompatibleRoomVersionError struct { RoomVersion string `json:"room_version"` Error string `json:"error"` diff --git a/clientapi/routing/key_backup.go b/clientapi/routing/key_backup.go index 0f65f75a1..56b05db15 100644 --- a/clientapi/routing/key_backup.go +++ b/clientapi/routing/key_backup.go @@ -61,22 +61,26 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de if resErr != nil { return *resErr } - var performKeyBackupResp userapi.PerformKeyBackupResponse - if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + if len(kb.AuthData) == 0 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("missing auth_data"), + } + } + version, err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ UserID: device.UserID, Version: "", AuthData: kb.AuthData, Algorithm: kb.Algorithm, - }, &performKeyBackupResp); err != nil { - return jsonerror.InternalServerError() - } - if performKeyBackupResp.Error != "" { - return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + }) + if err != nil { + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %w", err)) } + return util.JSONResponse{ Code: 200, JSON: keyBackupVersionCreateResponse{ - Version: performKeyBackupResp.Version, + Version: version, }, } } @@ -84,15 +88,12 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de // KeyBackupVersion returns the key backup version specified. If `version` is empty, the latest `keyBackupVersionResponse` is returned. // Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version} func KeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse { - var queryResp userapi.QueryKeyBackupResponse - if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ + queryResp, err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ UserID: device.UserID, Version: version, - }, &queryResp); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) - } - if queryResp.Error != "" { - return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) + }) + if err != nil { + return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", err)) } if !queryResp.Exists { return util.JSONResponse{ @@ -120,31 +121,29 @@ func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.ClientUse if resErr != nil { return *resErr } - var performKeyBackupResp userapi.PerformKeyBackupResponse - if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + performKeyBackupResp, err := userAPI.UpdateBackupKeyAuthData(req.Context(), &userapi.PerformKeyBackupRequest{ UserID: device.UserID, Version: version, AuthData: kb.AuthData, Algorithm: kb.Algorithm, - }, &performKeyBackupResp); err != nil { - return jsonerror.InternalServerError() - } - if performKeyBackupResp.Error != "" { - if performKeyBackupResp.BadInput { - return util.JSONResponse{ - Code: 400, - JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), - } + }) + switch e := err.(type) { + case *jsonerror.ErrRoomKeysVersion: + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: e, } - return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + case nil: + default: + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %w", e)) } + if !performKeyBackupResp.Exists { return util.JSONResponse{ Code: 404, JSON: jsonerror.NotFound("backup version not found"), } } - // Unclear what the 200 body should be return util.JSONResponse{ Code: 200, JSON: keyBackupVersionCreateResponse{ @@ -176,22 +175,21 @@ func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de func UploadBackupKeys( req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string, keys *keyBackupSessionRequest, ) util.JSONResponse { - var performKeyBackupResp userapi.PerformKeyBackupResponse - if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + performKeyBackupResp, err := userAPI.UpdateBackupKeyAuthData(req.Context(), &userapi.PerformKeyBackupRequest{ UserID: device.UserID, Version: version, Keys: *keys, - }, &performKeyBackupResp); err != nil && performKeyBackupResp.Error == "" { - return jsonerror.InternalServerError() - } - if performKeyBackupResp.Error != "" { - if performKeyBackupResp.BadInput { - return util.JSONResponse{ - Code: 400, - JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), - } + }) + + switch e := err.(type) { + case *jsonerror.ErrRoomKeysVersion: + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: e, } - return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + case nil: + default: + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %w", e)) } if !performKeyBackupResp.Exists { return util.JSONResponse{ @@ -212,18 +210,15 @@ func UploadBackupKeys( func GetBackupKeys( req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version, roomID, sessionID string, ) util.JSONResponse { - var queryResp userapi.QueryKeyBackupResponse - if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ + queryResp, err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ UserID: device.UserID, Version: version, ReturnKeys: true, KeysForRoomID: roomID, KeysForSessionID: sessionID, - }, &queryResp); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) - } - if queryResp.Error != "" { - return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) + }) + if err != nil { + return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %w", err)) } if !queryResp.Exists { return util.JSONResponse{ @@ -245,17 +240,20 @@ func GetBackupKeys( } } else if roomID != "" { roomData, ok := queryResp.Keys[roomID] - if ok { - // wrap response in "sessions" - return util.JSONResponse{ - Code: 200, - JSON: struct { - Sessions map[string]userapi.KeyBackupSession `json:"sessions"` - }{ - Sessions: roomData, - }, - } + if !ok { + // If no keys are found, then an object with an empty sessions property will be returned + roomData = make(map[string]userapi.KeyBackupSession) } + // wrap response in "sessions" + return util.JSONResponse{ + Code: 200, + JSON: struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }{ + Sessions: roomData, + }, + } + } else { // response is the same as the upload request var resp keyBackupSessionRequest diff --git a/userapi/api/api.go b/userapi/api/api.go index fef6c2af3..4e13a3b94 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -115,8 +115,9 @@ type ClientUserAPI interface { type KeyBackupAPI interface { DeleteKeyBackup(ctx context.Context, userID, version string) (bool, error) - PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error - QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error + PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest) (string, error) + QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest) (*QueryKeyBackupResponse, error) + UpdateBackupKeyAuthData(ctx context.Context, req *PerformKeyBackupRequest) (*PerformKeyBackupResponse, error) } type ProfileAPI interface { @@ -184,9 +185,6 @@ type InternalKeyBackupSession struct { } type PerformKeyBackupResponse struct { - Error string // set if there was a problem performing the request - BadInput bool // if set, the Error was due to bad input (HTTP 400) - Exists bool // set to true if the Version exists Version string // the newly created version @@ -204,7 +202,6 @@ type QueryKeyBackupRequest struct { } type QueryKeyBackupResponse struct { - Error string Exists bool Algorithm string `json:"algorithm"` diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 9745b7a73..e9bf62b32 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -25,6 +25,7 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/jsonerror" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/gomatrixserverlib" @@ -687,47 +688,39 @@ func (a *UserInternalAPI) DeleteKeyBackup(ctx context.Context, userID, version s return a.DB.DeleteKeyBackup(ctx, userID, version) } -func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error { +func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest) (string, error) { // Create metadata - if req.Version == "" { - version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) - if err != nil { - res.Error = fmt.Sprintf("failed to create backup: %s", err) - } - res.Exists = err == nil - res.Version = version - return nil - } + return a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) +} + +func (a *UserInternalAPI) UpdateBackupKeyAuthData(ctx context.Context, req *api.PerformKeyBackupRequest) (*api.PerformKeyBackupResponse, error) { + res := &api.PerformKeyBackupResponse{} // Update metadata if len(req.Keys.Rooms) == 0 { err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) - if err != nil { - res.Error = fmt.Sprintf("failed to update backup: %s", err) - } res.Exists = err == nil res.Version = req.Version - return nil + if err != nil { + return res, fmt.Errorf("failed to update backup: %w", err) + } + return res, nil } // Upload Keys for a specific version metadata - a.uploadBackupKeys(ctx, req, res) - return nil + return a.uploadBackupKeys(ctx, req) } -func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { +func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest) (*api.PerformKeyBackupResponse, error) { + res := &api.PerformKeyBackupResponse{} // you can only upload keys for the CURRENT version version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "") if err != nil { - res.Error = fmt.Sprintf("failed to query version: %s", err) - return + return res, fmt.Errorf("failed to query version: %w", err) } if deleted { - res.Error = "backup was deleted" - return + return res, fmt.Errorf("backup was deleted") } if version != req.Version { - res.BadInput = true - res.Error = fmt.Sprintf("%s isn't the current version, %s is.", req.Version, version) - return + return res, jsonerror.WrongBackupVersionError(version) } res.Exists = true res.Version = version @@ -745,23 +738,25 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform } count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads) if err != nil { - res.Error = fmt.Sprintf("failed to upsert keys: %s", err) - return + return res, fmt.Errorf("failed to upsert keys: %w", err) } res.KeyCount = count res.KeyETag = etag + return res, nil } -func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) error { +func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest) (*api.QueryKeyBackupResponse, error) { + res := &api.QueryKeyBackupResponse{} version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version) res.Version = version if err != nil { - if err == sql.ErrNoRows { - res.Exists = false - return nil + if errors.Is(err, sql.ErrNoRows) { + return res, nil } - res.Error = fmt.Sprintf("failed to query key backup: %s", err) - return nil + if errors.Is(err, strconv.ErrSyntax) { + return res, nil + } + return res, fmt.Errorf("failed to query key backup: %s", err) } res.Algorithm = algorithm res.AuthData = authData @@ -771,18 +766,17 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB if !req.ReturnKeys { res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID) if err != nil { - res.Error = fmt.Sprintf("failed to count keys: %s", err) + return res, fmt.Errorf("failed to count keys: %w", err) } - return nil + return res, nil } result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) if err != nil { - res.Error = fmt.Sprintf("failed to query keys: %s", err) - return nil + return res, fmt.Errorf("failed to query keys: %s", err) } res.Keys = result - return nil + return res, nil } func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {