diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 362333fc9..0d83e36b3 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -30,6 +30,12 @@ type FederationInternalAPI interface { request *PerformBroadcastEDURequest, response *PerformBroadcastEDUResponse, ) error + + PerformStoreAsync( + ctx context.Context, + request *PerformStoreAsyncRequest, + response *PerformStoreAsyncResponse, + ) error } type ClientFederationAPI interface { @@ -213,6 +219,14 @@ type PerformBroadcastEDURequest struct { type PerformBroadcastEDUResponse struct { } +type PerformStoreAsyncRequest struct { + Txn gomatrixserverlib.Transaction `json:"transaction"` + UserID gomatrixserverlib.UserID `json:"user_id"` +} + +type PerformStoreAsyncResponse struct { +} + type InputPublicKeysRequest struct { Keys map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"keys"` } diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 1b61ec711..970f2fc39 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -648,6 +648,27 @@ func (r *FederationInternalAPI) PerformBroadcastEDU( return nil } +// PerformStoreAsync implements api.FederationInternalAPI +func (r *FederationInternalAPI) PerformStoreAsync( + ctx context.Context, + request *api.PerformStoreAsyncRequest, + response *api.PerformStoreAsyncResponse, +) error { + receipt, err := r.db.StoreAsyncTransaction(ctx, request.Txn) + if err != nil { + return err + } + err = r.db.AssociateAsyncTransactionWithDestinations( + ctx, + map[gomatrixserverlib.UserID]struct{}{ + request.UserID: {}, + }, + request.Txn.TransactionID, + receipt) + + return err +} + func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) { for _, srv := range destinations { _ = r.db.RemoveServerFromBlacklist(srv) diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index 812d3c6da..9fdb30cef 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -23,6 +23,7 @@ const ( FederationAPIPerformInviteRequestPath = "/federationapi/performInviteRequest" FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest" FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" + FederationAPIPerformStoreAsyncPath = "/federationapi/performStoreAsync" FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" @@ -150,6 +151,17 @@ func (h *httpFederationInternalAPI) PerformBroadcastEDU( ) } +func (h *httpFederationInternalAPI) PerformStoreAsync( + ctx context.Context, + request *api.PerformStoreAsyncRequest, + response *api.PerformStoreAsyncResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformStoreAsync", h.federationAPIURL+FederationAPIPerformStoreAsyncPath, + h.httpClient, ctx, request, response, + ) +} + type getUserDevices struct { S gomatrixserverlib.ServerName UserID string diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 7aa0e4801..cc55fab64 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -43,6 +43,11 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU), ) + internalAPIMux.Handle( + FederationAPIPerformStoreAsyncPath, + httputil.MakeInternalRPCAPI("FederationAPIPerformStoreAsync", intAPI.PerformStoreAsync), + ) + internalAPIMux.Handle( FederationAPIPerformJoinRequestPath, httputil.MakeInternalRPCAPI( diff --git a/federationapi/routing/forwardasync.go b/federationapi/routing/forwardasync.go index e9a46d35e..9b9bb307a 100644 --- a/federationapi/routing/forwardasync.go +++ b/federationapi/routing/forwardasync.go @@ -1,8 +1,10 @@ package routing import ( + "encoding/json" "net/http" + "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -13,12 +15,56 @@ func ForwardAsync( httpReq *http.Request, fedReq *gomatrixserverlib.FederationRequest, fedAPI api.FederationInternalAPI, - txnId gomatrixserverlib.TransactionID, + txnID gomatrixserverlib.TransactionID, userID gomatrixserverlib.UserID, ) util.JSONResponse { + var txnEvents struct { + PDUs []json.RawMessage `json:"pdus"` + EDUs []gomatrixserverlib.EDU `json:"edus"` + } - // TODO: wrap in fedAPI call - // fedAPI.db.AssociateAsyncTransactionWithDestinations(context.TODO(), userID, nil) + if err := json.Unmarshal(fedReq.Content(), &txnEvents); err != nil { + println("The request body could not be decoded into valid JSON. " + err.Error()) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), + } + } + // Transactions are limited in size; they can have at most 50 PDUs and 100 EDUs. + // https://matrix.org/docs/spec/server_server/latest#transactions + if len(txnEvents.PDUs) > 50 || len(txnEvents.EDUs) > 100 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("max 50 pdus / 100 edus"), + } + } + + t := gomatrixserverlib.Transaction{} + t.PDUs = txnEvents.PDUs + t.EDUs = txnEvents.EDUs + t.Origin = fedReq.Origin() + t.TransactionID = txnID + t.Destination = userID.Domain() + + util.GetLogger(httpReq.Context()).Debugf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, fedReq.Origin(), len(t.PDUs), len(t.EDUs)) + + req := api.PerformStoreAsyncRequest{ + Txn: t, + UserID: userID, + } + res := api.PerformStoreAsyncResponse{} + err := fedAPI.PerformStoreAsync(httpReq.Context(), &req, &res) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("could not store the transaction for forwarding"), + } + } + + // Naming: + // mailServer? assign mailserver for user? + // configure my mailserver + // Homeserver, idendity server, mailserver... why not? return util.JSONResponse{Code: 200} -} \ No newline at end of file +} diff --git a/federationapi/routing/forwardasync_test.go b/federationapi/routing/forwardasync_test.go index 4c06a86d4..4b330b37f 100644 --- a/federationapi/routing/forwardasync_test.go +++ b/federationapi/routing/forwardasync_test.go @@ -1,27 +1,146 @@ package routing_test import ( - // "context" + "context" + "database/sql" + "encoding/json" + "fmt" "net/http" + "sync" "testing" + "time" "github.com/matrix-org/dendrite/federationapi/internal" "github.com/matrix-org/dendrite/federationapi/routing" - // "github.com/matrix-org/dendrite/federationapi/storage/shared" - // "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") + testDestination = gomatrixserverlib.ServerName("white.orchard") +) + +type testDatabase struct { + nid int64 + nidMutex sync.Mutex + transactions map[int64]json.RawMessage + associations map[gomatrixserverlib.ServerName][]int64 +} + +func createDatabase() *testDatabase { + return &testDatabase{ + nid: 1, + nidMutex: sync.Mutex{}, + transactions: make(map[int64]json.RawMessage), + associations: make(map[gomatrixserverlib.ServerName][]int64), + } +} + +func (d *testDatabase) InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error { + if _, ok := d.associations[serverName]; !ok { + d.associations[serverName] = []int64{} + } + d.associations[serverName] = append(d.associations[serverName], nid) + return nil +} + +func (d *testDatabase) DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error { + for _, nid := range jsonNIDs { + for index, associatedNID := range d.associations[serverName] { + if associatedNID == nid { + d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...) + } + } + } + + return nil +} + +func (d *testDatabase) SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) { + results := []int64{} + resultCount := limit + if limit > len(d.associations[serverName]) { + resultCount = len(d.associations[serverName]) + } + for i := 0; i < resultCount; i++ { + results = append(results, d.associations[serverName][i]) + } + + return results, nil +} + +func (d *testDatabase) SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) { + return int64(len(d.associations[serverName])), nil +} + +func (d *testDatabase) InsertTransactionJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) { + d.nidMutex.Lock() + defer d.nidMutex.Unlock() + + nid := d.nid + d.transactions[nid] = []byte(json) + d.nid++ + + return nid, nil +} + +func (d *testDatabase) DeleteTransactionJSON(ctx context.Context, txn *sql.Tx, nids []int64) error { + for _, nid := range nids { + delete(d.transactions, nid) + } + + return nil +} + +func (d *testDatabase) SelectTransactionJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) { + result := make(map[int64][]byte) + for _, nid := range jsonNIDs { + if transaction, ok := d.transactions[nid]; ok { + result[nid] = transaction + } + } + + return result, nil +} + +func createFederationRequest(userID gomatrixserverlib.UserID) (gomatrixserverlib.Transaction, gomatrixserverlib.FederationRequest) { + txn := gomatrixserverlib.Transaction{} + txn.PDUs = []json.RawMessage{ + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + } + txn.Origin = testOrigin + txn.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + txn.Destination = testDestination + var federationPathPrefixV1 = "/_matrix/federation/v1" + path := federationPathPrefixV1 + "/forward_async/" + string(txn.TransactionID) + "/" + userID.Raw() + request := gomatrixserverlib.NewFederationRequest("PUT", txn.Destination, path) + request.SetContent(txn) + + return txn, request +} + func TestEmptyForwardReturnsOk(t *testing.T) { + testDB := createDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + FederationQueueTransactions: testDB, + FederationTransactionJSON: testDB, + } httpReq := &http.Request{} - request := &gomatrixserverlib.FederationRequest{} - fedAPI := internal.FederationInternalAPI{} userID, err := gomatrixserverlib.NewUserID("@local:domain", false) if err != nil { t.Fatalf("Invalid userID: %s", err.Error()) } + _, request := createFederationRequest(*userID) - response := routing.ForwardAsync(httpReq, request, &fedAPI, "1", *userID) + fedAPI := internal.NewFederationInternalAPI( + &db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil, + ) + + response := routing.ForwardAsync(httpReq, &request, fedAPI, "1", *userID) expected := 200 if response.Code != expected { @@ -29,36 +148,45 @@ func TestEmptyForwardReturnsOk(t *testing.T) { } } -// func TestUniqueTransactionStoredInDatabase(t *testing.T) { -// db := shared.Database{} -// httpReq := &http.Request{} -// inputTransaction := gomatrixserverlib.Transaction{} -// request := &gomatrixserverlib.FederationRequest{} -// fedAPI := internal.NewFederationInternalAPI( -// &db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil, -// ) -// userID, err := gomatrixserverlib.NewUserID("@local:domain", false) -// if err != nil { -// t.Fatalf("Invalid userID: %s", err.Error()) -// } +func TestUniqueTransactionStoredInDatabase(t *testing.T) { + testDB := createDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + FederationQueueTransactions: testDB, + FederationTransactionJSON: testDB, + } -// response := routing.ForwardAsync(httpReq, request, fedAPI, "1", *userID) -// transaction, err := db.GetAsyncTransaction(context.TODO(), *userID) -// transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID) + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + if err != nil { + t.Fatalf("Invalid userID: %s", err.Error()) + } + inputTransaction, request := createFederationRequest(*userID) -// expected := 200 -// if response.Code != expected { -// t.Fatalf("Expected Return Code: %v, Actual: %v", expected, response.Code) -// } -// if transactionCount != 1 { -// t.Fatalf("Expected count of 1, Actual: %d", transactionCount) -// } -// if transaction.TransactionID != inputTransaction.TransactionID { -// t.Fatalf("Expected Transaction ID: %s, Actual: %s", -// inputTransaction.TransactionID, transaction.TransactionID) -// } -// } + fedAPI := internal.NewFederationInternalAPI( + &db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil, + ) -// func TestDuplicateTransactionNotStoredInDatabase(t *testing.T) { + response := routing.ForwardAsync( + httpReq, &request, fedAPI, inputTransaction.TransactionID, *userID) + transaction, err := db.GetAsyncTransaction(context.TODO(), *userID) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } -// } \ No newline at end of file + expected := 200 + if response.Code != expected { + t.Fatalf("Expected Return Code: %v, Actual: %v", expected, response.Code) + } + if transactionCount != 1 { + t.Fatalf("Expected count of 1, Actual: %d", transactionCount) + } + if transaction.TransactionID != inputTransaction.TransactionID { + t.Fatalf("Expected Transaction ID: %s, Actual: %s", + inputTransaction.TransactionID, transaction.TransactionID) + } +} diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 9f16e5093..f6ef64444 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -133,6 +133,23 @@ func Setup( }, )).Methods(http.MethodPut, http.MethodOptions) + v1fedmux.Handle("/forward_async/{txnID}/{userID}", MakeFedAPI( + "federation_forward_async", "", cfg.Matrix.IsLocalServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Username was invalid"), + } + } + return ForwardAsync( + httpReq, request, fsAPI, gomatrixserverlib.TransactionID(vars["txnID"]), + *userID, + ) + }, + )).Methods(http.MethodPut, http.MethodOptions) + v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 4540a65fe..2597f0830 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -51,9 +51,10 @@ type Database interface { GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) - GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (gomatrixserverlib.Transaction, error) + StoreAsyncTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*shared.Receipt, error) + GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, error) GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error) - AssociateAsyncTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, receipt *shared.Receipt) error + AssociateAsyncTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *shared.Receipt) error // these don't have contexts passed in as we want things to happen regardless of the request context AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index a33fa4a43..8e38603ed 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -54,10 +54,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + queueTransactions, err := NewPostgresQueueTransactionsTable(d.db) + if err != nil { + return nil, err + } queueJSON, err := NewPostgresQueueJSONTable(d.db) if err != nil { return nil, err } + transactionJSON, err := NewPostgresTransactionJSONTable(d.db) + if err != nil { + return nil, err + } blacklist, err := NewPostgresBlacklistTable(d.db) if err != nil { return nil, err @@ -95,20 +103,22 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return nil, err } d.Database = shared.Database{ - DB: d.db, - IsLocalServerName: isLocalServerName, - Cache: cache, - Writer: d.writer, - FederationJoinedHosts: joinedHosts, - FederationQueuePDUs: queuePDUs, - FederationQueueEDUs: queueEDUs, - FederationQueueJSON: queueJSON, - FederationBlacklist: blacklist, - FederationInboundPeeks: inboundPeeks, - FederationOutboundPeeks: outboundPeeks, - NotaryServerKeysJSON: notaryJSON, - NotaryServerKeysMetadata: notaryMetadata, - ServerSigningKeys: serverSigningKeys, + DB: d.db, + IsLocalServerName: isLocalServerName, + Cache: cache, + Writer: d.writer, + FederationJoinedHosts: joinedHosts, + FederationQueuePDUs: queuePDUs, + FederationQueueEDUs: queueEDUs, + FederationQueueJSON: queueJSON, + FederationQueueTransactions: queueTransactions, + FederationTransactionJSON: transactionJSON, + FederationBlacklist: blacklist, + FederationInboundPeeks: inboundPeeks, + FederationOutboundPeeks: outboundPeeks, + NotaryServerKeysJSON: notaryJSON, + NotaryServerKeysMetadata: notaryMetadata, + ServerSigningKeys: serverSigningKeys, } return &d, nil } diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index ab5d66ac8..dae25b197 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -17,6 +17,7 @@ package shared import ( "context" "database/sql" + "encoding/json" "fmt" "time" @@ -27,27 +28,23 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -type transactionEntry struct { - transaction gomatrixserverlib.Transaction - userID []gomatrixserverlib.UserID -} - type Database struct { - DB *sql.DB - IsLocalServerName func(gomatrixserverlib.ServerName) bool - Cache caching.FederationCache - Writer sqlutil.Writer - FederationQueuePDUs tables.FederationQueuePDUs - FederationQueueEDUs tables.FederationQueueEDUs - FederationQueueJSON tables.FederationQueueJSON - FederationJoinedHosts tables.FederationJoinedHosts - FederationBlacklist tables.FederationBlacklist - FederationOutboundPeeks tables.FederationOutboundPeeks - FederationInboundPeeks tables.FederationInboundPeeks - NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON - NotaryServerKeysMetadata tables.FederationNotaryServerKeysMetadata - ServerSigningKeys tables.FederationServerSigningKeys - transactionDB map[Receipt]transactionEntry + DB *sql.DB + IsLocalServerName func(gomatrixserverlib.ServerName) bool + Cache caching.FederationCache + Writer sqlutil.Writer + FederationQueuePDUs tables.FederationQueuePDUs + FederationQueueEDUs tables.FederationQueueEDUs + FederationQueueJSON tables.FederationQueueJSON + FederationJoinedHosts tables.FederationJoinedHosts + FederationBlacklist tables.FederationBlacklist + FederationOutboundPeeks tables.FederationOutboundPeeks + FederationInboundPeeks tables.FederationInboundPeeks + NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON + NotaryServerKeysMetadata tables.FederationNotaryServerKeysMetadata + ServerSigningKeys tables.FederationServerSigningKeys + FederationQueueTransactions tables.FederationQueueTransactions + FederationTransactionJSON tables.FederationTransactionJSON } // An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. @@ -264,18 +261,43 @@ func (d *Database) GetNotaryKeys( return sks, err } +func (d *Database) StoreAsyncTransaction( + ctx context.Context, txn gomatrixserverlib.Transaction, +) (*Receipt, error) { + var err error + json, err := json.Marshal(txn) + if err != nil { + return nil, fmt.Errorf("d.JSONUnmarshall: %w", err) + } + + var nid int64 + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + nid, err = d.FederationTransactionJSON.InsertTransactionJSON(ctx, txn, string(json)) + return err + }) + if err != nil { + return nil, fmt.Errorf("d.insertTransactionJSON: %w", err) + } + return &Receipt{ + nid: nid, + }, nil +} + func (d *Database) AssociateAsyncTransactionWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, + transactionID gomatrixserverlib.TransactionID, receipt *Receipt, ) error { - if transaction, ok := d.transactionDB[*receipt]; ok { - for k := range destinations { - transaction.userID = append(transaction.userID, k) + for destination := range destinations { + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + err := d.FederationQueueTransactions.InsertQueueTransaction( + ctx, txn, transactionID, destination.Domain(), receipt.nid) + return err + }) + if err != nil { + return fmt.Errorf("d.insertQueueTransaction: %w", err) } - d.transactionDB[*receipt] = transaction - } else { - return fmt.Errorf("No transactions exist with that NID") } return nil @@ -284,21 +306,33 @@ func (d *Database) AssociateAsyncTransactionWithDestinations( func (d *Database) GetAsyncTransaction( ctx context.Context, userID gomatrixserverlib.UserID, -) (gomatrixserverlib.Transaction, error) { - return gomatrixserverlib.Transaction{}, nil +) (*gomatrixserverlib.Transaction, error) { + nids, err := d.FederationQueueTransactions.SelectQueueTransactions(ctx, nil, userID.Domain(), 1) + if err != nil { + return nil, fmt.Errorf("d.SelectQueueTransaction: %w", err) + } + + txn, err := d.FederationTransactionJSON.SelectTransactionJSON(ctx, nil, nids) + if err != nil { + return nil, fmt.Errorf("d.SelectTransactionJSON: %w", err) + } + + transaction := &gomatrixserverlib.Transaction{} + err = json.Unmarshal(txn[nids[0]], transaction) + if err != nil { + return nil, fmt.Errorf("Unmarshall transaction: %w", err) + } + + return transaction, nil } func (d *Database) GetAsyncTransactionCount( ctx context.Context, userID gomatrixserverlib.UserID, ) (int64, error) { - count := int64(0) - for _, transaction := range d.transactionDB { - for _, user := range transaction.userID { - if user == userID { - count++ - } - } + count, err := d.FederationQueueTransactions.SelectQueueTransactionCount(ctx, nil, userID.Domain()) + if err != nil { + return 0, fmt.Errorf("d.SelectQueueTransactionCount: %w", err) } return count, nil } diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index e86ac817b..e8fa9a0b6 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -1,11 +1,5 @@ // Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -57,6 +51,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + queueTransactions, err := NewSQLiteQueueTransactionsTable(d.db) + if err != nil { + return nil, err + } + transactionJSON, err := NewSQLiteTransactionJSONTable(d.db) + if err != nil { + return nil, err + } blacklist, err := NewSQLiteBlacklistTable(d.db) if err != nil { return nil, err @@ -94,20 +96,22 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return nil, err } d.Database = shared.Database{ - DB: d.db, - IsLocalServerName: isLocalServerName, - Cache: cache, - Writer: d.writer, - FederationJoinedHosts: joinedHosts, - FederationQueuePDUs: queuePDUs, - FederationQueueEDUs: queueEDUs, - FederationQueueJSON: queueJSON, - FederationBlacklist: blacklist, - FederationOutboundPeeks: outboundPeeks, - FederationInboundPeeks: inboundPeeks, - NotaryServerKeysJSON: notaryKeys, - NotaryServerKeysMetadata: notaryKeysMetadata, - ServerSigningKeys: serverSigningKeys, + DB: d.db, + IsLocalServerName: isLocalServerName, + Cache: cache, + Writer: d.writer, + FederationJoinedHosts: joinedHosts, + FederationQueuePDUs: queuePDUs, + FederationQueueEDUs: queueEDUs, + FederationQueueJSON: queueJSON, + FederationQueueTransactions: queueTransactions, + FederationTransactionJSON: transactionJSON, + FederationBlacklist: blacklist, + FederationOutboundPeeks: outboundPeeks, + FederationInboundPeeks: inboundPeeks, + NotaryServerKeysJSON: notaryKeys, + NotaryServerKeysMetadata: notaryKeysMetadata, + ServerSigningKeys: serverSigningKeys, } return &d, nil } diff --git a/federationapi/storage/tables/transaction_json_table_test.go b/federationapi/storage/tables/transaction_json_table_test.go index 6ebeff508..7fbf9bb62 100644 --- a/federationapi/storage/tables/transaction_json_table_test.go +++ b/federationapi/storage/tables/transaction_json_table_test.go @@ -23,7 +23,7 @@ const ( testDestination = gomatrixserverlib.ServerName("white.orchard") ) -func mustCreateTransaction(userID gomatrixserverlib.UserID) gomatrixserverlib.Transaction { +func mustCreateTransaction() gomatrixserverlib.Transaction { txn := gomatrixserverlib.Transaction{} txn.PDUs = []json.RawMessage{ []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), @@ -73,11 +73,7 @@ func TestShoudInsertTransaction(t *testing.T) { db, close := mustCreateTransactionJSONTable(t, dbType) defer close() - userID, err := gomatrixserverlib.NewUserID("@local:domain", false) - if err != nil { - t.Fatalf("Invalid userID: %s", err.Error()) - } - transaction := mustCreateTransaction(*userID) + transaction := mustCreateTransaction() tx, err := json.Marshal(transaction) if err != nil { t.Fatalf("Invalid transaction: %s", err.Error()) @@ -96,11 +92,7 @@ func TestShouldRetrieveInsertedTransaction(t *testing.T) { db, close := mustCreateTransactionJSONTable(t, dbType) defer close() - userID, err := gomatrixserverlib.NewUserID("@local:domain", false) - if err != nil { - t.Fatalf("Invalid userID: %s", err.Error()) - } - transaction := mustCreateTransaction(*userID) + transaction := mustCreateTransaction() tx, err := json.Marshal(transaction) if err != nil { t.Fatalf("Invalid transaction: %s", err.Error()) @@ -135,11 +127,7 @@ func TestShouldDeleteTransaction(t *testing.T) { db, close := mustCreateTransactionJSONTable(t, dbType) defer close() - userID, err := gomatrixserverlib.NewUserID("@local:domain", false) - if err != nil { - t.Fatalf("Invalid userID: %s", err.Error()) - } - transaction := mustCreateTransaction(*userID) + transaction := mustCreateTransaction() tx, err := json.Marshal(transaction) if err != nil { t.Fatalf("Invalid transaction: %s", err.Error()) diff --git a/go.mod b/go.mod index 0b40b8d1f..24633d5e8 100644 --- a/go.mod +++ b/go.mod @@ -143,4 +143,4 @@ require ( go 1.18 -replace github.com/matrix-org/gomatrixserverlib => ../../gomatrixserverlib/mailbox +replace github.com/matrix-org/gomatrixserverlib => github.com/matrix-org/gomatrixserverlib v0.0.0-20221110204444-22af9cae40c5 diff --git a/go.sum b/go.sum index 5d172e86b..3760356d2 100644 --- a/go.sum +++ b/go.sum @@ -348,8 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e h1:6I34fdyiHMRCxL6GOb/G8ZyI1WWlb6ZxCF2hIGSMSCc= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221110204444-22af9cae40c5 h1:06o3BPKc0CeYK6rOn/tzP9SZKTDAE2zF4AmWmZei1CU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221110204444-22af9cae40c5/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 h1:CQWFrgH9TJOU2f2qCDhGwaSdAnmgSu3/f+2xcf/Fse4= github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=