Handle retrieving async events on request

This commit is contained in:
Devon Hudson 2022-12-02 14:07:40 -07:00
parent bfa784b224
commit 56e9b54f31
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
11 changed files with 207 additions and 34 deletions

View file

@ -14,6 +14,7 @@ import (
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/consumers"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/version"
)
@ -700,6 +701,30 @@ func (r *FederationInternalAPI) QueryAsyncTransactions(
request *api.QueryAsyncTransactionsRequest,
response *api.QueryAsyncTransactionsResponse,
) error {
transaction, receipt, err := r.db.GetAsyncTransaction(ctx, request.UserID)
if err != nil {
return err
}
// TODO : Shouldn't be deleting unless the transaction was successfully returned...
// TODO : Should delete transaction json from table if no more associations
if transaction != nil && receipt != nil {
err = r.db.CleanAsyncTransactions(ctx, request.UserID, []*shared.Receipt{receipt})
if err != nil {
return err
}
}
// TODO : These db calls should happen at the same time right?
count, err := r.db.GetAsyncTransactionCount(ctx, request.UserID)
if err != nil {
return err
}
response.RemainingCount = uint32(count)
if transaction != nil {
response.Txn = *transaction
}
return nil
}

View file

@ -1250,7 +1250,7 @@ func TestSendPDUOnAsyncSuccessRemovedFromDB(t *testing.T) {
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
assumedOffline, _ := db.IsServerAssumedOffline(destination)
assert.Equal(t, assumedOffline, true)
assert.Equal(t, true, assumedOffline)
}
func TestSendEDUOnAsyncSuccessRemovedFromDB(t *testing.T) {
@ -1289,5 +1289,5 @@ func TestSendEDUOnAsyncSuccessRemovedFromDB(t *testing.T) {
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
assumedOffline, _ := db.IsServerAssumedOffline(destination)
assert.Equal(t, assumedOffline, true)
assert.Equal(t, true, assumedOffline)
}

View file

@ -16,6 +16,7 @@ type AsyncEventsResponse struct {
// GetAsyncEvents implements /_matrix/federation/v1/async_events/{userID}
func GetAsyncEvents(
httpReq *http.Request,
fedReq *gomatrixserverlib.FederationRequest,
fedAPI api.FederationInternalAPI,
userID gomatrixserverlib.UserID,
) util.JSONResponse {
@ -30,7 +31,7 @@ func GetAsyncEvents(
return util.JSONResponse{
Code: http.StatusOK,
JSON: AsyncEventsResponse{
Transaction: gomatrixserverlib.Transaction{},
Transaction: response.Txn,
Remaining: response.RemainingCount,
},
}

View file

@ -37,10 +37,115 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) {
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
)
response := routing.GetAsyncEvents(httpReq, fedAPI, *userID)
assert.Equal(t, response.Code, http.StatusOK)
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse := response.JSON.(routing.AsyncEventsResponse)
assert.Equal(t, jsonResponse.Remaining, uint32(0))
assert.Equal(t, jsonResponse.Transaction, gomatrixserverlib.Transaction{})
assert.Equal(t, uint32(0), jsonResponse.Remaining)
assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction)
}
func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
testDB := createDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
FederationQueueTransactions: testDB,
FederationTransactionJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
transaction := createTransaction()
receipt, err := db.StoreAsyncTransaction(context.Background(), transaction)
if err != nil {
t.Fatalf("Failed to store transaction: %s", err.Error())
}
err = db.AssociateAsyncTransactionWithDestinations(
context.Background(),
map[gomatrixserverlib.UserID]struct{}{
*userID: {},
},
transaction.TransactionID,
receipt)
if err != nil {
t.Fatalf("Failed to associate transaction with user: %s", err.Error())
}
fedAPI := internal.NewFederationInternalAPI(
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
)
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse := response.JSON.(routing.AsyncEventsResponse)
assert.Equal(t, uint32(0), jsonResponse.Remaining)
assert.Equal(t, transaction, jsonResponse.Transaction)
}
func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) {
testDB := createDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
FederationQueueTransactions: testDB,
FederationTransactionJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
transaction := createTransaction()
receipt, err := db.StoreAsyncTransaction(context.Background(), transaction)
if err != nil {
t.Fatalf("Failed to store transaction: %s", err.Error())
}
err = db.AssociateAsyncTransactionWithDestinations(
context.Background(),
map[gomatrixserverlib.UserID]struct{}{
*userID: {},
},
transaction.TransactionID,
receipt)
if err != nil {
t.Fatalf("Failed to associate transaction with user: %s", err.Error())
}
transaction2 := createTransaction()
receipt2, err := db.StoreAsyncTransaction(context.Background(), transaction2)
if err != nil {
t.Fatalf("Failed to store transaction: %s", err.Error())
}
err = db.AssociateAsyncTransactionWithDestinations(
context.Background(),
map[gomatrixserverlib.UserID]struct{}{
*userID: {},
},
transaction2.TransactionID,
receipt2)
if err != nil {
t.Fatalf("Failed to associate transaction with user: %s", err.Error())
}
fedAPI := internal.NewFederationInternalAPI(
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
)
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse := response.JSON.(routing.AsyncEventsResponse)
assert.Equal(t, uint32(1), jsonResponse.Remaining)
assert.Equal(t, transaction, jsonResponse.Transaction)
response = routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse = response.JSON.(routing.AsyncEventsResponse)
assert.Equal(t, uint32(0), jsonResponse.Remaining)
assert.Equal(t, transaction2, jsonResponse.Transaction)
}

