mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-16 18:43:10 -06:00
Handle retrieving async events on request
This commit is contained in:
parent
bfa784b224
commit
56e9b54f31
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/federationapi/consumers"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/version"
|
||||
)
|
||||
|
|
@ -700,6 +701,30 @@ func (r *FederationInternalAPI) QueryAsyncTransactions(
|
|||
request *api.QueryAsyncTransactionsRequest,
|
||||
response *api.QueryAsyncTransactionsResponse,
|
||||
) error {
|
||||
transaction, receipt, err := r.db.GetAsyncTransaction(ctx, request.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO : Shouldn't be deleting unless the transaction was successfully returned...
|
||||
// TODO : Should delete transaction json from table if no more associations
|
||||
if transaction != nil && receipt != nil {
|
||||
err = r.db.CleanAsyncTransactions(ctx, request.UserID, []*shared.Receipt{receipt})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// TODO : These db calls should happen at the same time right?
|
||||
count, err := r.db.GetAsyncTransactionCount(ctx, request.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response.RemainingCount = uint32(count)
|
||||
if transaction != nil {
|
||||
response.Txn = *transaction
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1250,7 +1250,7 @@ func TestSendPDUOnAsyncSuccessRemovedFromDB(t *testing.T) {
|
|||
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
|
||||
|
||||
assumedOffline, _ := db.IsServerAssumedOffline(destination)
|
||||
assert.Equal(t, assumedOffline, true)
|
||||
assert.Equal(t, true, assumedOffline)
|
||||
}
|
||||
|
||||
func TestSendEDUOnAsyncSuccessRemovedFromDB(t *testing.T) {
|
||||
|
|
@ -1289,5 +1289,5 @@ func TestSendEDUOnAsyncSuccessRemovedFromDB(t *testing.T) {
|
|||
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
|
||||
|
||||
assumedOffline, _ := db.IsServerAssumedOffline(destination)
|
||||
assert.Equal(t, assumedOffline, true)
|
||||
assert.Equal(t, true, assumedOffline)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ type AsyncEventsResponse struct {
|
|||
// GetAsyncEvents implements /_matrix/federation/v1/async_events/{userID}
|
||||
func GetAsyncEvents(
|
||||
httpReq *http.Request,
|
||||
fedReq *gomatrixserverlib.FederationRequest,
|
||||
fedAPI api.FederationInternalAPI,
|
||||
userID gomatrixserverlib.UserID,
|
||||
) util.JSONResponse {
|
||||
|
|
@ -30,7 +31,7 @@ func GetAsyncEvents(
|
|||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: AsyncEventsResponse{
|
||||
Transaction: gomatrixserverlib.Transaction{},
|
||||
Transaction: response.Txn,
|
||||
Remaining: response.RemainingCount,
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,10 +37,115 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) {
|
|||
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
|
||||
response := routing.GetAsyncEvents(httpReq, fedAPI, *userID)
|
||||
assert.Equal(t, response.Code, http.StatusOK)
|
||||
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
|
||||
assert.Equal(t, http.StatusOK, response.Code)
|
||||
|
||||
jsonResponse := response.JSON.(routing.AsyncEventsResponse)
|
||||
assert.Equal(t, jsonResponse.Remaining, uint32(0))
|
||||
assert.Equal(t, jsonResponse.Transaction, gomatrixserverlib.Transaction{})
|
||||
assert.Equal(t, uint32(0), jsonResponse.Remaining)
|
||||
assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction)
|
||||
}
|
||||
|
||||
func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
|
||||
testDB := createDatabase()
|
||||
db := shared.Database{
|
||||
Writer: sqlutil.NewDummyWriter(),
|
||||
FederationQueueTransactions: testDB,
|
||||
FederationTransactionJSON: testDB,
|
||||
}
|
||||
httpReq := &http.Request{}
|
||||
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
|
||||
if err != nil {
|
||||
t.Fatalf("Invalid userID: %s", err.Error())
|
||||
}
|
||||
transaction := createTransaction()
|
||||
|
||||
receipt, err := db.StoreAsyncTransaction(context.Background(), transaction)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to store transaction: %s", err.Error())
|
||||
}
|
||||
err = db.AssociateAsyncTransactionWithDestinations(
|
||||
context.Background(),
|
||||
map[gomatrixserverlib.UserID]struct{}{
|
||||
*userID: {},
|
||||
},
|
||||
transaction.TransactionID,
|
||||
receipt)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to associate transaction with user: %s", err.Error())
|
||||
}
|
||||
|
||||
fedAPI := internal.NewFederationInternalAPI(
|
||||
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
|
||||
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
|
||||
assert.Equal(t, http.StatusOK, response.Code)
|
||||
|
||||
jsonResponse := response.JSON.(routing.AsyncEventsResponse)
|
||||
assert.Equal(t, uint32(0), jsonResponse.Remaining)
|
||||
assert.Equal(t, transaction, jsonResponse.Transaction)
|
||||
}
|
||||
|
||||
func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) {
|
||||
testDB := createDatabase()
|
||||
db := shared.Database{
|
||||
Writer: sqlutil.NewDummyWriter(),
|
||||
FederationQueueTransactions: testDB,
|
||||
FederationTransactionJSON: testDB,
|
||||
}
|
||||
httpReq := &http.Request{}
|
||||
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
|
||||
if err != nil {
|
||||
t.Fatalf("Invalid userID: %s", err.Error())
|
||||
}
|
||||
|
||||
transaction := createTransaction()
|
||||
receipt, err := db.StoreAsyncTransaction(context.Background(), transaction)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to store transaction: %s", err.Error())
|
||||
}
|
||||
err = db.AssociateAsyncTransactionWithDestinations(
|
||||
context.Background(),
|
||||
map[gomatrixserverlib.UserID]struct{}{
|
||||
*userID: {},
|
||||
},
|
||||
transaction.TransactionID,
|
||||
receipt)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to associate transaction with user: %s", err.Error())
|
||||
}
|
||||
|
||||
transaction2 := createTransaction()
|
||||
receipt2, err := db.StoreAsyncTransaction(context.Background(), transaction2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to store transaction: %s", err.Error())
|
||||
}
|
||||
err = db.AssociateAsyncTransactionWithDestinations(
|
||||
context.Background(),
|
||||
map[gomatrixserverlib.UserID]struct{}{
|
||||
*userID: {},
|
||||
},
|
||||
transaction2.TransactionID,
|
||||
receipt2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to associate transaction with user: %s", err.Error())
|
||||
}
|
||||
|
||||
fedAPI := internal.NewFederationInternalAPI(
|
||||
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
|
||||
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
|
||||
assert.Equal(t, http.StatusOK, response.Code)
|
||||
|
||||
jsonResponse := response.JSON.(routing.AsyncEventsResponse)
|
||||
assert.Equal(t, uint32(1), jsonResponse.Remaining)
|
||||
assert.Equal(t, transaction, jsonResponse.Transaction)
|
||||
|
||||
response = routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
|
||||
assert.Equal(t, http.StatusOK, response.Code)
|
||||
|
||||
jsonResponse = response.JSON.(routing.AsyncEventsResponse)
|
||||
assert.Equal(t, uint32(0), jsonResponse.Remaining)
|
||||
assert.Equal(t, transaction2, jsonResponse.Transaction)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -65,8 +65,10 @@ func (d *testDatabase) SelectQueueTransactions(ctx context.Context, txn *sql.Tx,
|
|||
if limit > len(d.associations[serverName]) {
|
||||
resultCount = len(d.associations[serverName])
|
||||
}
|
||||
for i := 0; i < resultCount; i++ {
|
||||
results = append(results, d.associations[serverName][i])
|
||||
if resultCount > 0 {
|
||||
for i := 0; i < resultCount; i++ {
|
||||
results = append(results, d.associations[serverName][i])
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
|
|
@ -174,7 +176,7 @@ func TestUniqueTransactionStoredInDatabase(t *testing.T) {
|
|||
|
||||
response := routing.ForwardAsync(
|
||||
httpReq, &request, fedAPI, inputTransaction.TransactionID, *userID)
|
||||
transaction, err := db.GetAsyncTransaction(context.TODO(), *userID)
|
||||
transaction, _, err := db.GetAsyncTransaction(context.TODO(), *userID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving transaction: %s", err.Error())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -150,6 +150,20 @@ func Setup(
|
|||
},
|
||||
)).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
v1fedmux.Handle("/async_events/{userID}", MakeFedAPI(
|
||||
"federation_async_events", "", cfg.Matrix.IsLocalServerName, keys, wakeup,
|
||||
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
|
||||
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername("Username was invalid"),
|
||||
}
|
||||
}
|
||||
return GetAsyncEvents(httpReq, request, fsAPI, *userID)
|
||||
},
|
||||
)).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI(
|
||||
"federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
|
||||
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
|
||||
|
|
|
|||
|
|
@ -52,9 +52,10 @@ type Database interface {
|
|||
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||
|
||||
StoreAsyncTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*shared.Receipt, error)
|
||||
GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, error)
|
||||
GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error)
|
||||
AssociateAsyncTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *shared.Receipt) error
|
||||
CleanAsyncTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*shared.Receipt) error
|
||||
GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *shared.Receipt, error)
|
||||
GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error)
|
||||
|
||||
// these don't have contexts passed in as we want things to happen regardless of the request context
|
||||
AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error
|
||||
|
|
|
|||
|
|
@ -349,27 +349,52 @@ func (d *Database) AssociateAsyncTransactionWithDestinations(
|
|||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) CleanAsyncTransactions(
|
||||
ctx context.Context,
|
||||
userID gomatrixserverlib.UserID,
|
||||
receipts []*Receipt,
|
||||
) error {
|
||||
println(len(receipts))
|
||||
nids := make([]int64, len(receipts))
|
||||
for i, receipt := range receipts {
|
||||
nids[i] = receipt.nid
|
||||
}
|
||||
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
err := d.FederationQueueTransactions.DeleteQueueTransactions(ctx, txn, userID.Domain(), nids)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("d.insertQueueTransaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) GetAsyncTransaction(
|
||||
ctx context.Context,
|
||||
userID gomatrixserverlib.UserID,
|
||||
) (*gomatrixserverlib.Transaction, error) {
|
||||
) (*gomatrixserverlib.Transaction, *Receipt, error) {
|
||||
nids, err := d.FederationQueueTransactions.SelectQueueTransactions(ctx, nil, userID.Domain(), 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.SelectQueueTransaction: %w", err)
|
||||
return nil, nil, fmt.Errorf("d.SelectQueueTransaction: %w", err)
|
||||
}
|
||||
if len(nids) == 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
txn, err := d.FederationTransactionJSON.SelectTransactionJSON(ctx, nil, nids)
|
||||
txns, err := d.FederationTransactionJSON.SelectTransactionJSON(ctx, nil, nids)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.SelectTransactionJSON: %w", err)
|
||||
return nil, nil, fmt.Errorf("d.SelectTransactionJSON: %w", err)
|
||||
}
|
||||
|
||||
transaction := &gomatrixserverlib.Transaction{}
|
||||
err = json.Unmarshal(txn[nids[0]], transaction)
|
||||
err = json.Unmarshal(txns[nids[0]], transaction)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unmarshall transaction: %w", err)
|
||||
return nil, nil, fmt.Errorf("Unmarshall transaction: %w", err)
|
||||
}
|
||||
|
||||
return transaction, nil
|
||||
receipt := NewReceipt(nids[0])
|
||||
return transaction, &receipt, nil
|
||||
}
|
||||
|
||||
func (d *Database) GetAsyncTransactionCount(
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ func TestShouldInsertAssumedOfflineServer(t *testing.T) {
|
|||
t.Fatalf("Failed retrieving server: %s", err.Error())
|
||||
}
|
||||
|
||||
assert.Equal(t, isOffline, true)
|
||||
assert.Equal(t, true, isOffline)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -85,7 +85,7 @@ func TestShouldDeleteCorrectAssumedOfflineServer(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("Failed retrieving server status: %s", err.Error())
|
||||
}
|
||||
assert.Equal(t, isOffline, true)
|
||||
assert.Equal(t, true, isOffline)
|
||||
|
||||
err = db.Table.DeleteAssumedOffline(ctx, nil, server1)
|
||||
if err != nil {
|
||||
|
|
@ -96,13 +96,13 @@ func TestShouldDeleteCorrectAssumedOfflineServer(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("Failed retrieving server status: %s", err.Error())
|
||||
}
|
||||
assert.Equal(t, isOffline, false)
|
||||
assert.Equal(t, false, isOffline)
|
||||
|
||||
isOffline2, err := db.Table.SelectAssumedOffline(ctx, nil, server2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving server status: %s", err.Error())
|
||||
}
|
||||
assert.Equal(t, isOffline2, true)
|
||||
assert.Equal(t, true, isOffline2)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -125,13 +125,13 @@ func TestShouldDeleteAllAssumedOfflineServers(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("Failed retrieving server status: %s", err.Error())
|
||||
}
|
||||
assert.Equal(t, isOffline, true)
|
||||
assert.Equal(t, true, isOffline)
|
||||
isOffline2, err := db.Table.SelectAssumedOffline(ctx, nil, server2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving server status: %s", err.Error())
|
||||
}
|
||||
|
||||
assert.Equal(t, isOffline2, true)
|
||||
assert.Equal(t, true, isOffline2)
|
||||
|
||||
err = db.Table.DeleteAllAssumedOffline(ctx, nil)
|
||||
if err != nil {
|
||||
|
|
@ -142,11 +142,11 @@ func TestShouldDeleteAllAssumedOfflineServers(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("Failed retrieving server status: %s", err.Error())
|
||||
}
|
||||
assert.Equal(t, isOffline, false)
|
||||
assert.Equal(t, false, isOffline)
|
||||
isOffline2, err = db.Table.SelectAssumedOffline(ctx, nil, server2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving server status: %s", err.Error())
|
||||
}
|
||||
assert.Equal(t, isOffline2, false)
|
||||
assert.Equal(t, false, isOffline2)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -85,8 +85,8 @@ func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) {
|
|||
t.Fatalf("Failed retrieving transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
assert.Equal(t, retrievedNids[0], nid)
|
||||
assert.Equal(t, len(retrievedNids), 1)
|
||||
assert.Equal(t, nid, retrievedNids[0])
|
||||
assert.Equal(t, 1, len(retrievedNids))
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -117,7 +117,7 @@ func TestShouldDeleteQueueTransaction(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
|
||||
}
|
||||
assert.Equal(t, count, int64(0))
|
||||
assert.Equal(t, int64(0), count)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -160,12 +160,12 @@ func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
|
||||
}
|
||||
assert.Equal(t, count, int64(1))
|
||||
assert.Equal(t, int64(1), count)
|
||||
|
||||
count, err = db.Table.SelectQueueTransactionCount(ctx, nil, serverName2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
|
||||
}
|
||||
assert.Equal(t, count, int64(1))
|
||||
assert.Equal(t, int64(1), count)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ func TestShouldRetrieveInsertedTransaction(t *testing.T) {
|
|||
t.Fatalf("Failed retrieving transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
assert.Equal(t, len(storedJSON), 1)
|
||||
assert.Equal(t, 1, len(storedJSON))
|
||||
|
||||
var storedTx gomatrixserverlib.Transaction
|
||||
json.Unmarshal(storedJSON[1], &storedTx)
|
||||
|
|
@ -156,6 +156,6 @@ func TestShouldDeleteTransaction(t *testing.T) {
|
|||
t.Fatalf("Failed retrieving transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
assert.Equal(t, len(storedJSON), 0)
|
||||
assert.Equal(t, 0, len(storedJSON))
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue