Add key backup tests (#3071)
Also slightly refactors the functions and methods to rely less on the req/res pattern we had for polylith. Returns `M_WRONG_ROOM_KEYS_VERSION` for some endpoints as per the spec
This commit is contained in:
parent
6b47cf0f6a
commit
9e9617ff84
|
@ -1758,3 +1758,377 @@ func (d dummyStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventC
|
|||
func (d dummyStore) FindSharedRooms(userID id.UserID) []id.RoomID {
|
||||
return []id.RoomID{}
|
||||
}
|
||||
|
||||
func TestKeyBackup(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
|
||||
handleResponseCode := func(t *testing.T, rec *httptest.ResponseRecorder, expectedCode int) {
|
||||
t.Helper()
|
||||
if rec.Code != expectedCode {
|
||||
t.Fatalf("expected HTTP %d, but got %d: %s", expectedCode, rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
request func(t *testing.T) *http.Request
|
||||
validate func(t *testing.T, rec *httptest.ResponseRecorder)
|
||||
}{
|
||||
{
|
||||
name: "can not create backup with invalid JSON",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1"`) // missing closing braces
|
||||
return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody)
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusBadRequest)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can not create backup with missing auth_data", // as this would result in MarshalJSON errors when querying again
|
||||
request: func(t *testing.T) *http.Request {
|
||||
reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1"}`)
|
||||
return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody)
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusBadRequest)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can create backup",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1","auth_data":{"data":"random"}}`)
|
||||
return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody)
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusOK)
|
||||
wantVersion := "1"
|
||||
if gotVersion := gjson.GetBytes(rec.Body.Bytes(), "version").Str; gotVersion != wantVersion {
|
||||
t.Fatalf("expected version '%s', got '%s'", wantVersion, gotVersion)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can not query backup for invalid version",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version/1337", nil)
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can not query backup for invalid version string",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version/notanumber", nil)
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can query backup",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version", nil)
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusOK)
|
||||
wantVersion := "1"
|
||||
if gotVersion := gjson.GetBytes(rec.Body.Bytes(), "version").Str; gotVersion != wantVersion {
|
||||
t.Fatalf("expected version '%s', got '%s'", wantVersion, gotVersion)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can query backup without returning rooms",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys", test.WithQueryParams(map[string]string{
|
||||
"version": "1",
|
||||
}))
|
||||
return req
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusOK)
|
||||
if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "rooms").Map(); len(gotRooms) > 0 {
|
||||
t.Fatalf("expected no rooms in version, but got %#v", gotRooms)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can query backup for invalid room",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!abc:test", test.WithQueryParams(map[string]string{
|
||||
"version": "1",
|
||||
}))
|
||||
return req
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusOK)
|
||||
if gotSessions := gjson.GetBytes(rec.Body.Bytes(), "sessions").Map(); len(gotSessions) > 0 {
|
||||
t.Fatalf("expected no sessions in version, but got %#v", gotSessions)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can not query backup for invalid session",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!abc:test/doesnotexist", test.WithQueryParams(map[string]string{
|
||||
"version": "1",
|
||||
}))
|
||||
return req
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can not update backup with missing version",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
return test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys")
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusBadRequest)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can not update backup with invalid data",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
reqBody := test.WithJSONBody(t, "")
|
||||
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{
|
||||
"version": "0",
|
||||
}))
|
||||
return req
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusBadRequest)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can not update backup with wrong version",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
reqBody := test.WithJSONBody(t, map[string]interface{}{
|
||||
"rooms": map[string]interface{}{
|
||||
"!testroom:test": map[string]interface{}{
|
||||
"sessions": map[string]uapi.KeyBackupSession{},
|
||||
},
|
||||
},
|
||||
})
|
||||
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{
|
||||
"version": "5",
|
||||
}))
|
||||
return req
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusForbidden)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can update backup with correct version",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
reqBody := test.WithJSONBody(t, map[string]interface{}{
|
||||
"rooms": map[string]interface{}{
|
||||
"!testroom:test": map[string]interface{}{
|
||||
"sessions": map[string]uapi.KeyBackupSession{
|
||||
"dummySession": {
|
||||
FirstMessageIndex: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{
|
||||
"version": "1",
|
||||
}))
|
||||
return req
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusOK)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can update backup with correct version for specific room",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
reqBody := test.WithJSONBody(t, map[string]interface{}{
|
||||
"sessions": map[string]uapi.KeyBackupSession{
|
||||
"dummySession": {
|
||||
FirstMessageIndex: 1,
|
||||
IsVerified: true,
|
||||
SessionData: json.RawMessage("{}"),
|
||||
},
|
||||
},
|
||||
})
|
||||
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys/!testroom:test", reqBody, test.WithQueryParams(map[string]string{
|
||||
"version": "1",
|
||||
}))
|
||||
return req
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusOK)
|
||||
t.Logf("%#v", rec.Body.String())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can update backup with correct version for specific room and session",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{
|
||||
FirstMessageIndex: 1,
|
||||
SessionData: json.RawMessage("{}"),
|
||||
IsVerified: true,
|
||||
ForwardedCount: 0,
|
||||
})
|
||||
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys/!testroom:test/dummySession", reqBody, test.WithQueryParams(map[string]string{
|
||||
"version": "1",
|
||||
}))
|
||||
return req
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusOK)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can update backup by version",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{
|
||||
FirstMessageIndex: 1,
|
||||
SessionData: json.RawMessage("{}"),
|
||||
IsVerified: true,
|
||||
ForwardedCount: 0,
|
||||
})
|
||||
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/version/1", reqBody, test.WithQueryParams(map[string]string{"version": "1"}))
|
||||
return req
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusOK)
|
||||
t.Logf("%#v", rec.Body.String())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can not update backup by version for invalid version",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{
|
||||
FirstMessageIndex: 1,
|
||||
SessionData: json.RawMessage("{}"),
|
||||
IsVerified: true,
|
||||
ForwardedCount: 0,
|
||||
})
|
||||
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/version/2", reqBody, test.WithQueryParams(map[string]string{"version": "1"}))
|
||||
return req
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusOK)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can query backup sessions",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys", test.WithQueryParams(map[string]string{
|
||||
"version": "1",
|
||||
}))
|
||||
return req
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusOK)
|
||||
if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "rooms").Map(); len(gotRooms) != 1 {
|
||||
t.Fatalf("expected one room in response, but got %#v", rec.Body.String())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can query backup sessions by room",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!testroom:test", test.WithQueryParams(map[string]string{
|
||||
"version": "1",
|
||||
}))
|
||||
return req
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusOK)
|
||||
if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "sessions").Map(); len(gotRooms) != 1 {
|
||||
t.Fatalf("expected one session in response, but got %#v", rec.Body.String())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can query backup sessions by room and sessionID",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!testroom:test/dummySession", test.WithQueryParams(map[string]string{
|
||||
"version": "1",
|
||||
}))
|
||||
return req
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusOK)
|
||||
if !gjson.GetBytes(rec.Body.Bytes(), "is_verified").Bool() {
|
||||
t.Fatalf("expected session to be verified, but wasn't: %#v", rec.Body.String())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can not delete invalid version backup",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/2", nil)
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can delete version backup",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/1", nil)
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusOK)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deleting the same backup version twice doesn't error",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/1", nil)
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusOK)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deleting an empty version doesn't work", // make sure we can't delete an empty backup version. Handled at the router level
|
||||
request: func(t *testing.T) *http.Request {
|
||||
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/", nil)
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
handleResponseCode(t, rec, http.StatusNotFound)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||
cfg.ClientAPI.RateLimiting.Enabled = false
|
||||
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||
natsInstance := jetstream.NATSInstance{}
|
||||
defer close()
|
||||
|
||||
routers := httputil.NewRouters()
|
||||
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
||||
|
||||
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
alice: {},
|
||||
}
|
||||
createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := tc.request(t)
|
||||
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
||||
routers.Client.ServeHTTP(rec, req)
|
||||
tc.validate(t, rec)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -171,6 +171,23 @@ func LeaveServerNoticeError() *MatrixError {
|
|||
}
|
||||
}
|
||||
|
||||
// ErrRoomKeysVersion is an error returned by `PUT /room_keys/keys`
|
||||
type ErrRoomKeysVersion struct {
|
||||
MatrixError
|
||||
CurrentVersion string `json:"current_version"`
|
||||
}
|
||||
|
||||
// WrongBackupVersionError is an error returned by `PUT /room_keys/keys`
|
||||
func WrongBackupVersionError(currentVersion string) *ErrRoomKeysVersion {
|
||||
return &ErrRoomKeysVersion{
|
||||
MatrixError: MatrixError{
|
||||
ErrCode: "M_WRONG_ROOM_KEYS_VERSION",
|
||||
Err: "Wrong backup version.",
|
||||
},
|
||||
CurrentVersion: currentVersion,
|
||||
}
|
||||
}
|
||||
|
||||
type IncompatibleRoomVersionError struct {
|
||||
RoomVersion string `json:"room_version"`
|
||||
Error string `json:"error"`
|
||||
|
|
|
@ -61,28 +61,26 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de
|
|||
if resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
var performKeyBackupResp userapi.PerformKeyBackupResponse
|
||||
if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
|
||||
if len(kb.AuthData) == 0 {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON("missing auth_data"),
|
||||
}
|
||||
}
|
||||
version, err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
|
||||
UserID: device.UserID,
|
||||
Version: "",
|
||||
AuthData: kb.AuthData,
|
||||
Algorithm: kb.Algorithm,
|
||||
}, &performKeyBackupResp); err != nil {
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
if performKeyBackupResp.Error != "" {
|
||||
if performKeyBackupResp.BadInput {
|
||||
return util.JSONResponse{
|
||||
Code: 400,
|
||||
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
|
||||
}
|
||||
}
|
||||
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
|
||||
})
|
||||
if err != nil {
|
||||
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %w", err))
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: keyBackupVersionCreateResponse{
|
||||
Version: performKeyBackupResp.Version,
|
||||
Version: version,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -90,15 +88,12 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de
|
|||
// KeyBackupVersion returns the key backup version specified. If `version` is empty, the latest `keyBackupVersionResponse` is returned.
|
||||
// Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version}
|
||||
func KeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse {
|
||||
var queryResp userapi.QueryKeyBackupResponse
|
||||
if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
|
||||
queryResp, err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
|
||||
UserID: device.UserID,
|
||||
Version: version,
|
||||
}, &queryResp); err != nil {
|
||||
return jsonerror.InternalAPIError(req.Context(), err)
|
||||
}
|
||||
if queryResp.Error != "" {
|
||||
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
|
||||
})
|
||||
if err != nil {
|
||||
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", err))
|
||||
}
|
||||
if !queryResp.Exists {
|
||||
return util.JSONResponse{
|
||||
|
@ -126,31 +121,29 @@ func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.ClientUse
|
|||
if resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
var performKeyBackupResp userapi.PerformKeyBackupResponse
|
||||
if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
|
||||
performKeyBackupResp, err := userAPI.UpdateBackupKeyAuthData(req.Context(), &userapi.PerformKeyBackupRequest{
|
||||
UserID: device.UserID,
|
||||
Version: version,
|
||||
AuthData: kb.AuthData,
|
||||
Algorithm: kb.Algorithm,
|
||||
}, &performKeyBackupResp); err != nil {
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
if performKeyBackupResp.Error != "" {
|
||||
if performKeyBackupResp.BadInput {
|
||||
return util.JSONResponse{
|
||||
Code: 400,
|
||||
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
|
||||
}
|
||||
})
|
||||
switch e := err.(type) {
|
||||
case *jsonerror.ErrRoomKeysVersion:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: e,
|
||||
}
|
||||
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
|
||||
case nil:
|
||||
default:
|
||||
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %w", e))
|
||||
}
|
||||
|
||||
if !performKeyBackupResp.Exists {
|
||||
return util.JSONResponse{
|
||||
Code: 404,
|
||||
JSON: jsonerror.NotFound("backup version not found"),
|
||||
}
|
||||
}
|
||||
// Unclear what the 200 body should be
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: keyBackupVersionCreateResponse{
|
||||
|
@ -162,35 +155,19 @@ func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.ClientUse
|
|||
// Delete a version of key backup. Version must not be empty. If the key backup was previously deleted, will return 200 OK.
|
||||
// Implements DELETE /_matrix/client/r0/room_keys/version/{version}
|
||||
func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse {
|
||||
var performKeyBackupResp userapi.PerformKeyBackupResponse
|
||||
if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
|
||||
UserID: device.UserID,
|
||||
Version: version,
|
||||
DeleteBackup: true,
|
||||
}, &performKeyBackupResp); err != nil {
|
||||
return jsonerror.InternalServerError()
|
||||
exists, err := userAPI.DeleteKeyBackup(req.Context(), device.UserID, version)
|
||||
if err != nil {
|
||||
return util.ErrorResponse(fmt.Errorf("DeleteKeyBackup: %s", err))
|
||||
}
|
||||
if performKeyBackupResp.Error != "" {
|
||||
if performKeyBackupResp.BadInput {
|
||||
return util.JSONResponse{
|
||||
Code: 400,
|
||||
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
|
||||
}
|
||||
}
|
||||
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
|
||||
}
|
||||
if !performKeyBackupResp.Exists {
|
||||
if !exists {
|
||||
return util.JSONResponse{
|
||||
Code: 404,
|
||||
JSON: jsonerror.NotFound("backup version not found"),
|
||||
}
|
||||
}
|
||||
// Unclear what the 200 body should be
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: keyBackupVersionCreateResponse{
|
||||
Version: performKeyBackupResp.Version,
|
||||
},
|
||||
JSON: struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -198,22 +175,21 @@ func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de
|
|||
func UploadBackupKeys(
|
||||
req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string, keys *keyBackupSessionRequest,
|
||||
) util.JSONResponse {
|
||||
var performKeyBackupResp userapi.PerformKeyBackupResponse
|
||||
if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
|
||||
performKeyBackupResp, err := userAPI.UpdateBackupKeyAuthData(req.Context(), &userapi.PerformKeyBackupRequest{
|
||||
UserID: device.UserID,
|
||||
Version: version,
|
||||
Keys: *keys,
|
||||
}, &performKeyBackupResp); err != nil && performKeyBackupResp.Error == "" {
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
if performKeyBackupResp.Error != "" {
|
||||
if performKeyBackupResp.BadInput {
|
||||
return util.JSONResponse{
|
||||
Code: 400,
|
||||
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
|
||||
}
|
||||
})
|
||||
|
||||
switch e := err.(type) {
|
||||
case *jsonerror.ErrRoomKeysVersion:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: e,
|
||||
}
|
||||
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
|
||||
case nil:
|
||||
default:
|
||||
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %w", e))
|
||||
}
|
||||
if !performKeyBackupResp.Exists {
|
||||
return util.JSONResponse{
|
||||
|
@ -234,18 +210,15 @@ func UploadBackupKeys(
|
|||
func GetBackupKeys(
|
||||
req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version, roomID, sessionID string,
|
||||
) util.JSONResponse {
|
||||
var queryResp userapi.QueryKeyBackupResponse
|
||||
if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
|
||||
queryResp, err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
|
||||
UserID: device.UserID,
|
||||
Version: version,
|
||||
ReturnKeys: true,
|
||||
KeysForRoomID: roomID,
|
||||
KeysForSessionID: sessionID,
|
||||
}, &queryResp); err != nil {
|
||||
return jsonerror.InternalAPIError(req.Context(), err)
|
||||
}
|
||||
if queryResp.Error != "" {
|
||||
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
|
||||
})
|
||||
if err != nil {
|
||||
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %w", err))
|
||||
}
|
||||
if !queryResp.Exists {
|
||||
return util.JSONResponse{
|
||||
|
@ -267,17 +240,20 @@ func GetBackupKeys(
|
|||
}
|
||||
} else if roomID != "" {
|
||||
roomData, ok := queryResp.Keys[roomID]
|
||||
if ok {
|
||||
// wrap response in "sessions"
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: struct {
|
||||
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
||||
}{
|
||||
Sessions: roomData,
|
||||
},
|
||||
}
|
||||
if !ok {
|
||||
// If no keys are found, then an object with an empty sessions property will be returned
|
||||
roomData = make(map[string]userapi.KeyBackupSession)
|
||||
}
|
||||
// wrap response in "sessions"
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: struct {
|
||||
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
||||
}{
|
||||
Sessions: roomData,
|
||||
},
|
||||
}
|
||||
|
||||
} else {
|
||||
// response is the same as the upload request
|
||||
var resp keyBackupSessionRequest
|
||||
|
|
|
@ -87,6 +87,7 @@ type ClientUserAPI interface {
|
|||
UserLoginAPI
|
||||
ClientKeyAPI
|
||||
ProfileAPI
|
||||
KeyBackupAPI
|
||||
QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
|
||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||
|
@ -105,8 +106,6 @@ type ClientUserAPI interface {
|
|||
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
|
||||
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
|
||||
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
|
||||
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
|
||||
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error
|
||||
|
||||
QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error
|
||||
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
|
||||
|
@ -114,6 +113,13 @@ type ClientUserAPI interface {
|
|||
PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error
|
||||
}
|
||||
|
||||
type KeyBackupAPI interface {
|
||||
DeleteKeyBackup(ctx context.Context, userID, version string) (bool, error)
|
||||
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest) (string, error)
|
||||
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest) (*QueryKeyBackupResponse, error)
|
||||
UpdateBackupKeyAuthData(ctx context.Context, req *PerformKeyBackupRequest) (*PerformKeyBackupResponse, error)
|
||||
}
|
||||
|
||||
type ProfileAPI interface {
|
||||
QueryProfile(ctx context.Context, userID string) (*authtypes.Profile, error)
|
||||
SetAvatarURL(ctx context.Context, localpart string, serverName spec.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
|
||||
|
@ -135,11 +141,10 @@ type UserLoginAPI interface {
|
|||
}
|
||||
|
||||
type PerformKeyBackupRequest struct {
|
||||
UserID string
|
||||
Version string // optional if modifying a key backup
|
||||
AuthData json.RawMessage
|
||||
Algorithm string
|
||||
DeleteBackup bool // if true will delete the backup based on 'Version'.
|
||||
UserID string
|
||||
Version string // optional if modifying a key backup
|
||||
AuthData json.RawMessage
|
||||
Algorithm string
|
||||
|
||||
// The keys to upload, if any. If blank, creates/updates/deletes key version metadata only.
|
||||
Keys struct {
|
||||
|
@ -180,9 +185,6 @@ type InternalKeyBackupSession struct {
|
|||
}
|
||||
|
||||
type PerformKeyBackupResponse struct {
|
||||
Error string // set if there was a problem performing the request
|
||||
BadInput bool // if set, the Error was due to bad input (HTTP 400)
|
||||
|
||||
Exists bool // set to true if the Version exists
|
||||
Version string // the newly created version
|
||||
|
||||
|
@ -200,7 +202,6 @@ type QueryKeyBackupRequest struct {
|
|||
}
|
||||
|
||||
type QueryKeyBackupResponse struct {
|
||||
Error string
|
||||
Exists bool
|
||||
|
||||
Algorithm string `json:"algorithm"`
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
|
||||
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
@ -678,62 +679,43 @@ func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOp
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error {
|
||||
// Delete metadata
|
||||
if req.DeleteBackup {
|
||||
if req.Version == "" {
|
||||
res.BadInput = true
|
||||
res.Error = "must specify a version to delete"
|
||||
return nil
|
||||
}
|
||||
exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to delete backup: %s", err)
|
||||
}
|
||||
res.Exists = exists
|
||||
res.Version = req.Version
|
||||
return nil
|
||||
}
|
||||
func (a *UserInternalAPI) DeleteKeyBackup(ctx context.Context, userID, version string) (bool, error) {
|
||||
return a.DB.DeleteKeyBackup(ctx, userID, version)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest) (string, error) {
|
||||
// Create metadata
|
||||
if req.Version == "" {
|
||||
version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to create backup: %s", err)
|
||||
}
|
||||
res.Exists = err == nil
|
||||
res.Version = version
|
||||
return nil
|
||||
}
|
||||
return a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) UpdateBackupKeyAuthData(ctx context.Context, req *api.PerformKeyBackupRequest) (*api.PerformKeyBackupResponse, error) {
|
||||
res := &api.PerformKeyBackupResponse{}
|
||||
// Update metadata
|
||||
if len(req.Keys.Rooms) == 0 {
|
||||
err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to update backup: %s", err)
|
||||
}
|
||||
res.Exists = err == nil
|
||||
res.Version = req.Version
|
||||
return nil
|
||||
if err != nil {
|
||||
return res, fmt.Errorf("failed to update backup: %w", err)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
// Upload Keys for a specific version metadata
|
||||
a.uploadBackupKeys(ctx, req, res)
|
||||
return nil
|
||||
return a.uploadBackupKeys(ctx, req)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
|
||||
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest) (*api.PerformKeyBackupResponse, error) {
|
||||
res := &api.PerformKeyBackupResponse{}
|
||||
// you can only upload keys for the CURRENT version
|
||||
version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "")
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to query version: %s", err)
|
||||
return
|
||||
return res, fmt.Errorf("failed to query version: %w", err)
|
||||
}
|
||||
if deleted {
|
||||
res.Error = "backup was deleted"
|
||||
return
|
||||
return res, fmt.Errorf("backup was deleted")
|
||||
}
|
||||
if version != req.Version {
|
||||
res.BadInput = true
|
||||
res.Error = fmt.Sprintf("%s isn't the current version, %s is.", req.Version, version)
|
||||
return
|
||||
return res, jsonerror.WrongBackupVersionError(version)
|
||||
}
|
||||
res.Exists = true
|
||||
res.Version = version
|
||||
|
@ -751,23 +733,25 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
|
|||
}
|
||||
count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to upsert keys: %s", err)
|
||||
return
|
||||
return res, fmt.Errorf("failed to upsert keys: %w", err)
|
||||
}
|
||||
res.KeyCount = count
|
||||
res.KeyETag = etag
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) error {
|
||||
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest) (*api.QueryKeyBackupResponse, error) {
|
||||
res := &api.QueryKeyBackupResponse{}
|
||||
version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
|
||||
res.Version = version
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
res.Exists = false
|
||||
return nil
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return res, nil
|
||||
}
|
||||
res.Error = fmt.Sprintf("failed to query key backup: %s", err)
|
||||
return nil
|
||||
if errors.Is(err, strconv.ErrSyntax) {
|
||||
return res, nil
|
||||
}
|
||||
return res, fmt.Errorf("failed to query key backup: %s", err)
|
||||
}
|
||||
res.Algorithm = algorithm
|
||||
res.AuthData = authData
|
||||
|
@ -777,18 +761,17 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
|
|||
if !req.ReturnKeys {
|
||||
res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID)
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to count keys: %s", err)
|
||||
return res, fmt.Errorf("failed to count keys: %w", err)
|
||||
}
|
||||
return nil
|
||||
return res, nil
|
||||
}
|
||||
|
||||
result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to query keys: %s", err)
|
||||
return nil
|
||||
return res, fmt.Errorf("failed to query keys: %s", err)
|
||||
}
|
||||
res.Keys = result
|
||||
return nil
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {
|
||||
|
|
Loading…
Reference in a new issue