diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 10f0a9efc..1eb9329f0 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -92,7 +92,7 @@ type FederationClient interface { SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) SendAsyncTransaction(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error) - GetAsyncEvents(ctx context.Context, u gomatrixserverlib.UserID, relayServer gomatrixserverlib.ServerName) (res gomatrixserverlib.RespGetAsyncEvents, err error) + GetAsyncEvents(ctx context.Context, u gomatrixserverlib.UserID, prev gomatrixserverlib.RelayEntry, relayServer gomatrixserverlib.ServerName) (res gomatrixserverlib.RespGetAsyncEvents, err error) // Perform operations LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) diff --git a/relayapi/api/api.go b/relayapi/api/api.go index fab76becd..ef15b9401 100644 --- a/relayapi/api/api.go +++ b/relayapi/api/api.go @@ -49,7 +49,7 @@ type RelayServerAPI interface { type PerformRelayServerSyncRequest struct { UserID gomatrixserverlib.UserID `json:"user_id"` - RelayServer gomatrixserverlib.ServerName `json:"relay_name"` + RelayServer gomatrixserverlib.ServerName `json:"relay_server"` } type PerformRelayServerSyncResponse struct { @@ -72,10 +72,12 @@ type PerformStoreAsyncResponse struct { } type QueryAsyncTransactionsRequest struct { - UserID gomatrixserverlib.UserID `json:"user_id"` + UserID gomatrixserverlib.UserID `json:"user_id"` + PreviousEntry gomatrixserverlib.RelayEntry `json:"prev_entry,omitempty"` } type QueryAsyncTransactionsResponse struct { - Txn gomatrixserverlib.Transaction `json:"transaction"` - RemainingCount uint32 `json:"remaining"` + Txn gomatrixserverlib.Transaction `json:"transaction"` + EntryID int64 `json:"entry_id"` + EntriesQueued bool `json:entries_queued` } diff --git a/relayapi/internal/perform.go b/relayapi/internal/perform.go index 955941a10..aff914b3f 100644 --- a/relayapi/internal/perform.go +++ b/relayapi/internal/perform.go @@ -30,20 +30,25 @@ func (r *RelayInternalAPI) PerformRelayServerSync( request *api.PerformRelayServerSyncRequest, response *api.PerformRelayServerSyncResponse, ) error { - asyncResponse, err := r.fedClient.GetAsyncEvents(ctx, request.UserID, request.RelayServer) + prevEntry := gomatrixserverlib.RelayEntry{EntryID: -1} + asyncResponse, err := r.fedClient.GetAsyncEvents(ctx, request.UserID, prevEntry, request.RelayServer) if err != nil { logrus.Errorf("GetAsyncEvents: %s", err.Error()) return err } - r.processTransaction(&asyncResponse.Transaction) + r.processTransaction(&asyncResponse.Txn) - for asyncResponse.Remaining > 0 { - asyncResponse, err := r.fedClient.GetAsyncEvents(ctx, request.UserID, request.RelayServer) + for asyncResponse.EntriesQueued { + logrus.Info("Retrieving next entry from relay") + logrus.Infof("Previous entry: %v", prevEntry) + asyncResponse, err = r.fedClient.GetAsyncEvents(ctx, request.UserID, prevEntry, request.RelayServer) + prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID} + logrus.Infof("New previous entry: %v", prevEntry) if err != nil { logrus.Errorf("GetAsyncEvents: %s", err.Error()) return err } - r.processTransaction(&asyncResponse.Transaction) + r.processTransaction(&asyncResponse.Txn) } return nil @@ -58,6 +63,7 @@ func (r *RelayInternalAPI) PerformStoreAsync( logrus.Warnf("Storing transaction for %v", request.UserID) receipt, err := r.db.StoreAsyncTransaction(ctx, request.Txn) if err != nil { + logrus.Errorf("db.StoreAsyncTransaction: %s", err.Error()) return err } err = r.db.AssociateAsyncTransactionWithDestinations( @@ -77,36 +83,37 @@ func (r *RelayInternalAPI) QueryAsyncTransactions( request *api.QueryAsyncTransactionsRequest, response *api.QueryAsyncTransactionsResponse, ) error { - logrus.Warnf("Obtaining transaction for %v", request.UserID) - transaction, receipt, err := r.db.GetAsyncTransaction(ctx, request.UserID) - if err != nil { - return err - } - - // TODO : Shouldn't be deleting unless the transaction was successfully returned... - // TODO : Should delete transaction json from table if no more associations - // Maybe track last received transaction, and send that as part of the request, - // then delete before getting the new events from the db. - if transaction != nil && receipt != nil { - err = r.db.CleanAsyncTransactions(ctx, request.UserID, []*shared.Receipt{receipt}) + logrus.Infof("QueryAsyncTransactions for %s", request.UserID.Raw()) + if request.PreviousEntry.EntryID >= 0 { + logrus.Infof("Cleaning previous entry (%v) from db for %s", + request.PreviousEntry.EntryID, + request.UserID.Raw(), + ) + prevReceipt := shared.NewReceipt(request.PreviousEntry.EntryID) + err := r.db.CleanAsyncTransactions(ctx, request.UserID, []*shared.Receipt{&prevReceipt}) if err != nil { + logrus.Errorf("db.CleanAsyncTransactions: %s", err.Error()) return err } - - // TODO : Clean async transactions json } - // TODO : These db calls should happen at the same time right? - count, err := r.db.GetAsyncTransactionCount(ctx, request.UserID) + transaction, receipt, err := r.db.GetAsyncTransaction(ctx, request.UserID) if err != nil { + logrus.Errorf("db.GetAsyncTransaction: %s", err.Error()) return err } - response.RemainingCount = uint32(count) - if transaction != nil { + if transaction != nil && receipt != nil { + logrus.Infof("Obtained transaction (%v) for %s", transaction.TransactionID, request.UserID.Raw()) response.Txn = *transaction - logrus.Warnf("Obtained transaction: %v", transaction.TransactionID) + response.EntryID = receipt.GetNID() + response.EntriesQueued = true + } else { + logrus.Infof("No more entries in the queue for %s", request.UserID.Raw()) + response.EntryID = -1 + response.EntriesQueued = false } + return nil } diff --git a/relayapi/routing/asyncevents.go b/relayapi/routing/asyncevents.go index a86ae05ef..1b97c5c0a 100644 --- a/relayapi/routing/asyncevents.go +++ b/relayapi/routing/asyncevents.go @@ -1,6 +1,7 @@ package routing import ( + "encoding/json" "net/http" "github.com/matrix-org/dendrite/relayapi/api" @@ -10,8 +11,9 @@ import ( ) type AsyncEventsResponse struct { - Transaction gomatrixserverlib.Transaction `json:"transaction"` - Remaining uint32 `json:"remaining"` + Txn gomatrixserverlib.Transaction `json:"transaction"` + EntryID int64 `json:"entry_id,omitempty"` + EntriesQueued bool `json:"entries_queued"` } // GetAsyncEvents implements /_matrix/federation/v1/async_events/{userID} @@ -22,9 +24,27 @@ func GetAsyncEvents( relayAPI api.RelayInternalAPI, userID gomatrixserverlib.UserID, ) util.JSONResponse { - logrus.Infof("Handling async_events for %v", userID) + logrus.Infof("Handling async_events for %s", userID.Raw()) + + entryProvided := false + var previousEntry gomatrixserverlib.RelayEntry + if err := json.Unmarshal(fedReq.Content(), &previousEntry); err == nil { + logrus.Infof("Previous entry provided: %v", previousEntry.EntryID) + entryProvided = true + } + + request := api.QueryAsyncTransactionsRequest{ + UserID: userID, + PreviousEntry: gomatrixserverlib.RelayEntry{EntryID: -1}, + } + if entryProvided { + request.PreviousEntry = previousEntry + } var response api.QueryAsyncTransactionsResponse - err := relayAPI.QueryAsyncTransactions(httpReq.Context(), &api.QueryAsyncTransactionsRequest{UserID: userID}, &response) + err := relayAPI.QueryAsyncTransactions( + httpReq.Context(), + &request, + &response) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -34,8 +54,9 @@ func GetAsyncEvents( return util.JSONResponse{ Code: http.StatusOK, JSON: AsyncEventsResponse{ - Transaction: response.Txn, - Remaining: response.RemainingCount, + Txn: response.Txn, + EntryID: response.EntryID, + EntriesQueued: response.EntriesQueued, }, } } diff --git a/relayapi/routing/asyncevents_test.go b/relayapi/routing/asyncevents_test.go index 6f2ff10a7..7d65a8085 100644 --- a/relayapi/routing/asyncevents_test.go +++ b/relayapi/routing/asyncevents_test.go @@ -14,6 +14,19 @@ import ( "github.com/stretchr/testify/assert" ) +func createAsyncQuery( + userID gomatrixserverlib.UserID, + prevEntry gomatrixserverlib.RelayEntry, + relayServer gomatrixserverlib.ServerName, +) gomatrixserverlib.FederationRequest { + var federationPathPrefixV1 = "/_matrix/federation/v1" + path := federationPathPrefixV1 + "/async_events/" + userID.Raw() + request := gomatrixserverlib.NewFederationRequest("GET", userID.Domain(), relayServer, path) + request.SetContent(prevEntry) + + return request +} + func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) { testDB := storage.NewFakeRelayDatabase() db := shared.Database{ @@ -37,12 +50,16 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) { &db, nil, nil, nil, nil, false, "", ) - response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID) + request := createAsyncQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: -1}, "relay") + response := routing.GetAsyncEvents(httpReq, &request, &relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) jsonResponse := response.JSON.(routing.AsyncEventsResponse) - assert.Equal(t, uint32(0), jsonResponse.Remaining) - assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) + assert.Equal(t, false, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Txn) + + count, err := db.GetAsyncTransactionCount(context.Background(), *userID) + assert.Equal(t, count, int64(0)) } func TestGetAsyncReturnsSavedTransaction(t *testing.T) { @@ -58,7 +75,6 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) { t.Fatalf("Invalid userID: %s", err.Error()) } transaction := createTransaction() - receipt, err := db.StoreAsyncTransaction(context.Background(), transaction) if err != nil { t.Fatalf("Failed to store transaction: %s", err.Error()) @@ -78,12 +94,25 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) { &db, nil, nil, nil, nil, false, "", ) - response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID) + request := createAsyncQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: -1}, "relay") + response := routing.GetAsyncEvents(httpReq, &request, &relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) jsonResponse := response.JSON.(routing.AsyncEventsResponse) - assert.Equal(t, uint32(0), jsonResponse.Remaining) - assert.Equal(t, transaction, jsonResponse.Transaction) + assert.Equal(t, true, jsonResponse.EntriesQueued) + assert.Equal(t, transaction, jsonResponse.Txn) + + // And once more to clear the queue + request = createAsyncQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}, "relay") + response = routing.GetAsyncEvents(httpReq, &request, &relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(routing.AsyncEventsResponse) + assert.Equal(t, false, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Txn) + + count, err := db.GetAsyncTransactionCount(context.Background(), *userID) + assert.Equal(t, count, int64(0)) } func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) { @@ -135,17 +164,31 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) { &db, nil, nil, nil, nil, false, "", ) - response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID) + request := createAsyncQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: -1}, "relay") + response := routing.GetAsyncEvents(httpReq, &request, &relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) jsonResponse := response.JSON.(routing.AsyncEventsResponse) - assert.Equal(t, uint32(1), jsonResponse.Remaining) - assert.Equal(t, transaction, jsonResponse.Transaction) + assert.Equal(t, true, jsonResponse.EntriesQueued) + assert.Equal(t, transaction, jsonResponse.Txn) - response = routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID) + request = createAsyncQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}, "relay") + response = routing.GetAsyncEvents(httpReq, &request, &relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) jsonResponse = response.JSON.(routing.AsyncEventsResponse) - assert.Equal(t, uint32(0), jsonResponse.Remaining) - assert.Equal(t, transaction2, jsonResponse.Transaction) + assert.Equal(t, true, jsonResponse.EntriesQueued) + assert.Equal(t, transaction2, jsonResponse.Txn) + + // And once more to clear the queue + request = createAsyncQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}, "relay") + response = routing.GetAsyncEvents(httpReq, &request, &relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(routing.AsyncEventsResponse) + assert.Equal(t, false, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Txn) + + count, err := db.GetAsyncTransactionCount(context.Background(), *userID) + assert.Equal(t, count, int64(0)) } diff --git a/relayapi/routing/forwardasync.go b/relayapi/routing/forwardasync.go index 9f078da7e..ac9390519 100644 --- a/relayapi/routing/forwardasync.go +++ b/relayapi/routing/forwardasync.go @@ -8,6 +8,7 @@ import ( "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) // ForwardAsync implements /_matrix/federation/v1/forward_async/{txnID}/{userID} @@ -25,12 +26,13 @@ func ForwardAsync( } if err := json.Unmarshal(fedReq.Content(), &txnEvents); err != nil { - println("The request body could not be decoded into valid JSON. " + err.Error()) + logrus.Info("The request body could not be decoded into valid JSON. " + err.Error()) return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } + // Transactions are limited in size; they can have at most 50 PDUs and 100 EDUs. // https://matrix.org/docs/spec/server_server/latest#transactions if len(txnEvents.PDUs) > 50 || len(txnEvents.EDUs) > 100 {