Add initial test for forward_async federation endpoint

This commit is contained in:
Devon Hudson 2022-11-10 18:24:21 -07:00
parent efe28db631
commit b9d5fd942f
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
5 changed files with 139 additions and 0 deletions

View file

@ -0,0 +1,24 @@
package routing
import (
"net/http"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// ForwardAsync implements /_matrix/federation/v1/forward_async/{txnID}/{userID}
func ForwardAsync(
httpReq *http.Request,
fedReq *gomatrixserverlib.FederationRequest,
fedAPI api.FederationInternalAPI,
txnId gomatrixserverlib.TransactionID,
userID gomatrixserverlib.UserID,
) util.JSONResponse {
// TODO: wrap in fedAPI call
// fedAPI.db.AssociateAsyncTransactionWithDestinations(context.TODO(), userID, nil)
return util.JSONResponse{Code: 200}
}

View file

@ -0,0 +1,64 @@
package routing_test
import (
// "context"
"net/http"
"testing"
"github.com/matrix-org/dendrite/federationapi/internal"
"github.com/matrix-org/dendrite/federationapi/routing"
// "github.com/matrix-org/dendrite/federationapi/storage/shared"
// "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
func TestEmptyForwardReturnsOk(t *testing.T) {
httpReq := &http.Request{}
request := &gomatrixserverlib.FederationRequest{}
fedAPI := internal.FederationInternalAPI{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
response := routing.ForwardAsync(httpReq, request, &fedAPI, "1", *userID)
expected := 200
if response.Code != expected {
t.Fatalf("Expected: %v, Actual: %v", expected, response.Code)
}
}
// func TestUniqueTransactionStoredInDatabase(t *testing.T) {
// db := shared.Database{}
// httpReq := &http.Request{}
// inputTransaction := gomatrixserverlib.Transaction{}
// request := &gomatrixserverlib.FederationRequest{}
// fedAPI := internal.NewFederationInternalAPI(
// &db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
// )
// userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
// if err != nil {
// t.Fatalf("Invalid userID: %s", err.Error())
// }
// response := routing.ForwardAsync(httpReq, request, fedAPI, "1", *userID)
// transaction, err := db.GetAsyncTransaction(context.TODO(), *userID)
// transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID)
// expected := 200
// if response.Code != expected {
// t.Fatalf("Expected Return Code: %v, Actual: %v", expected, response.Code)
// }
// if transactionCount != 1 {
// t.Fatalf("Expected count of 1, Actual: %d", transactionCount)
// }
// if transaction.TransactionID != inputTransaction.TransactionID {
// t.Fatalf("Expected Transaction ID: %s, Actual: %s",
// inputTransaction.TransactionID, transaction.TransactionID)
// }
// }
// func TestDuplicateTransactionNotStoredInDatabase(t *testing.T) {
// }

View file

@ -51,6 +51,10 @@ type Database interface {
GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (gomatrixserverlib.Transaction, error)
GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error)
AssociateAsyncTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, receipt *shared.Receipt) error
// these don't have contexts passed in as we want things to happen regardless of the request context // these don't have contexts passed in as we want things to happen regardless of the request context
AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error
RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error

View file

@ -27,6 +27,11 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
type transactionEntry struct {
transaction gomatrixserverlib.Transaction
userID []gomatrixserverlib.UserID
}
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
IsLocalServerName func(gomatrixserverlib.ServerName) bool IsLocalServerName func(gomatrixserverlib.ServerName) bool
@ -42,6 +47,7 @@ type Database struct {
NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON
NotaryServerKeysMetadata tables.FederationNotaryServerKeysMetadata NotaryServerKeysMetadata tables.FederationNotaryServerKeysMetadata
ServerSigningKeys tables.FederationServerSigningKeys ServerSigningKeys tables.FederationServerSigningKeys
transactionDB map[Receipt]transactionEntry
} }
// An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. // An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs.
@ -257,3 +263,42 @@ func (d *Database) GetNotaryKeys(
}) })
return sks, err return sks, err
} }
func (d *Database) AssociateAsyncTransactionWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.UserID]struct{},
receipt *Receipt,
) error {
if transaction, ok := d.transactionDB[*receipt]; ok {
for k := range destinations {
transaction.userID = append(transaction.userID, k)
}
d.transactionDB[*receipt] = transaction
} else {
return fmt.Errorf("No transactions exist with that NID")
}
return nil
}
func (d *Database) GetAsyncTransaction(
ctx context.Context,
userID gomatrixserverlib.UserID,
) (gomatrixserverlib.Transaction, error) {
return gomatrixserverlib.Transaction{}, nil
}
func (d *Database) GetAsyncTransactionCount(
ctx context.Context,
userID gomatrixserverlib.UserID,
) (int64, error) {
count := int64(0)
for _, transaction := range d.transactionDB {
for _, user := range transaction.userID {
if user == userID {
count++
}
}
}
return count, nil
}

2
go.mod
View file

@ -142,3 +142,5 @@ require (
) )
go 1.18 go 1.18
replace github.com/matrix-org/gomatrixserverlib => ../../gomatrixserverlib/mailbox