mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-16 18:43:10 -06:00
Begin adding tests
This commit is contained in:
parent
c6457cd4e5
commit
baf502f112
|
|
@ -1758,3 +1758,106 @@ func (d dummyStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventC
|
||||||
func (d dummyStore) FindSharedRooms(userID id.UserID) []id.RoomID {
|
func (d dummyStore) FindSharedRooms(userID id.UserID) []id.RoomID {
|
||||||
return []id.RoomID{}
|
return []id.RoomID{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestKeyBackup(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
request func(t *testing.T) *http.Request
|
||||||
|
validate func(t *testing.T, rec *httptest.ResponseRecorder)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "can not create backup with invalid JSON",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1"`)
|
||||||
|
return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody)
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("HTTP[%d]: expected an error, but got none: %s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can create backup",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1"}`)
|
||||||
|
return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody)
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("HTTP[%d]: expected no error, but got: %s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
wantVersion := "1"
|
||||||
|
if gotVersion := gjson.GetBytes(rec.Body.Bytes(), "version").Str; gotVersion != wantVersion {
|
||||||
|
t.Fatalf("expected version '%s', got '%s'", wantVersion, gotVersion)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can not delete invalid version backup",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/2", nil)
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
if rec.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("HTTP[%d]: expected HTTP 404, but got: %d", rec.Code, rec.Code)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "can delete version backup",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/1", nil)
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("HTTP[%d]: expected HTTP 200, but got: %d", rec.Code, rec.Code)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deleting the same backup version twice doesn't error",
|
||||||
|
request: func(t *testing.T) *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/1", nil)
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("HTTP[%d]: expected HTTP 200, but got: %d", rec.Code, rec.Code)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||||
|
cfg.ClientAPI.RateLimiting.Enabled = false
|
||||||
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||||
|
natsInstance := jetstream.NATSInstance{}
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
routers := httputil.NewRouters()
|
||||||
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||||
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||||
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
||||||
|
|
||||||
|
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||||
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||||
|
|
||||||
|
accessTokens := map[*test.User]userDevice{
|
||||||
|
alice: {},
|
||||||
|
}
|
||||||
|
createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := tc.request(t)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
||||||
|
routers.Client.ServeHTTP(rec, req)
|
||||||
|
tc.validate(t, rec)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -71,12 +71,6 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
if performKeyBackupResp.Error != "" {
|
if performKeyBackupResp.Error != "" {
|
||||||
if performKeyBackupResp.BadInput {
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: 400,
|
|
||||||
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
|
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
|
||||||
}
|
}
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
|
@ -162,35 +156,19 @@ func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.ClientUse
|
||||||
// Delete a version of key backup. Version must not be empty. If the key backup was previously deleted, will return 200 OK.
|
// Delete a version of key backup. Version must not be empty. If the key backup was previously deleted, will return 200 OK.
|
||||||
// Implements DELETE /_matrix/client/r0/room_keys/version/{version}
|
// Implements DELETE /_matrix/client/r0/room_keys/version/{version}
|
||||||
func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse {
|
func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse {
|
||||||
var performKeyBackupResp userapi.PerformKeyBackupResponse
|
exists, err := userAPI.DeleteKeyBackup(req.Context(), device.UserID, version)
|
||||||
if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{
|
if err != nil {
|
||||||
UserID: device.UserID,
|
return util.ErrorResponse(fmt.Errorf("DeleteKeyBackup: %s", err))
|
||||||
Version: version,
|
|
||||||
DeleteBackup: true,
|
|
||||||
}, &performKeyBackupResp); err != nil {
|
|
||||||
return jsonerror.InternalServerError()
|
|
||||||
}
|
}
|
||||||
if performKeyBackupResp.Error != "" {
|
if !exists {
|
||||||
if performKeyBackupResp.BadInput {
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: 400,
|
|
||||||
JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error))
|
|
||||||
}
|
|
||||||
if !performKeyBackupResp.Exists {
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 404,
|
Code: 404,
|
||||||
JSON: jsonerror.NotFound("backup version not found"),
|
JSON: jsonerror.NotFound("backup version not found"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Unclear what the 200 body should be
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 200,
|
Code: 200,
|
||||||
JSON: keyBackupVersionCreateResponse{
|
JSON: struct{}{},
|
||||||
Version: performKeyBackupResp.Version,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -87,6 +87,7 @@ type ClientUserAPI interface {
|
||||||
UserLoginAPI
|
UserLoginAPI
|
||||||
ClientKeyAPI
|
ClientKeyAPI
|
||||||
ProfileAPI
|
ProfileAPI
|
||||||
|
KeyBackupAPI
|
||||||
QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
|
QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
|
||||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||||
|
|
@ -105,8 +106,6 @@ type ClientUserAPI interface {
|
||||||
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
|
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
|
||||||
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
|
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
|
||||||
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
|
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
|
||||||
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
|
|
||||||
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error
|
|
||||||
|
|
||||||
QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error
|
QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error
|
||||||
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
|
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
|
||||||
|
|
@ -114,6 +113,12 @@ type ClientUserAPI interface {
|
||||||
PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error
|
PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type KeyBackupAPI interface {
|
||||||
|
DeleteKeyBackup(ctx context.Context, userID, version string) (bool, error)
|
||||||
|
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
|
||||||
|
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error
|
||||||
|
}
|
||||||
|
|
||||||
type ProfileAPI interface {
|
type ProfileAPI interface {
|
||||||
QueryProfile(ctx context.Context, userID string) (*authtypes.Profile, error)
|
QueryProfile(ctx context.Context, userID string) (*authtypes.Profile, error)
|
||||||
SetAvatarURL(ctx context.Context, localpart string, serverName spec.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
|
SetAvatarURL(ctx context.Context, localpart string, serverName spec.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
|
||||||
|
|
@ -139,7 +144,6 @@ type PerformKeyBackupRequest struct {
|
||||||
Version string // optional if modifying a key backup
|
Version string // optional if modifying a key backup
|
||||||
AuthData json.RawMessage
|
AuthData json.RawMessage
|
||||||
Algorithm string
|
Algorithm string
|
||||||
DeleteBackup bool // if true will delete the backup based on 'Version'.
|
|
||||||
|
|
||||||
// The keys to upload, if any. If blank, creates/updates/deletes key version metadata only.
|
// The keys to upload, if any. If blank, creates/updates/deletes key version metadata only.
|
||||||
Keys struct {
|
Keys struct {
|
||||||
|
|
|
||||||
|
|
@ -683,22 +683,11 @@ func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOp
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) DeleteKeyBackup(ctx context.Context, userID, version string) (bool, error) {
|
||||||
|
return a.DB.DeleteKeyBackup(ctx, userID, version)
|
||||||
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error {
|
func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error {
|
||||||
// Delete metadata
|
|
||||||
if req.DeleteBackup {
|
|
||||||
if req.Version == "" {
|
|
||||||
res.BadInput = true
|
|
||||||
res.Error = "must specify a version to delete"
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
|
|
||||||
if err != nil {
|
|
||||||
res.Error = fmt.Sprintf("failed to delete backup: %s", err)
|
|
||||||
}
|
|
||||||
res.Exists = exists
|
|
||||||
res.Version = req.Version
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// Create metadata
|
// Create metadata
|
||||||
if req.Version == "" {
|
if req.Version == "" {
|
||||||
version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
|
version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue