From 56e9b54f31830cb97ea3e250f35360ccd7bee1ef Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 2 Dec 2022 14:07:40 -0700 Subject: [PATCH] Handle retrieving async events on request --- federationapi/internal/perform.go | 25 ++++ federationapi/queue/queue_test.go | 4 +- federationapi/routing/asyncevents.go | 3 +- federationapi/routing/asyncevents_test.go | 113 +++++++++++++++++- federationapi/routing/forwardasync_test.go | 8 +- federationapi/routing/routing.go | 14 +++ federationapi/storage/interface.go | 5 +- federationapi/storage/shared/storage.go | 39 ++++-- .../tables/assumed_offline_table_test.go | 16 +-- .../tables/queue_transactions_table_test.go | 10 +- .../tables/transaction_json_table_test.go | 4 +- 11 files changed, 207 insertions(+), 34 deletions(-) diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index a46c41054..fdc97d6c1 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -14,6 +14,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/consumers" + "github.com/matrix-org/dendrite/federationapi/storage/shared" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/version" ) @@ -700,6 +701,30 @@ func (r *FederationInternalAPI) QueryAsyncTransactions( request *api.QueryAsyncTransactionsRequest, response *api.QueryAsyncTransactionsResponse, ) error { + transaction, receipt, err := r.db.GetAsyncTransaction(ctx, request.UserID) + if err != nil { + return err + } + + // TODO : Shouldn't be deleting unless the transaction was successfully returned... + // TODO : Should delete transaction json from table if no more associations + if transaction != nil && receipt != nil { + err = r.db.CleanAsyncTransactions(ctx, request.UserID, []*shared.Receipt{receipt}) + if err != nil { + return err + } + } + + // TODO : These db calls should happen at the same time right? + count, err := r.db.GetAsyncTransactionCount(ctx, request.UserID) + if err != nil { + return err + } + + response.RemainingCount = uint32(count) + if transaction != nil { + response.Txn = *transaction + } return nil } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index d447882af..9506ede7d 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -1250,7 +1250,7 @@ func TestSendPDUOnAsyncSuccessRemovedFromDB(t *testing.T) { poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) assumedOffline, _ := db.IsServerAssumedOffline(destination) - assert.Equal(t, assumedOffline, true) + assert.Equal(t, true, assumedOffline) } func TestSendEDUOnAsyncSuccessRemovedFromDB(t *testing.T) { @@ -1289,5 +1289,5 @@ func TestSendEDUOnAsyncSuccessRemovedFromDB(t *testing.T) { poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) assumedOffline, _ := db.IsServerAssumedOffline(destination) - assert.Equal(t, assumedOffline, true) + assert.Equal(t, true, assumedOffline) } diff --git a/federationapi/routing/asyncevents.go b/federationapi/routing/asyncevents.go index 0348160ff..9a35e6805 100644 --- a/federationapi/routing/asyncevents.go +++ b/federationapi/routing/asyncevents.go @@ -16,6 +16,7 @@ type AsyncEventsResponse struct { // GetAsyncEvents implements /_matrix/federation/v1/async_events/{userID} func GetAsyncEvents( httpReq *http.Request, + fedReq *gomatrixserverlib.FederationRequest, fedAPI api.FederationInternalAPI, userID gomatrixserverlib.UserID, ) util.JSONResponse { @@ -30,7 +31,7 @@ func GetAsyncEvents( return util.JSONResponse{ Code: http.StatusOK, JSON: AsyncEventsResponse{ - Transaction: gomatrixserverlib.Transaction{}, + Transaction: response.Txn, Remaining: response.RemainingCount, }, } diff --git a/federationapi/routing/asyncevents_test.go b/federationapi/routing/asyncevents_test.go index b9258e210..f9775a90e 100644 --- a/federationapi/routing/asyncevents_test.go +++ b/federationapi/routing/asyncevents_test.go @@ -37,10 +37,115 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) { &db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil, ) - response := routing.GetAsyncEvents(httpReq, fedAPI, *userID) - assert.Equal(t, response.Code, http.StatusOK) + response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) jsonResponse := response.JSON.(routing.AsyncEventsResponse) - assert.Equal(t, jsonResponse.Remaining, uint32(0)) - assert.Equal(t, jsonResponse.Transaction, gomatrixserverlib.Transaction{}) + assert.Equal(t, uint32(0), jsonResponse.Remaining) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) +} + +func TestGetAsyncReturnsSavedTransaction(t *testing.T) { + testDB := createDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + FederationQueueTransactions: testDB, + FederationTransactionJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + if err != nil { + t.Fatalf("Invalid userID: %s", err.Error()) + } + transaction := createTransaction() + + receipt, err := db.StoreAsyncTransaction(context.Background(), transaction) + if err != nil { + t.Fatalf("Failed to store transaction: %s", err.Error()) + } + err = db.AssociateAsyncTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction.TransactionID, + receipt) + if err != nil { + t.Fatalf("Failed to associate transaction with user: %s", err.Error()) + } + + fedAPI := internal.NewFederationInternalAPI( + &db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil, + ) + + response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse := response.JSON.(routing.AsyncEventsResponse) + assert.Equal(t, uint32(0), jsonResponse.Remaining) + assert.Equal(t, transaction, jsonResponse.Transaction) +} + +func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) { + testDB := createDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + FederationQueueTransactions: testDB, + FederationTransactionJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + if err != nil { + t.Fatalf("Invalid userID: %s", err.Error()) + } + + transaction := createTransaction() + receipt, err := db.StoreAsyncTransaction(context.Background(), transaction) + if err != nil { + t.Fatalf("Failed to store transaction: %s", err.Error()) + } + err = db.AssociateAsyncTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction.TransactionID, + receipt) + if err != nil { + t.Fatalf("Failed to associate transaction with user: %s", err.Error()) + } + + transaction2 := createTransaction() + receipt2, err := db.StoreAsyncTransaction(context.Background(), transaction2) + if err != nil { + t.Fatalf("Failed to store transaction: %s", err.Error()) + } + err = db.AssociateAsyncTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction2.TransactionID, + receipt2) + if err != nil { + t.Fatalf("Failed to associate transaction with user: %s", err.Error()) + } + + fedAPI := internal.NewFederationInternalAPI( + &db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil, + ) + + response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse := response.JSON.(routing.AsyncEventsResponse) + assert.Equal(t, uint32(1), jsonResponse.Remaining) + assert.Equal(t, transaction, jsonResponse.Transaction) + + response = routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(routing.AsyncEventsResponse) + assert.Equal(t, uint32(0), jsonResponse.Remaining) + assert.Equal(t, transaction2, jsonResponse.Transaction) } diff --git a/federationapi/routing/forwardasync_test.go b/federationapi/routing/forwardasync_test.go index fe3df24ab..d62c04b61 100644 --- a/federationapi/routing/forwardasync_test.go +++ b/federationapi/routing/forwardasync_test.go @@ -65,8 +65,10 @@ func (d *testDatabase) SelectQueueTransactions(ctx context.Context, txn *sql.Tx, if limit > len(d.associations[serverName]) { resultCount = len(d.associations[serverName]) } - for i := 0; i < resultCount; i++ { - results = append(results, d.associations[serverName][i]) + if resultCount > 0 { + for i := 0; i < resultCount; i++ { + results = append(results, d.associations[serverName][i]) + } } return results, nil @@ -174,7 +176,7 @@ func TestUniqueTransactionStoredInDatabase(t *testing.T) { response := routing.ForwardAsync( httpReq, &request, fedAPI, inputTransaction.TransactionID, *userID) - transaction, err := db.GetAsyncTransaction(context.TODO(), *userID) + transaction, _, err := db.GetAsyncTransaction(context.TODO(), *userID) if err != nil { t.Fatalf("Failed retrieving transaction: %s", err.Error()) } diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index f6ef64444..ebdf32ab6 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -150,6 +150,20 @@ func Setup( }, )).Methods(http.MethodPut, http.MethodOptions) + v1fedmux.Handle("/async_events/{userID}", MakeFedAPI( + "federation_async_events", "", cfg.Matrix.IsLocalServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Username was invalid"), + } + } + return GetAsyncEvents(httpReq, request, fsAPI, *userID) + }, + )).Methods(http.MethodGet, http.MethodOptions) + v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 6108862f7..a1118fdfa 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -52,9 +52,10 @@ type Database interface { GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) StoreAsyncTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*shared.Receipt, error) - GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, error) - GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error) AssociateAsyncTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *shared.Receipt) error + CleanAsyncTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*shared.Receipt) error + GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *shared.Receipt, error) + GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error) // these don't have contexts passed in as we want things to happen regardless of the request context AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 1ec9f5d87..1f98eb5ae 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -349,27 +349,52 @@ func (d *Database) AssociateAsyncTransactionWithDestinations( return nil } +func (d *Database) CleanAsyncTransactions( + ctx context.Context, + userID gomatrixserverlib.UserID, + receipts []*Receipt, +) error { + println(len(receipts)) + nids := make([]int64, len(receipts)) + for i, receipt := range receipts { + nids[i] = receipt.nid + } + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + err := d.FederationQueueTransactions.DeleteQueueTransactions(ctx, txn, userID.Domain(), nids) + return err + }) + if err != nil { + return fmt.Errorf("d.insertQueueTransaction: %w", err) + } + + return nil +} + func (d *Database) GetAsyncTransaction( ctx context.Context, userID gomatrixserverlib.UserID, -) (*gomatrixserverlib.Transaction, error) { +) (*gomatrixserverlib.Transaction, *Receipt, error) { nids, err := d.FederationQueueTransactions.SelectQueueTransactions(ctx, nil, userID.Domain(), 1) if err != nil { - return nil, fmt.Errorf("d.SelectQueueTransaction: %w", err) + return nil, nil, fmt.Errorf("d.SelectQueueTransaction: %w", err) + } + if len(nids) == 0 { + return nil, nil, nil } - txn, err := d.FederationTransactionJSON.SelectTransactionJSON(ctx, nil, nids) + txns, err := d.FederationTransactionJSON.SelectTransactionJSON(ctx, nil, nids) if err != nil { - return nil, fmt.Errorf("d.SelectTransactionJSON: %w", err) + return nil, nil, fmt.Errorf("d.SelectTransactionJSON: %w", err) } transaction := &gomatrixserverlib.Transaction{} - err = json.Unmarshal(txn[nids[0]], transaction) + err = json.Unmarshal(txns[nids[0]], transaction) if err != nil { - return nil, fmt.Errorf("Unmarshall transaction: %w", err) + return nil, nil, fmt.Errorf("Unmarshall transaction: %w", err) } - return transaction, nil + receipt := NewReceipt(nids[0]) + return transaction, &receipt, nil } func (d *Database) GetAsyncTransactionCount( diff --git a/federationapi/storage/tables/assumed_offline_table_test.go b/federationapi/storage/tables/assumed_offline_table_test.go index 9c855ec79..e1f08ed56 100644 --- a/federationapi/storage/tables/assumed_offline_table_test.go +++ b/federationapi/storage/tables/assumed_offline_table_test.go @@ -62,7 +62,7 @@ func TestShouldInsertAssumedOfflineServer(t *testing.T) { t.Fatalf("Failed retrieving server: %s", err.Error()) } - assert.Equal(t, isOffline, true) + assert.Equal(t, true, isOffline) }) } @@ -85,7 +85,7 @@ func TestShouldDeleteCorrectAssumedOfflineServer(t *testing.T) { if err != nil { t.Fatalf("Failed retrieving server status: %s", err.Error()) } - assert.Equal(t, isOffline, true) + assert.Equal(t, true, isOffline) err = db.Table.DeleteAssumedOffline(ctx, nil, server1) if err != nil { @@ -96,13 +96,13 @@ func TestShouldDeleteCorrectAssumedOfflineServer(t *testing.T) { if err != nil { t.Fatalf("Failed retrieving server status: %s", err.Error()) } - assert.Equal(t, isOffline, false) + assert.Equal(t, false, isOffline) isOffline2, err := db.Table.SelectAssumedOffline(ctx, nil, server2) if err != nil { t.Fatalf("Failed retrieving server status: %s", err.Error()) } - assert.Equal(t, isOffline2, true) + assert.Equal(t, true, isOffline2) }) } @@ -125,13 +125,13 @@ func TestShouldDeleteAllAssumedOfflineServers(t *testing.T) { if err != nil { t.Fatalf("Failed retrieving server status: %s", err.Error()) } - assert.Equal(t, isOffline, true) + assert.Equal(t, true, isOffline) isOffline2, err := db.Table.SelectAssumedOffline(ctx, nil, server2) if err != nil { t.Fatalf("Failed retrieving server status: %s", err.Error()) } - assert.Equal(t, isOffline2, true) + assert.Equal(t, true, isOffline2) err = db.Table.DeleteAllAssumedOffline(ctx, nil) if err != nil { @@ -142,11 +142,11 @@ func TestShouldDeleteAllAssumedOfflineServers(t *testing.T) { if err != nil { t.Fatalf("Failed retrieving server status: %s", err.Error()) } - assert.Equal(t, isOffline, false) + assert.Equal(t, false, isOffline) isOffline2, err = db.Table.SelectAssumedOffline(ctx, nil, server2) if err != nil { t.Fatalf("Failed retrieving server status: %s", err.Error()) } - assert.Equal(t, isOffline2, false) + assert.Equal(t, false, isOffline2) }) } diff --git a/federationapi/storage/tables/queue_transactions_table_test.go b/federationapi/storage/tables/queue_transactions_table_test.go index 46d8a3bf3..9266f6c95 100644 --- a/federationapi/storage/tables/queue_transactions_table_test.go +++ b/federationapi/storage/tables/queue_transactions_table_test.go @@ -85,8 +85,8 @@ func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) { t.Fatalf("Failed retrieving transaction: %s", err.Error()) } - assert.Equal(t, retrievedNids[0], nid) - assert.Equal(t, len(retrievedNids), 1) + assert.Equal(t, nid, retrievedNids[0]) + assert.Equal(t, 1, len(retrievedNids)) }) } @@ -117,7 +117,7 @@ func TestShouldDeleteQueueTransaction(t *testing.T) { if err != nil { t.Fatalf("Failed retrieving transaction count: %s", err.Error()) } - assert.Equal(t, count, int64(0)) + assert.Equal(t, int64(0), count) }) } @@ -160,12 +160,12 @@ func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) { if err != nil { t.Fatalf("Failed retrieving transaction count: %s", err.Error()) } - assert.Equal(t, count, int64(1)) + assert.Equal(t, int64(1), count) count, err = db.Table.SelectQueueTransactionCount(ctx, nil, serverName2) if err != nil { t.Fatalf("Failed retrieving transaction count: %s", err.Error()) } - assert.Equal(t, count, int64(1)) + assert.Equal(t, int64(1), count) }) } diff --git a/federationapi/storage/tables/transaction_json_table_test.go b/federationapi/storage/tables/transaction_json_table_test.go index 7fbf9bb62..9569b0f0c 100644 --- a/federationapi/storage/tables/transaction_json_table_test.go +++ b/federationapi/storage/tables/transaction_json_table_test.go @@ -112,7 +112,7 @@ func TestShouldRetrieveInsertedTransaction(t *testing.T) { t.Fatalf("Failed retrieving transaction: %s", err.Error()) } - assert.Equal(t, len(storedJSON), 1) + assert.Equal(t, 1, len(storedJSON)) var storedTx gomatrixserverlib.Transaction json.Unmarshal(storedJSON[1], &storedTx) @@ -156,6 +156,6 @@ func TestShouldDeleteTransaction(t *testing.T) { t.Fatalf("Failed retrieving transaction: %s", err.Error()) } - assert.Equal(t, len(storedJSON), 0) + assert.Equal(t, 0, len(storedJSON)) }) }