From b9d5fd942fb33a3630a6113fa980a6a422f224c4 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Thu, 10 Nov 2022 18:24:21 -0700 Subject: [PATCH] Add initial test for forward_async federation endpoint --- federationapi/routing/forwardasync.go | 24 ++++++++ federationapi/routing/forwardasync_test.go | 64 ++++++++++++++++++++++ federationapi/storage/interface.go | 4 ++ federationapi/storage/shared/storage.go | 45 +++++++++++++++ go.mod | 2 + 5 files changed, 139 insertions(+) create mode 100644 federationapi/routing/forwardasync.go create mode 100644 federationapi/routing/forwardasync_test.go diff --git a/federationapi/routing/forwardasync.go b/federationapi/routing/forwardasync.go new file mode 100644 index 000000000..e9a46d35e --- /dev/null +++ b/federationapi/routing/forwardasync.go @@ -0,0 +1,24 @@ +package routing + +import ( + "net/http" + + "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// ForwardAsync implements /_matrix/federation/v1/forward_async/{txnID}/{userID} +func ForwardAsync( + httpReq *http.Request, + fedReq *gomatrixserverlib.FederationRequest, + fedAPI api.FederationInternalAPI, + txnId gomatrixserverlib.TransactionID, + userID gomatrixserverlib.UserID, +) util.JSONResponse { + + // TODO: wrap in fedAPI call + // fedAPI.db.AssociateAsyncTransactionWithDestinations(context.TODO(), userID, nil) + + return util.JSONResponse{Code: 200} +} \ No newline at end of file diff --git a/federationapi/routing/forwardasync_test.go b/federationapi/routing/forwardasync_test.go new file mode 100644 index 000000000..4c06a86d4 --- /dev/null +++ b/federationapi/routing/forwardasync_test.go @@ -0,0 +1,64 @@ +package routing_test + +import ( + // "context" + "net/http" + "testing" + + "github.com/matrix-org/dendrite/federationapi/internal" + "github.com/matrix-org/dendrite/federationapi/routing" + // "github.com/matrix-org/dendrite/federationapi/storage/shared" + // "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +func TestEmptyForwardReturnsOk(t *testing.T) { + httpReq := &http.Request{} + request := &gomatrixserverlib.FederationRequest{} + fedAPI := internal.FederationInternalAPI{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + if err != nil { + t.Fatalf("Invalid userID: %s", err.Error()) + } + + response := routing.ForwardAsync(httpReq, request, &fedAPI, "1", *userID) + + expected := 200 + if response.Code != expected { + t.Fatalf("Expected: %v, Actual: %v", expected, response.Code) + } +} + +// func TestUniqueTransactionStoredInDatabase(t *testing.T) { +// db := shared.Database{} +// httpReq := &http.Request{} +// inputTransaction := gomatrixserverlib.Transaction{} +// request := &gomatrixserverlib.FederationRequest{} +// fedAPI := internal.NewFederationInternalAPI( +// &db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil, +// ) +// userID, err := gomatrixserverlib.NewUserID("@local:domain", false) +// if err != nil { +// t.Fatalf("Invalid userID: %s", err.Error()) +// } + +// response := routing.ForwardAsync(httpReq, request, fedAPI, "1", *userID) +// transaction, err := db.GetAsyncTransaction(context.TODO(), *userID) +// transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID) + +// expected := 200 +// if response.Code != expected { +// t.Fatalf("Expected Return Code: %v, Actual: %v", expected, response.Code) +// } +// if transactionCount != 1 { +// t.Fatalf("Expected count of 1, Actual: %d", transactionCount) +// } +// if transaction.TransactionID != inputTransaction.TransactionID { +// t.Fatalf("Expected Transaction ID: %s, Actual: %s", +// inputTransaction.TransactionID, transaction.TransactionID) +// } +// } + +// func TestDuplicateTransactionNotStoredInDatabase(t *testing.T) { + +// } \ No newline at end of file diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 09098cd1e..4540a65fe 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -51,6 +51,10 @@ type Database interface { GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) + GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (gomatrixserverlib.Transaction, error) + GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error) + AssociateAsyncTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, receipt *shared.Receipt) error + // these don't have contexts passed in as we want things to happen regardless of the request context AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 4fabff7d4..ab5d66ac8 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -27,6 +27,11 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) +type transactionEntry struct { + transaction gomatrixserverlib.Transaction + userID []gomatrixserverlib.UserID +} + type Database struct { DB *sql.DB IsLocalServerName func(gomatrixserverlib.ServerName) bool @@ -42,6 +47,7 @@ type Database struct { NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON NotaryServerKeysMetadata tables.FederationNotaryServerKeysMetadata ServerSigningKeys tables.FederationServerSigningKeys + transactionDB map[Receipt]transactionEntry } // An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. @@ -257,3 +263,42 @@ func (d *Database) GetNotaryKeys( }) return sks, err } + +func (d *Database) AssociateAsyncTransactionWithDestinations( + ctx context.Context, + destinations map[gomatrixserverlib.UserID]struct{}, + receipt *Receipt, +) error { + if transaction, ok := d.transactionDB[*receipt]; ok { + for k := range destinations { + transaction.userID = append(transaction.userID, k) + } + d.transactionDB[*receipt] = transaction + } else { + return fmt.Errorf("No transactions exist with that NID") + } + + return nil +} + +func (d *Database) GetAsyncTransaction( + ctx context.Context, + userID gomatrixserverlib.UserID, +) (gomatrixserverlib.Transaction, error) { + return gomatrixserverlib.Transaction{}, nil +} + +func (d *Database) GetAsyncTransactionCount( + ctx context.Context, + userID gomatrixserverlib.UserID, +) (int64, error) { + count := int64(0) + for _, transaction := range d.transactionDB { + for _, user := range transaction.userID { + if user == userID { + count++ + } + } + } + return count, nil +} diff --git a/go.mod b/go.mod index 7dd9e0b2c..0b40b8d1f 100644 --- a/go.mod +++ b/go.mod @@ -142,3 +142,5 @@ require ( ) go 1.18 + +replace github.com/matrix-org/gomatrixserverlib => ../../gomatrixserverlib/mailbox