View file

@ -65,8 +65,10 @@ func (d *testDatabase) SelectQueueTransactions(ctx context.Context, txn *sql.Tx,
if limit > len(d.associations[serverName]) {
resultCount = len(d.associations[serverName])
}
for i := 0; i < resultCount; i++ {
results = append(results, d.associations[serverName][i])
if resultCount > 0 {
for i := 0; i < resultCount; i++ {
results = append(results, d.associations[serverName][i])
}
}
return results, nil
@ -174,7 +176,7 @@ func TestUniqueTransactionStoredInDatabase(t *testing.T) {
response := routing.ForwardAsync(
httpReq, &request, fedAPI, inputTransaction.TransactionID, *userID)
transaction, err := db.GetAsyncTransaction(context.TODO(), *userID)
transaction, _, err := db.GetAsyncTransaction(context.TODO(), *userID)
if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}

View file

@ -150,6 +150,20 @@ func Setup(
},
)).Methods(http.MethodPut, http.MethodOptions)
v1fedmux.Handle("/async_events/{userID}", MakeFedAPI(
"federation_async_events", "", cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username was invalid"),
}
}
return GetAsyncEvents(httpReq, request, fsAPI, *userID)
},
)).Methods(http.MethodGet, http.MethodOptions)
v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI(
"federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {

View file

@ -52,9 +52,10 @@ type Database interface {
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
StoreAsyncTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*shared.Receipt, error)
GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, error)
GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error)
AssociateAsyncTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *shared.Receipt) error
CleanAsyncTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*shared.Receipt) error
GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *shared.Receipt, error)
GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error)
// these don't have contexts passed in as we want things to happen regardless of the request context
AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error

View file

