mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-16 10:33:11 -06:00
Begin adding tests
This commit is contained in:
parent
c6457cd4e5
commit
baf502f112
|
|
@ -1758,3 +1758,106 @@ func (d dummyStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventC
|
|||
func (d dummyStore) FindSharedRooms(userID id.UserID) []id.RoomID {
|
||||
return []id.RoomID{}
|
||||
}
|
||||
|
||||
func TestKeyBackup(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
request func(t *testing.T) *http.Request
|
||||
validate func(t *testing.T, rec *httptest.ResponseRecorder)
|
||||
}{
|
||||
{
|
||||
name: "can not create backup with invalid JSON",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1"`)
|
||||
return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody)
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("HTTP[%d]: expected an error, but got none: %s", rec.Code, rec.Body.String())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can create backup",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1"}`)
|
||||
return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody)
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("HTTP[%d]: expected no error, but got: %s", rec.Code, rec.Body.String())
|
||||
}
|
||||
wantVersion := "1"
|
||||
if gotVersion := gjson.GetBytes(rec.Body.Bytes(), "version").Str; gotVersion != wantVersion {
|
||||
t.Fatalf("expected version '%s', got '%s'", wantVersion, gotVersion)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can not delete invalid version backup",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/2", nil)
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Fatalf("HTTP[%d]: expected HTTP 404, but got: %d", rec.Code, rec.Code)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "can delete version backup",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/1", nil)
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("HTTP[%d]: expected HTTP 200, but got: %d", rec.Code, rec.Code)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deleting the same backup version twice doesn't error",
|
||||
request: func(t *testing.T) *http.Request {
|
||||
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/1", nil)
|
||||
},
|
||||
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("HTTP[%d]: expected HTTP 200, but got: %d", rec.Code, rec.Code)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||
cfg.ClientAPI.RateLimiting.Enabled = false
|
||||
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||
natsInstance := jetstream.NATSInstance{}
|
||||
defer close()
|
||||
|
||||
routers := httputil.NewRouters()
|
||||
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
||||
|
||||
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
alice: {},
|
||||
}
|
||||
createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := tc.request(t)
|
||||
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
||||
routers.Client.ServeHTTP(rec, req)
|
||||
tc.validate(t, rec)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -71,12 +71,6 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de
|
|||
return jsonerror.InternalServerError()
|
||||
}
|
||||
if performKeyBackupResp.Error != "" {
|
||||
if performKeyBackupResp.BadInput {
|
||||
return util.JSONResponse{
|
||||
Code: 400,
|
||||
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
|
||||
}
|
||||
}
|
||||
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
|
||||
}
|
||||
return util.JSONResponse{
|
||||
|
|
@ -162,35 +156,19 @@ func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.ClientUse
|
|||
// Delete a version of key backup. Version must not be empty. If the key backup was previously deleted, will return 200 OK.
|
||||
// Implements DELETE /_matrix/client/r0/room_keys/version/{version}
|
||||
func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse {
|
||||
var performKeyBackupResp userapi.PerformKeyBackupResponse
|
||||
if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
|
||||
UserID: device.UserID,
|
||||
Version: version,
|
||||
DeleteBackup: true,
|
||||
}, &performKeyBackupResp); err != nil {
|
||||
return jsonerror.InternalServerError()
|
||||
exists, err := userAPI.DeleteKeyBackup(req.Context(), device.UserID, version)
|
||||
if err != nil {
|
||||
return util.ErrorResponse(fmt.Errorf("DeleteKeyBackup: %s", err))
|
||||
}
|
||||
if performKeyBackupResp.Error != "" {
|
||||
if performKeyBackupResp.BadInput {
|
||||
return util.JSONResponse{
|
||||
Code: 400,
|
||||
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
|
||||
}
|
||||
}
|
||||
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
|
||||
}
|
||||
if !performKeyBackupResp.Exists {
|
||||
if !exists {
|
||||
return util.JSONResponse{
|
||||
Code: 404,
|
||||
JSON: jsonerror.NotFound("backup version not found"),
|
||||
}
|
||||
}
|
||||
// Unclear what the 200 body should be
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: keyBackupVersionCreateResponse{
|
||||
Version: performKeyBackupResp.Version,
|
||||
},
|
||||
JSON: struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ type ClientUserAPI interface {
|
|||
UserLoginAPI
|
||||
ClientKeyAPI
|
||||
ProfileAPI
|
||||
KeyBackupAPI
|
||||
QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
|
||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||
|
|
@ -105,8 +106,6 @@ type ClientUserAPI interface {
|
|||
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
|
||||
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
|
||||
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
|
||||
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
|
||||
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error
|
||||
|
||||
QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error
|
||||
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
|
||||
|
|
@ -114,6 +113,12 @@ type ClientUserAPI interface {
|
|||
PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error
|
||||
}
|
||||
|
||||
type KeyBackupAPI interface {
|
||||
DeleteKeyBackup(ctx context.Context, userID, version string) (bool, error)
|
||||
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
|
||||
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error
|
||||
}
|
||||
|
||||
type ProfileAPI interface {
|
||||
QueryProfile(ctx context.Context, userID string) (*authtypes.Profile, error)
|
||||
SetAvatarURL(ctx context.Context, localpart string, serverName spec.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
|
||||
|
|
@ -135,11 +140,10 @@ type UserLoginAPI interface {
|
|||
}
|
||||
|
||||
type PerformKeyBackupRequest struct {
|
||||
UserID string
|
||||
Version string // optional if modifying a key backup
|
||||
AuthData json.RawMessage
|
||||
Algorithm string
|
||||
DeleteBackup bool // if true will delete the backup based on 'Version'.
|
||||
UserID string
|
||||
Version string // optional if modifying a key backup
|
||||
AuthData json.RawMessage
|
||||
Algorithm string
|
||||
|
||||
// The keys to upload, if any. If blank, creates/updates/deletes key version metadata only.
|
||||
Keys struct {
|
||||
|
|
|
|||
|
|
@ -683,22 +683,11 @@ func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOp
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) DeleteKeyBackup(ctx context.Context, userID, version string) (bool, error) {
|
||||
return a.DB.DeleteKeyBackup(ctx, userID, version)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error {
|
||||
// Delete metadata
|
||||
if req.DeleteBackup {
|
||||
if req.Version == "" {
|
||||
res.BadInput = true
|
||||
res.Error = "must specify a version to delete"
|
||||
return nil
|
||||
}
|
||||
exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to delete backup: %s", err)
|
||||
}
|
||||
res.Exists = exists
|
||||
res.Version = req.Version
|
||||
return nil
|
||||
}
|
||||
// Create metadata
|
||||
if req.Version == "" {
|
||||
version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
|
||||
|
|
|
|||
Loading…
Reference in a new issue