diff --git a/relayapi/routing/asyncevents_test.go b/relayapi/routing/asyncevents_test.go index 7d65a8085..cf9397fff 100644 --- a/relayapi/routing/asyncevents_test.go +++ b/relayapi/routing/asyncevents_test.go @@ -36,15 +36,12 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) { } httpReq := &http.Request{} userID, err := gomatrixserverlib.NewUserID("@local:domain", false) - if err != nil { - t.Fatalf("Invalid userID: %s", err.Error()) - } + assert.Nil(t, err, "Invalid userID") + transaction := createTransaction() _, err = db.StoreAsyncTransaction(context.Background(), transaction) - if err != nil { - t.Fatalf("Failed to store transaction: %s", err.Error()) - } + assert.Nil(t, err, "Failed to store transaction") relayAPI := internal.NewRelayInternalAPI( &db, nil, nil, nil, nil, false, "", @@ -59,7 +56,7 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) { assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Txn) count, err := db.GetAsyncTransactionCount(context.Background(), *userID) - assert.Equal(t, count, int64(0)) + assert.Zero(t, count) } func TestGetAsyncReturnsSavedTransaction(t *testing.T) { @@ -71,14 +68,12 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) { } httpReq := &http.Request{} userID, err := gomatrixserverlib.NewUserID("@local:domain", false) - if err != nil { - t.Fatalf("Invalid userID: %s", err.Error()) - } + assert.Nil(t, err, "Invalid userID") + transaction := createTransaction() receipt, err := db.StoreAsyncTransaction(context.Background(), transaction) - if err != nil { - t.Fatalf("Failed to store transaction: %s", err.Error()) - } + assert.Nil(t, err, "Failed to store transaction") + err = db.AssociateAsyncTransactionWithDestinations( context.Background(), map[gomatrixserverlib.UserID]struct{}{ @@ -86,9 +81,7 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) { }, transaction.TransactionID, receipt) - if err != nil { - t.Fatalf("Failed to associate transaction with user: %s", err.Error()) - } + assert.Nil(t, err, "Failed to associate transaction with user") relayAPI := internal.NewRelayInternalAPI( &db, nil, nil, nil, nil, false, "", @@ -99,7 +92,7 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) { assert.Equal(t, http.StatusOK, response.Code) jsonResponse := response.JSON.(routing.AsyncEventsResponse) - assert.Equal(t, true, jsonResponse.EntriesQueued) + assert.True(t, jsonResponse.EntriesQueued) assert.Equal(t, transaction, jsonResponse.Txn) // And once more to clear the queue @@ -108,11 +101,11 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) { assert.Equal(t, http.StatusOK, response.Code) jsonResponse = response.JSON.(routing.AsyncEventsResponse) - assert.Equal(t, false, jsonResponse.EntriesQueued) + assert.False(t, jsonResponse.EntriesQueued) assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Txn) count, err := db.GetAsyncTransactionCount(context.Background(), *userID) - assert.Equal(t, count, int64(0)) + assert.Zero(t, count) } func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) { @@ -124,15 +117,12 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) { } httpReq := &http.Request{} userID, err := gomatrixserverlib.NewUserID("@local:domain", false) - if err != nil { - t.Fatalf("Invalid userID: %s", err.Error()) - } + assert.Nil(t, err, "Invalid userID") transaction := createTransaction() receipt, err := db.StoreAsyncTransaction(context.Background(), transaction) - if err != nil { - t.Fatalf("Failed to store transaction: %s", err.Error()) - } + assert.Nil(t, err, "Failed to store transaction") + err = db.AssociateAsyncTransactionWithDestinations( context.Background(), map[gomatrixserverlib.UserID]struct{}{ @@ -140,15 +130,12 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) { }, transaction.TransactionID, receipt) - if err != nil { - t.Fatalf("Failed to associate transaction with user: %s", err.Error()) - } + assert.Nil(t, err, "Failed to associate transaction with user") transaction2 := createTransaction() receipt2, err := db.StoreAsyncTransaction(context.Background(), transaction2) - if err != nil { - t.Fatalf("Failed to store transaction: %s", err.Error()) - } + assert.Nil(t, err, "Failed to store transaction") + err = db.AssociateAsyncTransactionWithDestinations( context.Background(), map[gomatrixserverlib.UserID]struct{}{ @@ -156,9 +143,7 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) { }, transaction2.TransactionID, receipt2) - if err != nil { - t.Fatalf("Failed to associate transaction with user: %s", err.Error()) - } + assert.Nil(t, err, "Failed to associate transaction with user") relayAPI := internal.NewRelayInternalAPI( &db, nil, nil, nil, nil, false, "", @@ -169,7 +154,7 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) { assert.Equal(t, http.StatusOK, response.Code) jsonResponse := response.JSON.(routing.AsyncEventsResponse) - assert.Equal(t, true, jsonResponse.EntriesQueued) + assert.True(t, jsonResponse.EntriesQueued) assert.Equal(t, transaction, jsonResponse.Txn) request = createAsyncQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}, "relay") @@ -177,7 +162,7 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) { assert.Equal(t, http.StatusOK, response.Code) jsonResponse = response.JSON.(routing.AsyncEventsResponse) - assert.Equal(t, true, jsonResponse.EntriesQueued) + assert.True(t, jsonResponse.EntriesQueued) assert.Equal(t, transaction2, jsonResponse.Txn) // And once more to clear the queue @@ -186,9 +171,9 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) { assert.Equal(t, http.StatusOK, response.Code) jsonResponse = response.JSON.(routing.AsyncEventsResponse) - assert.Equal(t, false, jsonResponse.EntriesQueued) + assert.False(t, jsonResponse.EntriesQueued) assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Txn) count, err := db.GetAsyncTransactionCount(context.Background(), *userID) - assert.Equal(t, count, int64(0)) + assert.Zero(t, count) } diff --git a/relayapi/routing/forwardasync.go b/relayapi/routing/forwardasync.go index ac9390519..4d499486d 100644 --- a/relayapi/routing/forwardasync.go +++ b/relayapi/routing/forwardasync.go @@ -26,10 +26,10 @@ func ForwardAsync( } if err := json.Unmarshal(fedReq.Content(), &txnEvents); err != nil { - logrus.Info("The request body could not be decoded into valid JSON. " + err.Error()) + logrus.Info("The request body could not be decoded into valid JSON." + err.Error()) return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON." + err.Error()), } } diff --git a/relayapi/routing/forwardasync_test.go b/relayapi/routing/forwardasync_test.go index f402e8a47..c68e119e0 100644 --- a/relayapi/routing/forwardasync_test.go +++ b/relayapi/routing/forwardasync_test.go @@ -14,6 +14,7 @@ import ( "github.com/matrix-org/dendrite/relayapi/storage" "github.com/matrix-org/dendrite/relayapi/storage/shared" "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" ) const ( @@ -32,17 +33,22 @@ func createTransaction() gomatrixserverlib.Transaction { return txn } -func createFederationRequest(userID gomatrixserverlib.UserID) (gomatrixserverlib.Transaction, gomatrixserverlib.FederationRequest) { - txn := createTransaction() +func createFederationRequest( + userID gomatrixserverlib.UserID, + txnID gomatrixserverlib.TransactionID, + origin gomatrixserverlib.ServerName, + destination gomatrixserverlib.ServerName, + content interface{}, +) gomatrixserverlib.FederationRequest { var federationPathPrefixV1 = "/_matrix/federation/v1" - path := federationPathPrefixV1 + "/forward_async/" + string(txn.TransactionID) + "/" + userID.Raw() - request := gomatrixserverlib.NewFederationRequest("PUT", txn.Origin, txn.Destination, path) - request.SetContent(txn) + path := federationPathPrefixV1 + "/forward_async/" + string(txnID) + "/" + userID.Raw() + request := gomatrixserverlib.NewFederationRequest("PUT", origin, destination, path) + request.SetContent(content) - return txn, request + return request } -func TestEmptyForwardReturnsOk(t *testing.T) { +func TestForwardEmptyReturnsOk(t *testing.T) { testDB := storage.NewFakeRelayDatabase() db := shared.Database{ Writer: sqlutil.NewDummyWriter(), @@ -51,10 +57,10 @@ func TestEmptyForwardReturnsOk(t *testing.T) { } httpReq := &http.Request{} userID, err := gomatrixserverlib.NewUserID("@local:domain", false) - if err != nil { - t.Fatalf("Invalid userID: %s", err.Error()) - } - _, request := createFederationRequest(*userID) + assert.Nil(t, err, "Invalid userID") + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn) relayAPI := internal.NewRelayInternalAPI( &db, nil, nil, nil, nil, false, "", @@ -62,10 +68,104 @@ func TestEmptyForwardReturnsOk(t *testing.T) { response := routing.ForwardAsync(httpReq, &request, &relayAPI, "1", *userID) - expected := 200 - if response.Code != expected { - t.Fatalf("Expected: %v, Actual: %v", expected, response.Code) + assert.Equal(t, 200, response.Code) +} + +func TestForwardBadJSONReturnsError(t *testing.T) { + testDB := storage.NewFakeRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + type BadData struct { + Field bool `json:"pdus"` + } + content := BadData{ + Field: false, + } + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.ForwardAsync(httpReq, &request, &relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) +} + +func TestForwardTooManyPDUsReturnsError(t *testing.T) { + testDB := storage.NewFakeRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + type BadData struct { + Field []json.RawMessage `json:"pdus"` + } + content := BadData{ + Field: []json.RawMessage{}, + } + for i := 0; i < 51; i++ { + content.Field = append(content.Field, []byte{}) + } + assert.Greater(t, len(content.Field), 50) + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.ForwardAsync(httpReq, &request, &relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) +} + +func TestForwardTooManyEDUsReturnsError(t *testing.T) { + testDB := storage.NewFakeRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + type BadData struct { + Field []gomatrixserverlib.EDU `json:"edus"` + } + content := BadData{ + Field: []gomatrixserverlib.EDU{}, + } + for i := 0; i < 101; i++ { + content.Field = append(content.Field, gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping}) + } + assert.Greater(t, len(content.Field), 100) + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.ForwardAsync(httpReq, &request, &relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) } func TestUniqueTransactionStoredInDatabase(t *testing.T) { @@ -77,35 +177,24 @@ func TestUniqueTransactionStoredInDatabase(t *testing.T) { } httpReq := &http.Request{} userID, err := gomatrixserverlib.NewUserID("@local:domain", false) - if err != nil { - t.Fatalf("Invalid userID: %s", err.Error()) - } - inputTransaction, request := createFederationRequest(*userID) + assert.Nil(t, err, "Invalid userID") + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn) relayAPI := internal.NewRelayInternalAPI( &db, nil, nil, nil, nil, false, "", ) response := routing.ForwardAsync( - httpReq, &request, &relayAPI, inputTransaction.TransactionID, *userID) + httpReq, &request, &relayAPI, txn.TransactionID, *userID) transaction, _, err := db.GetAsyncTransaction(context.TODO(), *userID) - if err != nil { - t.Fatalf("Failed retrieving transaction: %s", err.Error()) - } - transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID) - if err != nil { - t.Fatalf("Failed retrieving transaction count: %s", err.Error()) - } + assert.Nil(t, err, "Failed retrieving transaction") - expected := 200 - if response.Code != expected { - t.Fatalf("Expected Return Code: %v, Actual: %v", expected, response.Code) - } - if transactionCount != 1 { - t.Fatalf("Expected count of 1, Actual: %d", transactionCount) - } - if transaction.TransactionID != inputTransaction.TransactionID { - t.Fatalf("Expected Transaction ID: %s, Actual: %s", - inputTransaction.TransactionID, transaction.TransactionID) - } + transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID) + assert.Nil(t, err, "Failed retrieving transaction count") + + assert.Equal(t, 200, response.Code) + assert.Equal(t, int64(1), transactionCount) + assert.Equal(t, txn.TransactionID, transaction.TransactionID) }