@ -349,27 +349,52 @@ func (d *Database) AssociateAsyncTransactionWithDestinations(
return nil
}
func (d *Database) CleanAsyncTransactions(
ctx context.Context,
userID gomatrixserverlib.UserID,
receipts []*Receipt,
) error {
println(len(receipts))
nids := make([]int64, len(receipts))
for i, receipt := range receipts {
nids[i] = receipt.nid
}
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.FederationQueueTransactions.DeleteQueueTransactions(ctx, txn, userID.Domain(), nids)
return err
})
if err != nil {
return fmt.Errorf("d.insertQueueTransaction: %w", err)
}
return nil
}
func (d *Database) GetAsyncTransaction(
ctx context.Context,
userID gomatrixserverlib.UserID,
) (*gomatrixserverlib.Transaction, error) {
) (*gomatrixserverlib.Transaction, *Receipt, error) {
nids, err := d.FederationQueueTransactions.SelectQueueTransactions(ctx, nil, userID.Domain(), 1)
if err != nil {
return nil, fmt.Errorf("d.SelectQueueTransaction: %w", err)
return nil, nil, fmt.Errorf("d.SelectQueueTransaction: %w", err)
}
if len(nids) == 0 {
return nil, nil, nil
}
txn, err := d.FederationTransactionJSON.SelectTransactionJSON(ctx, nil, nids)
txns, err := d.FederationTransactionJSON.SelectTransactionJSON(ctx, nil, nids)
if err != nil {
return nil, fmt.Errorf("d.SelectTransactionJSON: %w", err)
return nil, nil, fmt.Errorf("d.SelectTransactionJSON: %w", err)
}
transaction := &gomatrixserverlib.Transaction{}
err = json.Unmarshal(txn[nids[0]], transaction)
err = json.Unmarshal(txns[nids[0]], transaction)
if err != nil {
return nil, fmt.Errorf("Unmarshall transaction: %w", err)
return nil, nil, fmt.Errorf("Unmarshall transaction: %w", err)
}
return transaction, nil
receipt := NewReceipt(nids[0])
return transaction, &receipt, nil
}
func (d *Database) GetAsyncTransactionCount(

View file

@ -62,7 +62,7 @@ func TestShouldInsertAssumedOfflineServer(t *testing.T) {
t.Fatalf("Failed retrieving server: %s", err.Error())
}
assert.Equal(t, isOffline, true)
assert.Equal(t, true, isOffline)
})
}
@ -85,7 +85,7 @@ func TestShouldDeleteCorrectAssumedOfflineServer(t *testing.T) {
if err != nil {
t.Fatalf("Failed retrieving server status: %s", err.Error())
}
assert.Equal(t, isOffline, true)
assert.Equal(t, true, isOffline)
err = db.Table.DeleteAssumedOffline(ctx, nil, server1)
if err != nil {
@ -96,13 +96,13 @@ func TestShouldDeleteCorrectAssumedOfflineServer(t *testing.T) {
if err != nil {
t.Fatalf("Failed retrieving server status: %s", err.Error())
}
assert.Equal(t, isOffline, false)
assert.Equal(t, false, isOffline)
isOffline2, err := db.Table.SelectAssumedOffline(ctx, nil, server2)
if err != nil {
t.Fatalf("Failed retrieving server status: %s", err.Error())
}
assert.Equal(t, isOffline2, true)
assert.Equal(t, true, isOffline2)
})
}
@ -125,13 +125,13 @@ func TestShouldDeleteAllAssumedOfflineServers(t *testing.T) {
if err != nil {
t.Fatalf("Failed retrieving server status: %s", err.Error())
}
assert.Equal(t, isOffline, true)
assert.Equal(t, true, isOffline)
isOffline2, err := db.Table.SelectAssumedOffline(ctx, nil, server2)
if err != nil {
t.Fatalf("Failed retrieving server status: %s", err.Error())
}
assert.Equal(t, isOffline2, true)
assert.Equal(t, true, isOffline2)
err = db.Table.DeleteAllAssumedOffline(ctx, nil)
if err != nil {
@ -142,11 +142,11 @@ func TestShouldDeleteAllAssumedOfflineServers(t *testing.T) {
if err != nil {
t.Fatalf("Failed retrieving server status: %s", err.Error())
}
assert.Equal(t, isOffline, false)
assert.Equal(t, false, isOffline)
isOffline2, err = db.Table.SelectAssumedOffline(ctx, nil, server2)
if err != nil {
t.Fatalf("Failed retrieving server status: %s", err.Error())
}
assert.Equal(t, isOffline2, false)
assert.Equal(t, false, isOffline2)
})
}

View file

@ -85,8 +85,8 @@ func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
assert.Equal(t, retrievedNids[0], nid)
assert.Equal(t, len(retrievedNids), 1)
assert.Equal(t, nid, retrievedNids[0])
assert.Equal(t, 1, len(retrievedNids))
})
}
@ -117,7 +117,7 @@ func TestShouldDeleteQueueTransaction(t *testing.T) {
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
assert.Equal(t, count, int64(0))
assert.Equal(t, int64(0), count)
})
}
@ -160,12 +160,12 @@ func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) {
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
assert.Equal(t, count, int64(1))
assert.Equal(t, int64(1), count)
count, err = db.Table.SelectQueueTransactionCount(ctx, nil, serverName2)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
assert.Equal(t, count, int64(1))
assert.Equal(t, int64(1), count)
})
}

View file

@ -112,7 +112,7 @@ func TestShouldRetrieveInsertedTransaction(t *testing.T) {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
assert.Equal(t, len(storedJSON), 1)
assert.Equal(t, 1, len(storedJSON))
var storedTx gomatrixserverlib.Transaction
json.Unmarshal(storedJSON[1], &storedTx)
@ -156,6 +156,6 @@ func TestShouldDeleteTransaction(t *testing.T) {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
assert.Equal(t, len(storedJSON), 0)
assert.Equal(t, 0, len(storedJSON))
})
}