Add key backup tests (#3071)
Also slightly refactors the functions and methods to rely less on the req/res pattern we had for polylith. Returns `M_WRONG_ROOM_KEYS_VERSION` for some endpoints as per the spec
This commit is contained in:
parent
6b47cf0f6a
commit
9e9617ff84
|
@ -1758,3 +1758,377 @@ func (d dummyStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventC
|
||||||
func (d dummyStore) FindSharedRooms(userID id.UserID) []id.RoomID {
|
func (d dummyStore) FindSharedRooms(userID id.UserID) []id.RoomID {
|
||||||
return []id.RoomID{}
|
return []id.RoomID{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestKeyBackup(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
|
||||||
|
handleResponseCode := func(t *testing.T, rec *httptest.ResponseRecorder, expectedCode int) {
|
||||||
|
t.Helper()
|
||||||
|
if rec.Code != expectedCode {
|
||||||
|
t.Fatalf("expected HTTP %d, but got %d: %s", expectedCode, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
request func(t *testing.T) *http.Request
|
||||||
|
validate func(t *testing.T, rec *httptest.ResponseRecorder)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "can not create backup with invalid JSON",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1"`) // missing closing braces
|
||||||
|
return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody)
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusBadRequest)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not create backup with missing auth_data", // as this would result in MarshalJSON errors when querying again
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1"}`)
|
||||||
|
return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody)
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusBadRequest)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can create backup",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1","auth_data":{"data":"random"}}`)
|
||||||
|
return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody)
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusOK)
|
||||||
|
wantVersion := "1"
|
||||||
|
if gotVersion := gjson.GetBytes(rec.Body.Bytes(), "version").Str; gotVersion != wantVersion {
|
||||||
|
t.Fatalf("expected version '%s', got '%s'", wantVersion, gotVersion)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not query backup for invalid version",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version/1337", nil)
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not query backup for invalid version string",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version/notanumber", nil)
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can query backup",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version", nil)
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusOK)
|
||||||
|
wantVersion := "1"
|
||||||
|
if gotVersion := gjson.GetBytes(rec.Body.Bytes(), "version").Str; gotVersion != wantVersion {
|
||||||
|
t.Fatalf("expected version '%s', got '%s'", wantVersion, gotVersion)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can query backup without returning rooms",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys", test.WithQueryParams(map[string]string{
|
||||||
|
"version": "1",
|
||||||
|
}))
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusOK)
|
||||||
|
if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "rooms").Map(); len(gotRooms) > 0 {
|
||||||
|
t.Fatalf("expected no rooms in version, but got %#v", gotRooms)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can query backup for invalid room",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!abc:test", test.WithQueryParams(map[string]string{
|
||||||
|
"version": "1",
|
||||||
|
}))
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusOK)
|
||||||
|
if gotSessions := gjson.GetBytes(rec.Body.Bytes(), "sessions").Map(); len(gotSessions) > 0 {
|
||||||
|
t.Fatalf("expected no sessions in version, but got %#v", gotSessions)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not query backup for invalid session",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!abc:test/doesnotexist", test.WithQueryParams(map[string]string{
|
||||||
|
"version": "1",
|
||||||
|
}))
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not update backup with missing version",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
return test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys")
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusBadRequest)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not update backup with invalid data",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
reqBody := test.WithJSONBody(t, "")
|
||||||
|
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{
|
||||||
|
"version": "0",
|
||||||
|
}))
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusBadRequest)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not update backup with wrong version",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
reqBody := test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"rooms": map[string]interface{}{
|
||||||
|
"!testroom:test": map[string]interface{}{
|
||||||
|
"sessions": map[string]uapi.KeyBackupSession{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{
|
||||||
|
"version": "5",
|
||||||
|
}))
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusForbidden)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can update backup with correct version",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
reqBody := test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"rooms": map[string]interface{}{
|
||||||
|
"!testroom:test": map[string]interface{}{
|
||||||
|
"sessions": map[string]uapi.KeyBackupSession{
|
||||||
|
"dummySession": {
|
||||||
|
FirstMessageIndex: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{
|
||||||
|
"version": "1",
|
||||||
|
}))
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusOK)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can update backup with correct version for specific room",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
reqBody := test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"sessions": map[string]uapi.KeyBackupSession{
|
||||||
|
"dummySession": {
|
||||||
|
FirstMessageIndex: 1,
|
||||||
|
IsVerified: true,
|
||||||
|
SessionData: json.RawMessage("{}"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys/!testroom:test", reqBody, test.WithQueryParams(map[string]string{
|
||||||
|
"version": "1",
|
||||||
|
}))
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusOK)
|
||||||
|
t.Logf("%#v", rec.Body.String())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can update backup with correct version for specific room and session",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{
|
||||||
|
FirstMessageIndex: 1,
|
||||||
|
SessionData: json.RawMessage("{}"),
|
||||||
|
IsVerified: true,
|
||||||
|
ForwardedCount: 0,
|
||||||
|
})
|
||||||
|
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys/!testroom:test/dummySession", reqBody, test.WithQueryParams(map[string]string{
|
||||||
|
"version": "1",
|
||||||
|
}))
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusOK)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can update backup by version",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{
|
||||||
|
FirstMessageIndex: 1,
|
||||||
|
SessionData: json.RawMessage("{}"),
|
||||||
|
IsVerified: true,
|
||||||
|
ForwardedCount: 0,
|
||||||
|
})
|
||||||
|
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/version/1", reqBody, test.WithQueryParams(map[string]string{"version": "1"}))
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusOK)
|
||||||
|
t.Logf("%#v", rec.Body.String())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not update backup by version for invalid version",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{
|
||||||
|
FirstMessageIndex: 1,
|
||||||
|
SessionData: json.RawMessage("{}"),
|
||||||
|
IsVerified: true,
|
||||||
|
ForwardedCount: 0,
|
||||||
|
})
|
||||||
|
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/version/2", reqBody, test.WithQueryParams(map[string]string{"version": "1"}))
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusOK)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can query backup sessions",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys", test.WithQueryParams(map[string]string{
|
||||||
|
"version": "1",
|
||||||
|
}))
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusOK)
|
||||||
|
if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "rooms").Map(); len(gotRooms) != 1 {
|
||||||
|
t.Fatalf("expected one room in response, but got %#v", rec.Body.String())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can query backup sessions by room",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!testroom:test", test.WithQueryParams(map[string]string{
|
||||||
|
"version": "1",
|
||||||
|
}))
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusOK)
|
||||||
|
if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "sessions").Map(); len(gotRooms) != 1 {
|
||||||
|
t.Fatalf("expected one session in response, but got %#v", rec.Body.String())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can query backup sessions by room and sessionID",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!testroom:test/dummySession", test.WithQueryParams(map[string]string{
|
||||||
|
"version": "1",
|
||||||
|
}))
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusOK)
|
||||||
|
if !gjson.GetBytes(rec.Body.Bytes(), "is_verified").Bool() {
|
||||||
|
t.Fatalf("expected session to be verified, but wasn't: %#v", rec.Body.String())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not delete invalid version backup",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/2", nil)
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can delete version backup",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/1", nil)
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusOK)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deleting the same backup version twice doesn't error",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/1", nil)
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusOK)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deleting an empty version doesn't work", // make sure we can't delete an empty backup version. Handled at the router level
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/", nil)
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
handleResponseCode(t, rec, http.StatusNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||||
|
cfg.ClientAPI.RateLimiting.Enabled = false
|
||||||
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||||
|
natsInstance := jetstream.NATSInstance{}
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
routers := httputil.NewRouters()
|
||||||
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||||
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||||
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
||||||
|
|
||||||
|
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||||
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||||
|
|
||||||
|
accessTokens := map[*test.User]userDevice{
|
||||||
|
alice: {},
|
||||||
|
}
|
||||||
|
createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := tc.request(t)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
||||||
|
routers.Client.ServeHTTP(rec, req)
|
||||||
|
tc.validate(t, rec)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -171,6 +171,23 @@ func LeaveServerNoticeError() *MatrixError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrRoomKeysVersion is an error returned by `PUT /room_keys/keys`
|
||||||
|
type ErrRoomKeysVersion struct {
|
||||||
|
MatrixError
|
||||||
|
CurrentVersion string `json:"current_version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrongBackupVersionError is an error returned by `PUT /room_keys/keys`
|
||||||
|
func WrongBackupVersionError(currentVersion string) *ErrRoomKeysVersion {
|
||||||
|
return &ErrRoomKeysVersion{
|
||||||
|
MatrixError: MatrixError{
|
||||||
|
ErrCode: "M_WRONG_ROOM_KEYS_VERSION",
|
||||||
|
Err: "Wrong backup version.",
|
||||||
|
},
|
||||||
|
CurrentVersion: currentVersion,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type IncompatibleRoomVersionError struct {
|
type IncompatibleRoomVersionError struct {
|
||||||
RoomVersion string `json:"room_version"`
|
RoomVersion string `json:"room_version"`
|
||||||
Error string `json:"error"`
|
Error string `json:"error"`
|
||||||
|
|
|
@ -61,28 +61,26 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
var performKeyBackupResp userapi.PerformKeyBackupResponse
|
if len(kb.AuthData) == 0 {
|
||||||
if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON("missing auth_data"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
version, err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
|
||||||
UserID: device.UserID,
|
UserID: device.UserID,
|
||||||
Version: "",
|
Version: "",
|
||||||
AuthData: kb.AuthData,
|
AuthData: kb.AuthData,
|
||||||
Algorithm: kb.Algorithm,
|
Algorithm: kb.Algorithm,
|
||||||
}, &performKeyBackupResp); err != nil {
|
})
|
||||||
return jsonerror.InternalServerError()
|
if err != nil {
|
||||||
}
|
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %w", err))
|
||||||
if performKeyBackupResp.Error != "" {
|
|
||||||
if performKeyBackupResp.BadInput {
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: 400,
|
|
||||||
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 200,
|
Code: 200,
|
||||||
JSON: keyBackupVersionCreateResponse{
|
JSON: keyBackupVersionCreateResponse{
|
||||||
Version: performKeyBackupResp.Version,
|
Version: version,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -90,15 +88,12 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de
|
||||||
// KeyBackupVersion returns the key backup version specified. If `version` is empty, the latest `keyBackupVersionResponse` is returned.
|
// KeyBackupVersion returns the key backup version specified. If `version` is empty, the latest `keyBackupVersionResponse` is returned.
|
||||||
// Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version}
|
// Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version}
|
||||||
func KeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse {
|
func KeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse {
|
||||||
var queryResp userapi.QueryKeyBackupResponse
|
queryResp, err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
|
||||||
if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
|
|
||||||
UserID: device.UserID,
|
UserID: device.UserID,
|
||||||
Version: version,
|
Version: version,
|
||||||
}, &queryResp); err != nil {
|
})
|
||||||
return jsonerror.InternalAPIError(req.Context(), err)
|
if err != nil {
|
||||||
}
|
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", err))
|
||||||
if queryResp.Error != "" {
|
|
||||||
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
|
|
||||||
}
|
}
|
||||||
if !queryResp.Exists {
|
if !queryResp.Exists {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
@ -126,31 +121,29 @@ func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.ClientUse
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
var performKeyBackupResp userapi.PerformKeyBackupResponse
|
performKeyBackupResp, err := userAPI.UpdateBackupKeyAuthData(req.Context(), &userapi.PerformKeyBackupRequest{
|
||||||
if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
|
|
||||||
UserID: device.UserID,
|
UserID: device.UserID,
|
||||||
Version: version,
|
Version: version,
|
||||||
AuthData: kb.AuthData,
|
AuthData: kb.AuthData,
|
||||||
Algorithm: kb.Algorithm,
|
Algorithm: kb.Algorithm,
|
||||||
}, &performKeyBackupResp); err != nil {
|
})
|
||||||
return jsonerror.InternalServerError()
|
switch e := err.(type) {
|
||||||
}
|
case *jsonerror.ErrRoomKeysVersion:
|
||||||
if performKeyBackupResp.Error != "" {
|
return util.JSONResponse{
|
||||||
if performKeyBackupResp.BadInput {
|
Code: http.StatusForbidden,
|
||||||
return util.JSONResponse{
|
JSON: e,
|
||||||
Code: 400,
|
|
||||||
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
|
case nil:
|
||||||
|
default:
|
||||||
|
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %w", e))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !performKeyBackupResp.Exists {
|
if !performKeyBackupResp.Exists {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 404,
|
Code: 404,
|
||||||
JSON: jsonerror.NotFound("backup version not found"),
|
JSON: jsonerror.NotFound("backup version not found"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Unclear what the 200 body should be
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 200,
|
Code: 200,
|
||||||
JSON: keyBackupVersionCreateResponse{
|
JSON: keyBackupVersionCreateResponse{
|
||||||
|
@ -162,35 +155,19 @@ func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.ClientUse
|
||||||
// Delete a version of key backup. Version must not be empty. If the key backup was previously deleted, will return 200 OK.
|
// Delete a version of key backup. Version must not be empty. If the key backup was previously deleted, will return 200 OK.
|
||||||
// Implements DELETE /_matrix/client/r0/room_keys/version/{version}
|
// Implements DELETE /_matrix/client/r0/room_keys/version/{version}
|
||||||
func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse {
|
func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse {
|
||||||
var performKeyBackupResp userapi.PerformKeyBackupResponse
|
exists, err := userAPI.DeleteKeyBackup(req.Context(), device.UserID, version)
|
||||||
if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
|
if err != nil {
|
||||||
UserID: device.UserID,
|
return util.ErrorResponse(fmt.Errorf("DeleteKeyBackup: %s", err))
|
||||||
Version: version,
|
|
||||||
DeleteBackup: true,
|
|
||||||
}, &performKeyBackupResp); err != nil {
|
|
||||||
return jsonerror.InternalServerError()
|
|
||||||
}
|
}
|
||||||
if performKeyBackupResp.Error != "" {
|
if !exists {
|
||||||
if performKeyBackupResp.BadInput {
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: 400,
|
|
||||||
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
|
|
||||||
}
|
|
||||||
if !performKeyBackupResp.Exists {
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 404,
|
Code: 404,
|
||||||
JSON: jsonerror.NotFound("backup version not found"),
|
JSON: jsonerror.NotFound("backup version not found"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Unclear what the 200 body should be
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 200,
|
Code: 200,
|
||||||
JSON: keyBackupVersionCreateResponse{
|
JSON: struct{}{},
|
||||||
Version: performKeyBackupResp.Version,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -198,22 +175,21 @@ func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de
|
||||||
func UploadBackupKeys(
|
func UploadBackupKeys(
|
||||||
req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string, keys *keyBackupSessionRequest,
|
req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string, keys *keyBackupSessionRequest,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
var performKeyBackupResp userapi.PerformKeyBackupResponse
|
performKeyBackupResp, err := userAPI.UpdateBackupKeyAuthData(req.Context(), &userapi.PerformKeyBackupRequest{
|
||||||
if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
|
|
||||||
UserID: device.UserID,
|
UserID: device.UserID,
|
||||||
Version: version,
|
Version: version,
|
||||||
Keys: *keys,
|
Keys: *keys,
|
||||||
}, &performKeyBackupResp); err != nil && performKeyBackupResp.Error == "" {
|
})
|
||||||
return jsonerror.InternalServerError()
|
|
||||||
}
|
switch e := err.(type) {
|
||||||
if performKeyBackupResp.Error != "" {
|
case *jsonerror.ErrRoomKeysVersion:
|
||||||
if performKeyBackupResp.BadInput {
|
return util.JSONResponse{
|
||||||
return util.JSONResponse{
|
Code: http.StatusForbidden,
|
||||||
Code: 400,
|
JSON: e,
|
||||||
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
|
case nil:
|
||||||
|
default:
|
||||||
|
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %w", e))
|
||||||
}
|
}
|
||||||
if !performKeyBackupResp.Exists {
|
if !performKeyBackupResp.Exists {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
@ -234,18 +210,15 @@ func UploadBackupKeys(
|
||||||
func GetBackupKeys(
|
func GetBackupKeys(
|
||||||
req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version, roomID, sessionID string,
|
req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version, roomID, sessionID string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
var queryResp userapi.QueryKeyBackupResponse
|
queryResp, err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
|
||||||
if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
|
|
||||||
UserID: device.UserID,
|
UserID: device.UserID,
|
||||||
Version: version,
|
Version: version,
|
||||||
ReturnKeys: true,
|
ReturnKeys: true,
|
||||||
KeysForRoomID: roomID,
|
KeysForRoomID: roomID,
|
||||||
KeysForSessionID: sessionID,
|
KeysForSessionID: sessionID,
|
||||||
}, &queryResp); err != nil {
|
})
|
||||||
return jsonerror.InternalAPIError(req.Context(), err)
|
if err != nil {
|
||||||
}
|
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %w", err))
|
||||||
if queryResp.Error != "" {
|
|
||||||
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
|
|
||||||
}
|
}
|
||||||
if !queryResp.Exists {
|
if !queryResp.Exists {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
@ -267,17 +240,20 @@ func GetBackupKeys(
|
||||||
}
|
}
|
||||||
} else if roomID != "" {
|
} else if roomID != "" {
|
||||||
roomData, ok := queryResp.Keys[roomID]
|
roomData, ok := queryResp.Keys[roomID]
|
||||||
if ok {
|
if !ok {
|
||||||
// wrap response in "sessions"
|
// If no keys are found, then an object with an empty sessions property will be returned
|
||||||
return util.JSONResponse{
|
roomData = make(map[string]userapi.KeyBackupSession)
|
||||||
Code: 200,
|
|
||||||
JSON: struct {
|
|
||||||
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
|
||||||
}{
|
|
||||||
Sessions: roomData,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
// wrap response in "sessions"
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: struct {
|
||||||
|
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
||||||
|
}{
|
||||||
|
Sessions: roomData,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// response is the same as the upload request
|
// response is the same as the upload request
|
||||||
var resp keyBackupSessionRequest
|
var resp keyBackupSessionRequest
|
||||||
|
|
|
@ -87,6 +87,7 @@ type ClientUserAPI interface {
|
||||||
UserLoginAPI
|
UserLoginAPI
|
||||||
ClientKeyAPI
|
ClientKeyAPI
|
||||||
ProfileAPI
|
ProfileAPI
|
||||||
|
KeyBackupAPI
|
||||||
QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
|
QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
|
||||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||||
|
@ -105,8 +106,6 @@ type ClientUserAPI interface {
|
||||||
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
|
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
|
||||||
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
|
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
|
||||||
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
|
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
|
||||||
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
|
|
||||||
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error
|
|
||||||
|
|
||||||
QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error
|
QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error
|
||||||
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
|
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
|
||||||
|
@ -114,6 +113,13 @@ type ClientUserAPI interface {
|
||||||
PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error
|
PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type KeyBackupAPI interface {
|
||||||
|
DeleteKeyBackup(ctx context.Context, userID, version string) (bool, error)
|
||||||
|
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest) (string, error)
|
||||||
|
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest) (*QueryKeyBackupResponse, error)
|
||||||
|
UpdateBackupKeyAuthData(ctx context.Context, req *PerformKeyBackupRequest) (*PerformKeyBackupResponse, error)
|
||||||
|
}
|
||||||
|
|
||||||
type ProfileAPI interface {
|
type ProfileAPI interface {
|
||||||
QueryProfile(ctx context.Context, userID string) (*authtypes.Profile, error)
|
QueryProfile(ctx context.Context, userID string) (*authtypes.Profile, error)
|
||||||
SetAvatarURL(ctx context.Context, localpart string, serverName spec.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
|
SetAvatarURL(ctx context.Context, localpart string, serverName spec.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
|
||||||
|
@ -135,11 +141,10 @@ type UserLoginAPI interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type PerformKeyBackupRequest struct {
|
type PerformKeyBackupRequest struct {
|
||||||
UserID string
|
UserID string
|
||||||
Version string // optional if modifying a key backup
|
Version string // optional if modifying a key backup
|
||||||
AuthData json.RawMessage
|
AuthData json.RawMessage
|
||||||
Algorithm string
|
Algorithm string
|
||||||
DeleteBackup bool // if true will delete the backup based on 'Version'.
|
|
||||||
|
|
||||||
// The keys to upload, if any. If blank, creates/updates/deletes key version metadata only.
|
// The keys to upload, if any. If blank, creates/updates/deletes key version metadata only.
|
||||||
Keys struct {
|
Keys struct {
|
||||||
|
@ -180,9 +185,6 @@ type InternalKeyBackupSession struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type PerformKeyBackupResponse struct {
|
type PerformKeyBackupResponse struct {
|
||||||
Error string // set if there was a problem performing the request
|
|
||||||
BadInput bool // if set, the Error was due to bad input (HTTP 400)
|
|
||||||
|
|
||||||
Exists bool // set to true if the Version exists
|
Exists bool // set to true if the Version exists
|
||||||
Version string // the newly created version
|
Version string // the newly created version
|
||||||
|
|
||||||
|
@ -200,7 +202,6 @@ type QueryKeyBackupRequest struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryKeyBackupResponse struct {
|
type QueryKeyBackupResponse struct {
|
||||||
Error string
|
|
||||||
Exists bool
|
Exists bool
|
||||||
|
|
||||||
Algorithm string `json:"algorithm"`
|
Algorithm string `json:"algorithm"`
|
||||||
|
|
|
@ -25,6 +25,7 @@ import (
|
||||||
|
|
||||||
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
|
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -678,62 +679,43 @@ func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOp
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error {
|
func (a *UserInternalAPI) DeleteKeyBackup(ctx context.Context, userID, version string) (bool, error) {
|
||||||
// Delete metadata
|
return a.DB.DeleteKeyBackup(ctx, userID, version)
|
||||||
if req.DeleteBackup {
|
}
|
||||||
if req.Version == "" {
|
|
||||||
res.BadInput = true
|
func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest) (string, error) {
|
||||||
res.Error = "must specify a version to delete"
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
|
|
||||||
if err != nil {
|
|
||||||
res.Error = fmt.Sprintf("failed to delete backup: %s", err)
|
|
||||||
}
|
|
||||||
res.Exists = exists
|
|
||||||
res.Version = req.Version
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// Create metadata
|
// Create metadata
|
||||||
if req.Version == "" {
|
return a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
|
||||||
version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
|
}
|
||||||
if err != nil {
|
|
||||||
res.Error = fmt.Sprintf("failed to create backup: %s", err)
|
func (a *UserInternalAPI) UpdateBackupKeyAuthData(ctx context.Context, req *api.PerformKeyBackupRequest) (*api.PerformKeyBackupResponse, error) {
|
||||||
}
|
res := &api.PerformKeyBackupResponse{}
|
||||||
res.Exists = err == nil
|
|
||||||
res.Version = version
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// Update metadata
|
// Update metadata
|
||||||
if len(req.Keys.Rooms) == 0 {
|
if len(req.Keys.Rooms) == 0 {
|
||||||
err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
|
err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
|
||||||
if err != nil {
|
|
||||||
res.Error = fmt.Sprintf("failed to update backup: %s", err)
|
|
||||||
}
|
|
||||||
res.Exists = err == nil
|
res.Exists = err == nil
|
||||||
res.Version = req.Version
|
res.Version = req.Version
|
||||||
return nil
|
if err != nil {
|
||||||
|
return res, fmt.Errorf("failed to update backup: %w", err)
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
}
|
}
|
||||||
// Upload Keys for a specific version metadata
|
// Upload Keys for a specific version metadata
|
||||||
a.uploadBackupKeys(ctx, req, res)
|
return a.uploadBackupKeys(ctx, req)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
|
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest) (*api.PerformKeyBackupResponse, error) {
|
||||||
|
res := &api.PerformKeyBackupResponse{}
|
||||||
// you can only upload keys for the CURRENT version
|
// you can only upload keys for the CURRENT version
|
||||||
version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "")
|
version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = fmt.Sprintf("failed to query version: %s", err)
|
return res, fmt.Errorf("failed to query version: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if deleted {
|
if deleted {
|
||||||
res.Error = "backup was deleted"
|
return res, fmt.Errorf("backup was deleted")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if version != req.Version {
|
if version != req.Version {
|
||||||
res.BadInput = true
|
return res, jsonerror.WrongBackupVersionError(version)
|
||||||
res.Error = fmt.Sprintf("%s isn't the current version, %s is.", req.Version, version)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
res.Exists = true
|
res.Exists = true
|
||||||
res.Version = version
|
res.Version = version
|
||||||
|
@ -751,23 +733,25 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
|
||||||
}
|
}
|
||||||
count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
|
count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = fmt.Sprintf("failed to upsert keys: %s", err)
|
return res, fmt.Errorf("failed to upsert keys: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
res.KeyCount = count
|
res.KeyCount = count
|
||||||
res.KeyETag = etag
|
res.KeyETag = etag
|
||||||
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) error {
|
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest) (*api.QueryKeyBackupResponse, error) {
|
||||||
|
res := &api.QueryKeyBackupResponse{}
|
||||||
version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
|
version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
|
||||||
res.Version = version
|
res.Version = version
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
res.Exists = false
|
return res, nil
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
res.Error = fmt.Sprintf("failed to query key backup: %s", err)
|
if errors.Is(err, strconv.ErrSyntax) {
|
||||||
return nil
|
return res, nil
|
||||||
|
}
|
||||||
|
return res, fmt.Errorf("failed to query key backup: %s", err)
|
||||||
}
|
}
|
||||||
res.Algorithm = algorithm
|
res.Algorithm = algorithm
|
||||||
res.AuthData = authData
|
res.AuthData = authData
|
||||||
|
@ -777,18 +761,17 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
|
||||||
if !req.ReturnKeys {
|
if !req.ReturnKeys {
|
||||||
res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID)
|
res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = fmt.Sprintf("failed to count keys: %s", err)
|
return res, fmt.Errorf("failed to count keys: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
|
result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = fmt.Sprintf("failed to query keys: %s", err)
|
return res, fmt.Errorf("failed to query keys: %s", err)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
res.Keys = result
|
res.Keys = result
|
||||||
return nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {
|
func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {
|
||||||
|
|
Loading…
Reference in a new issue