Flush out forwardasync tests

This commit is contained in:
Devon Hudson 2022-12-16 10:34:58 -07:00
parent ed42b252ee
commit bd40d53cbb
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
3 changed files with 151 additions and 77 deletions

View file

@ -36,15 +36,12 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) {
} }
httpReq := &http.Request{} httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false) userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil { assert.Nil(t, err, "Invalid userID")
t.Fatalf("Invalid userID: %s", err.Error())
}
transaction := createTransaction() transaction := createTransaction()
_, err = db.StoreAsyncTransaction(context.Background(), transaction) _, err = db.StoreAsyncTransaction(context.Background(), transaction)
if err != nil { assert.Nil(t, err, "Failed to store transaction")
t.Fatalf("Failed to store transaction: %s", err.Error())
}
relayAPI := internal.NewRelayInternalAPI( relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "", &db, nil, nil, nil, nil, false, "",
@ -59,7 +56,7 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) {
assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Txn) assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Txn)
count, err := db.GetAsyncTransactionCount(context.Background(), *userID) count, err := db.GetAsyncTransactionCount(context.Background(), *userID)
assert.Equal(t, count, int64(0)) assert.Zero(t, count)
} }
func TestGetAsyncReturnsSavedTransaction(t *testing.T) { func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
@ -71,14 +68,12 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
} }
httpReq := &http.Request{} httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false) userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil { assert.Nil(t, err, "Invalid userID")
t.Fatalf("Invalid userID: %s", err.Error())
}
transaction := createTransaction() transaction := createTransaction()
receipt, err := db.StoreAsyncTransaction(context.Background(), transaction) receipt, err := db.StoreAsyncTransaction(context.Background(), transaction)
if err != nil { assert.Nil(t, err, "Failed to store transaction")
t.Fatalf("Failed to store transaction: %s", err.Error())
}
err = db.AssociateAsyncTransactionWithDestinations( err = db.AssociateAsyncTransactionWithDestinations(
context.Background(), context.Background(),
map[gomatrixserverlib.UserID]struct{}{ map[gomatrixserverlib.UserID]struct{}{
@ -86,9 +81,7 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
}, },
transaction.TransactionID, transaction.TransactionID,
receipt) receipt)
if err != nil { assert.Nil(t, err, "Failed to associate transaction with user")
t.Fatalf("Failed to associate transaction with user: %s", err.Error())
}
relayAPI := internal.NewRelayInternalAPI( relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "", &db, nil, nil, nil, nil, false, "",
@ -99,7 +92,7 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
assert.Equal(t, http.StatusOK, response.Code) assert.Equal(t, http.StatusOK, response.Code)
jsonResponse := response.JSON.(routing.AsyncEventsResponse) jsonResponse := response.JSON.(routing.AsyncEventsResponse)
assert.Equal(t, true, jsonResponse.EntriesQueued) assert.True(t, jsonResponse.EntriesQueued)
assert.Equal(t, transaction, jsonResponse.Txn) assert.Equal(t, transaction, jsonResponse.Txn)
// And once more to clear the queue // And once more to clear the queue
@ -108,11 +101,11 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
assert.Equal(t, http.StatusOK, response.Code) assert.Equal(t, http.StatusOK, response.Code)
jsonResponse = response.JSON.(routing.AsyncEventsResponse) jsonResponse = response.JSON.(routing.AsyncEventsResponse)
assert.Equal(t, false, jsonResponse.EntriesQueued) assert.False(t, jsonResponse.EntriesQueued)
assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Txn) assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Txn)
count, err := db.GetAsyncTransactionCount(context.Background(), *userID) count, err := db.GetAsyncTransactionCount(context.Background(), *userID)
assert.Equal(t, count, int64(0)) assert.Zero(t, count)
} }
func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) { func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) {
@ -124,15 +117,12 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) {
} }
httpReq := &http.Request{} httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false) userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil { assert.Nil(t, err, "Invalid userID")
t.Fatalf("Invalid userID: %s", err.Error())
}
transaction := createTransaction() transaction := createTransaction()
receipt, err := db.StoreAsyncTransaction(context.Background(), transaction) receipt, err := db.StoreAsyncTransaction(context.Background(), transaction)
if err != nil { assert.Nil(t, err, "Failed to store transaction")
t.Fatalf("Failed to store transaction: %s", err.Error())
}
err = db.AssociateAsyncTransactionWithDestinations( err = db.AssociateAsyncTransactionWithDestinations(
context.Background(), context.Background(),
map[gomatrixserverlib.UserID]struct{}{ map[gomatrixserverlib.UserID]struct{}{
@ -140,15 +130,12 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) {
}, },
transaction.TransactionID, transaction.TransactionID,
receipt) receipt)
if err != nil { assert.Nil(t, err, "Failed to associate transaction with user")
t.Fatalf("Failed to associate transaction with user: %s", err.Error())
}
transaction2 := createTransaction() transaction2 := createTransaction()
receipt2, err := db.StoreAsyncTransaction(context.Background(), transaction2) receipt2, err := db.StoreAsyncTransaction(context.Background(), transaction2)
if err != nil { assert.Nil(t, err, "Failed to store transaction")
t.Fatalf("Failed to store transaction: %s", err.Error())
}
err = db.AssociateAsyncTransactionWithDestinations( err = db.AssociateAsyncTransactionWithDestinations(
context.Background(), context.Background(),
map[gomatrixserverlib.UserID]struct{}{ map[gomatrixserverlib.UserID]struct{}{
@ -156,9 +143,7 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) {
}, },
transaction2.TransactionID, transaction2.TransactionID,
receipt2) receipt2)
if err != nil { assert.Nil(t, err, "Failed to associate transaction with user")
t.Fatalf("Failed to associate transaction with user: %s", err.Error())
}
relayAPI := internal.NewRelayInternalAPI( relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "", &db, nil, nil, nil, nil, false, "",
@ -169,7 +154,7 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) {
assert.Equal(t, http.StatusOK, response.Code) assert.Equal(t, http.StatusOK, response.Code)
jsonResponse := response.JSON.(routing.AsyncEventsResponse) jsonResponse := response.JSON.(routing.AsyncEventsResponse)
assert.Equal(t, true, jsonResponse.EntriesQueued) assert.True(t, jsonResponse.EntriesQueued)
assert.Equal(t, transaction, jsonResponse.Txn) assert.Equal(t, transaction, jsonResponse.Txn)
request = createAsyncQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}, "relay") request = createAsyncQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}, "relay")
@ -177,7 +162,7 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) {
assert.Equal(t, http.StatusOK, response.Code) assert.Equal(t, http.StatusOK, response.Code)
jsonResponse = response.JSON.(routing.AsyncEventsResponse) jsonResponse = response.JSON.(routing.AsyncEventsResponse)
assert.Equal(t, true, jsonResponse.EntriesQueued) assert.True(t, jsonResponse.EntriesQueued)
assert.Equal(t, transaction2, jsonResponse.Txn) assert.Equal(t, transaction2, jsonResponse.Txn)
// And once more to clear the queue // And once more to clear the queue
@ -186,9 +171,9 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) {
assert.Equal(t, http.StatusOK, response.Code) assert.Equal(t, http.StatusOK, response.Code)
jsonResponse = response.JSON.(routing.AsyncEventsResponse) jsonResponse = response.JSON.(routing.AsyncEventsResponse)
assert.Equal(t, false, jsonResponse.EntriesQueued) assert.False(t, jsonResponse.EntriesQueued)
assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Txn) assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Txn)
count, err := db.GetAsyncTransactionCount(context.Background(), *userID) count, err := db.GetAsyncTransactionCount(context.Background(), *userID)
assert.Equal(t, count, int64(0)) assert.Zero(t, count)
} }

View file

@ -26,10 +26,10 @@ func ForwardAsync(
} }
if err := json.Unmarshal(fedReq.Content(), &txnEvents); err != nil { if err := json.Unmarshal(fedReq.Content(), &txnEvents); err != nil {
logrus.Info("The request body could not be decoded into valid JSON. " + err.Error()) logrus.Info("The request body could not be decoded into valid JSON." + err.Error())
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON." + err.Error()),
} }
} }

View file

@ -14,6 +14,7 @@ import (
"github.com/matrix-org/dendrite/relayapi/storage" "github.com/matrix-org/dendrite/relayapi/storage"
"github.com/matrix-org/dendrite/relayapi/storage/shared" "github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
) )
const ( const (
@ -32,17 +33,22 @@ func createTransaction() gomatrixserverlib.Transaction {
return txn return txn
} }
func createFederationRequest(userID gomatrixserverlib.UserID) (gomatrixserverlib.Transaction, gomatrixserverlib.FederationRequest) { func createFederationRequest(
txn := createTransaction() userID gomatrixserverlib.UserID,
txnID gomatrixserverlib.TransactionID,
origin gomatrixserverlib.ServerName,
destination gomatrixserverlib.ServerName,
content interface{},
) gomatrixserverlib.FederationRequest {
var federationPathPrefixV1 = "/_matrix/federation/v1" var federationPathPrefixV1 = "/_matrix/federation/v1"
path := federationPathPrefixV1 + "/forward_async/" + string(txn.TransactionID) + "/" + userID.Raw() path := federationPathPrefixV1 + "/forward_async/" + string(txnID) + "/" + userID.Raw()
request := gomatrixserverlib.NewFederationRequest("PUT", txn.Origin, txn.Destination, path) request := gomatrixserverlib.NewFederationRequest("PUT", origin, destination, path)
request.SetContent(txn) request.SetContent(content)
return txn, request return request
} }
func TestEmptyForwardReturnsOk(t *testing.T) { func TestForwardEmptyReturnsOk(t *testing.T) {
testDB := storage.NewFakeRelayDatabase() testDB := storage.NewFakeRelayDatabase()
db := shared.Database{ db := shared.Database{
Writer: sqlutil.NewDummyWriter(), Writer: sqlutil.NewDummyWriter(),
@ -51,10 +57,10 @@ func TestEmptyForwardReturnsOk(t *testing.T) {
} }
httpReq := &http.Request{} httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false) userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil { assert.Nil(t, err, "Invalid userID")
t.Fatalf("Invalid userID: %s", err.Error())
} txn := createTransaction()
_, request := createFederationRequest(*userID) request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn)
relayAPI := internal.NewRelayInternalAPI( relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "", &db, nil, nil, nil, nil, false, "",
@ -62,10 +68,104 @@ func TestEmptyForwardReturnsOk(t *testing.T) {
response := routing.ForwardAsync(httpReq, &request, &relayAPI, "1", *userID) response := routing.ForwardAsync(httpReq, &request, &relayAPI, "1", *userID)
expected := 200 assert.Equal(t, 200, response.Code)
if response.Code != expected { }
t.Fatalf("Expected: %v, Actual: %v", expected, response.Code)
func TestForwardBadJSONReturnsError(t *testing.T) {
testDB := storage.NewFakeRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
} }
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
assert.Nil(t, err, "Invalid userID")
type BadData struct {
Field bool `json:"pdus"`
}
content := BadData{
Field: false,
}
txn := createTransaction()
request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content)
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.ForwardAsync(httpReq, &request, &relayAPI, "1", *userID)
assert.NotEqual(t, 200, response.Code)
}
func TestForwardTooManyPDUsReturnsError(t *testing.T) {
testDB := storage.NewFakeRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
assert.Nil(t, err, "Invalid userID")
type BadData struct {
Field []json.RawMessage `json:"pdus"`
}
content := BadData{
Field: []json.RawMessage{},
}
for i := 0; i < 51; i++ {
content.Field = append(content.Field, []byte{})
}
assert.Greater(t, len(content.Field), 50)
txn := createTransaction()
request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content)
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.ForwardAsync(httpReq, &request, &relayAPI, "1", *userID)
assert.NotEqual(t, 200, response.Code)
}
func TestForwardTooManyEDUsReturnsError(t *testing.T) {
testDB := storage.NewFakeRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
assert.Nil(t, err, "Invalid userID")
type BadData struct {
Field []gomatrixserverlib.EDU `json:"edus"`
}
content := BadData{
Field: []gomatrixserverlib.EDU{},
}
for i := 0; i < 101; i++ {
content.Field = append(content.Field, gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping})
}
assert.Greater(t, len(content.Field), 100)
txn := createTransaction()
request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content)
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.ForwardAsync(httpReq, &request, &relayAPI, "1", *userID)
assert.NotEqual(t, 200, response.Code)
} }
func TestUniqueTransactionStoredInDatabase(t *testing.T) { func TestUniqueTransactionStoredInDatabase(t *testing.T) {
@ -77,35 +177,24 @@ func TestUniqueTransactionStoredInDatabase(t *testing.T) {
} }
httpReq := &http.Request{} httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false) userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil { assert.Nil(t, err, "Invalid userID")
t.Fatalf("Invalid userID: %s", err.Error())
} txn := createTransaction()
inputTransaction, request := createFederationRequest(*userID) request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn)
relayAPI := internal.NewRelayInternalAPI( relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "", &db, nil, nil, nil, nil, false, "",
) )
response := routing.ForwardAsync( response := routing.ForwardAsync(
httpReq, &request, &relayAPI, inputTransaction.TransactionID, *userID) httpReq, &request, &relayAPI, txn.TransactionID, *userID)
transaction, _, err := db.GetAsyncTransaction(context.TODO(), *userID) transaction, _, err := db.GetAsyncTransaction(context.TODO(), *userID)
if err != nil { assert.Nil(t, err, "Failed retrieving transaction")
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
expected := 200 transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID)
if response.Code != expected { assert.Nil(t, err, "Failed retrieving transaction count")
t.Fatalf("Expected Return Code: %v, Actual: %v", expected, response.Code)
} assert.Equal(t, 200, response.Code)
if transactionCount != 1 { assert.Equal(t, int64(1), transactionCount)
t.Fatalf("Expected count of 1, Actual: %d", transactionCount) assert.Equal(t, txn.TransactionID, transaction.TransactionID)
}
if transaction.TransactionID != inputTransaction.TransactionID {
t.Fatalf("Expected Transaction ID: %s, Actual: %s",
inputTransaction.TransactionID, transaction.TransactionID)
}
} }