diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 8e9af36b0..3bac6e90a 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -458,7 +458,7 @@ func (m *DendriteMonolith) Start() { switch e := event.(type) { case pineconeEvents.PeerAdded: if !relayServerSyncRunning.Load() { - go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning) + // go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning) } case pineconeEvents.PeerRemoved: if relayServerSyncRunning.Load() && m.PineconeRouter.PeerCount(-1) == 0 { @@ -486,46 +486,6 @@ func (m *DendriteMonolith) Start() { }(pineconeEventChannel) } -func (m *DendriteMonolith) syncRelayServers(stop <-chan bool, running atomic.Bool) { - defer running.Store(false) - - t := time.NewTimer(relayServerRetryInterval) - for { - relayServersToQuery := []gomatrixserverlib.ServerName{} - for server, complete := range m.relayServersQueried { - if !complete { - relayServersToQuery = append(relayServersToQuery, server) - } - } - if len(relayServersToQuery) == 0 { - // All relay servers have been synced. - return - } - m.queryRelayServers(relayServersToQuery) - t.Reset(relayServerRetryInterval) - - select { - case <-stop: - if !t.Stop() { - <-t.C - } - return - case <-t.C: - } - } -} - -func (m *DendriteMonolith) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { - for _, server := range relayServers { - request := api.PerformRelayServerSyncRequest{RelayServer: server} - response := api.PerformRelayServerSyncResponse{} - err := m.federationAPI.PerformRelayServerSync(m.processContext.Context(), &request, &response) - if err == nil { - m.relayServersQueried[server] = true - } - } -} - func (m *DendriteMonolith) Stop() { m.processContext.ShutdownDendrite() _ = m.listener.Close() diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index d594cb863..a6cee1696 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -43,6 +43,7 @@ import ( "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/relayapi" + relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" @@ -145,6 +146,7 @@ func main() { cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(*instanceDir, *instanceName))) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(*instanceDir, *instanceName))) cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", filepath.Join(*instanceDir, *instanceName))) + cfg.RelayAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-relayapi.db", filepath.Join(*instanceDir, *instanceName))) cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName))) cfg.ClientAPI.RegistrationDisabled = false @@ -235,6 +237,20 @@ func main() { userProvider := users.NewPineconeUserProvider(pRouter, pQUIC, userAPI, federation) roomProvider := rooms.NewPineconeRoomProvider(pRouter, pQUIC, fsAPI, federation) + js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + producer := &producers.SyncAPIProducer{ + JetStream: js, + TopicReceiptEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicTypingEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + TopicPresenceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + TopicDeviceListUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + TopicSigningKeyUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + Config: &base.Cfg.FederationAPI, + UserAPI: userAPI, + } + relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer) + monolith := setup.Monolith{ Config: base.Cfg, Client: conn.CreateClient(base, pQUIC), @@ -246,6 +262,7 @@ func main() { RoomserverAPI: rsAPI, UserAPI: userAPI, KeyAPI: keyAPI, + RelayAPI: &relayAPI, ExtPublicRoomsProvider: roomProvider, ExtUserDirectoryProvider: userProvider, } @@ -319,32 +336,12 @@ func main() { relayServerSyncRunning := atomic.NewBool(false) stopRelayServerSync := make(chan bool) - js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.FederationAPI.Matrix.JetStream) - producer := &producers.SyncAPIProducer{ - JetStream: js, - TopicReceiptEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), - TopicSendToDeviceEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), - TopicTypingEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputTypingEvent), - TopicPresenceEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent), - TopicDeviceListUpdate: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), - TopicSigningKeyUpdate: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), - Config: &base.Cfg.FederationAPI, - UserAPI: userAPI, - } - m := RelayServerRetriever{ Context: context.Background(), ServerName: gomatrixserverlib.ServerName(pRouter.PublicKey().String()), FederationAPI: fsAPI, RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool), - RelayAPI: relayapi.NewRelayAPI( - federation, - rsAPI, - keyRing, - producer, - cfg.Global.Presence.EnableInbound, - cfg.Global.ServerName, - ), + RelayAPI: monolith.RelayAPI, } m.InitializeRelayServers(eLog) @@ -387,7 +384,7 @@ type RelayServerRetriever struct { ServerName gomatrixserverlib.ServerName FederationAPI api.FederationInternalAPI RelayServersQueried map[gomatrixserverlib.ServerName]bool - RelayAPI relayapi.RelayAPI + RelayAPI relayServerAPI.RelayInternalAPI } func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { @@ -440,7 +437,12 @@ func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverli if err != nil { return } - err = m.RelayAPI.PerformRelayServerSync(*userID, server) + request := relayServerAPI.PerformRelayServerSyncRequest{ + UserID: *userID, + RelayServer: server, + } + response := relayServerAPI.PerformRelayServerSyncResponse{} + err = m.RelayAPI.PerformRelayServerSync(context.Background(), &request, &response) if err == nil { m.RelayServersQueried[server] = true // TODO : What happens if your relay receives new messages after this point? diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 3b234d5eb..10f0a9efc 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -18,7 +18,6 @@ type FederationInternalAPI interface { gomatrixserverlib.KeyDatabase ClientFederationAPI RoomserverFederationAPI - RelayServerAPI QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) @@ -45,22 +44,6 @@ type FederationInternalAPI interface { ) error } -type RelayServerAPI interface { - // Store async transactions for forwarding to the destination at a later time. - PerformStoreAsync( - ctx context.Context, - request *PerformStoreAsyncRequest, - response *PerformStoreAsyncResponse, - ) error - - // Obtain the oldest stored transaction for the specified userID. - QueryAsyncTransactions( - ctx context.Context, - request *QueryAsyncTransactionsRequest, - response *QueryAsyncTransactionsResponse, - ) error -} - type ClientFederationAPI interface { // Query the server names of the joined hosts in a room. // Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index edb943526..20591a3cd 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -14,7 +14,6 @@ import ( "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/consumers" - "github.com/matrix-org/dendrite/federationapi/storage/shared" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/version" ) @@ -847,64 +846,3 @@ func (r *FederationInternalAPI) QueryRelayServers( response.RelayServers = relayServers return nil } - -// PerformStoreAsync implements api.FederationInternalAPI -func (r *FederationInternalAPI) PerformStoreAsync( - ctx context.Context, - request *api.PerformStoreAsyncRequest, - response *api.PerformStoreAsyncResponse, -) error { - logrus.Warnf("Storing transaction for %v", request.UserID) - receipt, err := r.db.StoreAsyncTransaction(ctx, request.Txn) - if err != nil { - return err - } - err = r.db.AssociateAsyncTransactionWithDestinations( - ctx, - map[gomatrixserverlib.UserID]struct{}{ - request.UserID: {}, - }, - request.Txn.TransactionID, - receipt) - - return err -} - -// QueryAsyncTransactions implements api.FederationInternalAPI -func (r *FederationInternalAPI) QueryAsyncTransactions( - ctx context.Context, - request *api.QueryAsyncTransactionsRequest, - response *api.QueryAsyncTransactionsResponse, -) error { - logrus.Warnf("Obtaining transaction for %v", request.UserID) - transaction, receipt, err := r.db.GetAsyncTransaction(ctx, request.UserID) - if err != nil { - return err - } - - // TODO : Shouldn't be deleting unless the transaction was successfully returned... - // TODO : Should delete transaction json from table if no more associations - // Maybe track last received transaction, and send that as part of the request, - // then delete before getting the new events from the db. - if transaction != nil && receipt != nil { - err = r.db.CleanAsyncTransactions(ctx, request.UserID, []*shared.Receipt{receipt}) - if err != nil { - return err - } - - // TODO : Clean async transactions json - } - - // TODO : These db calls should happen at the same time right? - count, err := r.db.GetAsyncTransactionCount(ctx, request.UserID) - if err != nil { - return err - } - - response.RemainingCount = uint32(count) - if transaction != nil { - response.Txn = *transaction - logrus.Warnf("Obtained transaction: %v", transaction.TransactionID) - } - return nil -} diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index cc63cd96a..de384335d 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -26,9 +26,6 @@ const ( FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers" FederationAPIQueryRelayServers = "/federationapi/queryRelayServers" - FederationAPIPerformStoreAsyncPath = "/federationapi/performStoreAsync" - FederationAPIQueryAsyncTransactionsPath = "/federationapi/queryAsyncTransactions" - FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" FederationAPIQueryKeysPath = "/federationapi/client/queryKeys" @@ -525,25 +522,3 @@ func (h *httpFederationInternalAPI) QueryRelayServers( h.httpClient, ctx, request, response, ) } - -func (h *httpFederationInternalAPI) PerformStoreAsync( - ctx context.Context, - request *api.PerformStoreAsyncRequest, - response *api.PerformStoreAsyncResponse, -) error { - return httputil.CallInternalRPCAPI( - "PerformStoreAsync", h.federationAPIURL+FederationAPIPerformStoreAsyncPath, - h.httpClient, ctx, request, response, - ) -} - -func (h *httpFederationInternalAPI) QueryAsyncTransactions( - ctx context.Context, - request *api.QueryAsyncTransactionsRequest, - response *api.QueryAsyncTransactionsResponse, -) error { - return httputil.CallInternalRPCAPI( - "QueryAsyncTransactions", h.federationAPIURL+FederationAPIQueryAsyncTransactionsPath, - h.httpClient, ctx, request, response, - ) -} diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 70480db85..21a070392 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -43,16 +43,6 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU), ) - internalAPIMux.Handle( - FederationAPIPerformStoreAsyncPath, - httputil.MakeInternalRPCAPI("FederationAPIPerformStoreAsync", intAPI.PerformStoreAsync), - ) - - internalAPIMux.Handle( - FederationAPIQueryAsyncTransactionsPath, - httputil.MakeInternalRPCAPI("FederationAPIQueryAsyncTransactions", intAPI.QueryAsyncTransactions), - ) - internalAPIMux.Handle( FederationAPIPerformWakeupServers, httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", intAPI.PerformWakeupServers), diff --git a/federationapi/routing/forwardasync_test.go b/federationapi/routing/forwardasync_test.go deleted file mode 100644 index 09ec82871..000000000 --- a/federationapi/routing/forwardasync_test.go +++ /dev/null @@ -1,199 +0,0 @@ -package routing_test - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - "net/http" - "sync" - "testing" - "time" - - "github.com/matrix-org/dendrite/federationapi/internal" - "github.com/matrix-org/dendrite/federationapi/routing" - "github.com/matrix-org/dendrite/federationapi/storage/shared" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" -) - -const ( - testOrigin = gomatrixserverlib.ServerName("kaer.morhen") - testDestination = gomatrixserverlib.ServerName("white.orchard") -) - -type testDatabase struct { - nid int64 - nidMutex sync.Mutex - transactions map[int64]json.RawMessage - associations map[gomatrixserverlib.ServerName][]int64 -} - -func createDatabase() *testDatabase { - return &testDatabase{ - nid: 1, - nidMutex: sync.Mutex{}, - transactions: make(map[int64]json.RawMessage), - associations: make(map[gomatrixserverlib.ServerName][]int64), - } -} - -func (d *testDatabase) InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error { - if _, ok := d.associations[serverName]; !ok { - d.associations[serverName] = []int64{} - } - d.associations[serverName] = append(d.associations[serverName], nid) - return nil -} - -func (d *testDatabase) DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error { - for _, nid := range jsonNIDs { - for index, associatedNID := range d.associations[serverName] { - if associatedNID == nid { - d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...) - } - } - } - - return nil -} - -func (d *testDatabase) SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) { - results := []int64{} - resultCount := limit - if limit > len(d.associations[serverName]) { - resultCount = len(d.associations[serverName]) - } - if resultCount > 0 { - for i := 0; i < resultCount; i++ { - results = append(results, d.associations[serverName][i]) - } - } - - return results, nil -} - -func (d *testDatabase) SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) { - return int64(len(d.associations[serverName])), nil -} - -func (d *testDatabase) InsertTransactionJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) { - d.nidMutex.Lock() - defer d.nidMutex.Unlock() - - nid := d.nid - d.transactions[nid] = []byte(json) - d.nid++ - - return nid, nil -} - -func (d *testDatabase) DeleteTransactionJSON(ctx context.Context, txn *sql.Tx, nids []int64) error { - for _, nid := range nids { - delete(d.transactions, nid) - } - - return nil -} - -func (d *testDatabase) SelectTransactionJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) { - result := make(map[int64][]byte) - for _, nid := range jsonNIDs { - if transaction, ok := d.transactions[nid]; ok { - result[nid] = transaction - } - } - - return result, nil -} - -func createTransaction() gomatrixserverlib.Transaction { - txn := gomatrixserverlib.Transaction{} - txn.PDUs = []json.RawMessage{ - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), - } - txn.Origin = testOrigin - txn.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) - txn.Destination = testDestination - return txn -} - -func createFederationRequest(userID gomatrixserverlib.UserID) (gomatrixserverlib.Transaction, gomatrixserverlib.FederationRequest) { - txn := createTransaction() - var federationPathPrefixV1 = "/_matrix/federation/v1" - path := federationPathPrefixV1 + "/forward_async/" + string(txn.TransactionID) + "/" + userID.Raw() - request := gomatrixserverlib.NewFederationRequest("PUT", txn.Origin, txn.Destination, path) - request.SetContent(txn) - - return txn, request -} - -func TestEmptyForwardReturnsOk(t *testing.T) { - testDB := createDatabase() - db := shared.Database{ - Writer: sqlutil.NewDummyWriter(), - FederationQueueTransactions: testDB, - FederationTransactionJSON: testDB, - } - httpReq := &http.Request{} - userID, err := gomatrixserverlib.NewUserID("@local:domain", false) - if err != nil { - t.Fatalf("Invalid userID: %s", err.Error()) - } - _, request := createFederationRequest(*userID) - - fedAPI := internal.NewFederationInternalAPI( - &db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil, - ) - - response := routing.ForwardAsync(httpReq, &request, fedAPI, "1", *userID) - - expected := 200 - if response.Code != expected { - t.Fatalf("Expected: %v, Actual: %v", expected, response.Code) - } -} - -func TestUniqueTransactionStoredInDatabase(t *testing.T) { - testDB := createDatabase() - db := shared.Database{ - Writer: sqlutil.NewDummyWriter(), - FederationQueueTransactions: testDB, - FederationTransactionJSON: testDB, - } - - httpReq := &http.Request{} - userID, err := gomatrixserverlib.NewUserID("@local:domain", false) - if err != nil { - t.Fatalf("Invalid userID: %s", err.Error()) - } - inputTransaction, request := createFederationRequest(*userID) - - fedAPI := internal.NewFederationInternalAPI( - &db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil, - ) - - response := routing.ForwardAsync( - httpReq, &request, fedAPI, inputTransaction.TransactionID, *userID) - transaction, _, err := db.GetAsyncTransaction(context.TODO(), *userID) - if err != nil { - t.Fatalf("Failed retrieving transaction: %s", err.Error()) - } - transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID) - if err != nil { - t.Fatalf("Failed retrieving transaction count: %s", err.Error()) - } - - expected := 200 - if response.Code != expected { - t.Fatalf("Expected Return Code: %v, Actual: %v", expected, response.Code) - } - if transactionCount != 1 { - t.Fatalf("Expected count of 1, Actual: %d", transactionCount) - } - if transaction.TransactionID != inputTransaction.TransactionID { - t.Fatalf("Expected Transaction ID: %s, Actual: %s", - inputTransaction.TransactionID, transaction.TransactionID) - } -} diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index beabf894a..3f30cd4ee 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -133,37 +133,6 @@ func Setup( }, )).Methods(http.MethodPut, http.MethodOptions) - v1fedmux.Handle("/forward_async/{txnID}/{userID}", MakeFedAPI( - "federation_forward_async", "", cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username was invalid"), - } - } - return ForwardAsync( - httpReq, request, fsAPI, gomatrixserverlib.TransactionID(vars["txnID"]), - *userID, - ) - }, - )).Methods(http.MethodPut, http.MethodOptions) - - v1fedmux.Handle("/async_events/{userID}", MakeFedAPI( - "federation_async_events", "", cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username was invalid"), - } - } - return GetAsyncEvents(httpReq, request, fsAPI, *userID) - }, - )).Methods(http.MethodGet, http.MethodOptions) - v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 8105ed2a3..c70d5e9ed 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -51,12 +51,6 @@ type Database interface { GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) - StoreAsyncTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*shared.Receipt, error) - AssociateAsyncTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *shared.Receipt) error - CleanAsyncTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*shared.Receipt) error - GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *shared.Receipt, error) - GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error) - // these don't have contexts passed in as we want things to happen regardless of the request context AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index 7bb1a2037..b81f128e7 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -58,18 +58,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } - queueTransactions, err := NewPostgresQueueTransactionsTable(d.db) - if err != nil { - return nil, err - } queueJSON, err := NewPostgresQueueJSONTable(d.db) if err != nil { return nil, err } - transactionJSON, err := NewPostgresTransactionJSONTable(d.db) - if err != nil { - return nil, err - } assumedOffline, err := NewPostgresAssumedOfflineTable(d.db) if err != nil { return nil, err @@ -111,24 +103,22 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return nil, err } d.Database = shared.Database{ - DB: d.db, - IsLocalServerName: isLocalServerName, - Cache: cache, - Writer: d.writer, - FederationJoinedHosts: joinedHosts, - FederationQueuePDUs: queuePDUs, - FederationQueueEDUs: queueEDUs, - FederationQueueJSON: queueJSON, - FederationQueueTransactions: queueTransactions, - FederationTransactionJSON: transactionJSON, - FederationBlacklist: blacklist, - FederationAssumedOffline: assumedOffline, - FederationRelayServers: relayServers, - FederationInboundPeeks: inboundPeeks, - FederationOutboundPeeks: outboundPeeks, - NotaryServerKeysJSON: notaryJSON, - NotaryServerKeysMetadata: notaryMetadata, - ServerSigningKeys: serverSigningKeys, + DB: d.db, + IsLocalServerName: isLocalServerName, + Cache: cache, + Writer: d.writer, + FederationJoinedHosts: joinedHosts, + FederationQueuePDUs: queuePDUs, + FederationQueueEDUs: queueEDUs, + FederationQueueJSON: queueJSON, + FederationBlacklist: blacklist, + FederationAssumedOffline: assumedOffline, + FederationRelayServers: relayServers, + FederationInboundPeeks: inboundPeeks, + FederationOutboundPeeks: outboundPeeks, + NotaryServerKeysJSON: notaryJSON, + NotaryServerKeysMetadata: notaryMetadata, + ServerSigningKeys: serverSigningKeys, } return &d, nil } diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 06fe2901e..3668d6e2c 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -17,7 +17,6 @@ package shared import ( "context" "database/sql" - "encoding/json" "fmt" "time" @@ -29,24 +28,22 @@ import ( ) type Database struct { - DB *sql.DB - IsLocalServerName func(gomatrixserverlib.ServerName) bool - Cache caching.FederationCache - Writer sqlutil.Writer - FederationQueuePDUs tables.FederationQueuePDUs - FederationQueueEDUs tables.FederationQueueEDUs - FederationQueueJSON tables.FederationQueueJSON - FederationJoinedHosts tables.FederationJoinedHosts - FederationBlacklist tables.FederationBlacklist - FederationAssumedOffline tables.FederationAssumedOffline - FederationRelayServers tables.FederationRelayServers - FederationOutboundPeeks tables.FederationOutboundPeeks - FederationInboundPeeks tables.FederationInboundPeeks - NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON - NotaryServerKeysMetadata tables.FederationNotaryServerKeysMetadata - ServerSigningKeys tables.FederationServerSigningKeys - FederationQueueTransactions tables.FederationQueueTransactions - FederationTransactionJSON tables.FederationTransactionJSON + DB *sql.DB + IsLocalServerName func(gomatrixserverlib.ServerName) bool + Cache caching.FederationCache + Writer sqlutil.Writer + FederationQueuePDUs tables.FederationQueuePDUs + FederationQueueEDUs tables.FederationQueueEDUs + FederationQueueJSON tables.FederationQueueJSON + FederationJoinedHosts tables.FederationJoinedHosts + FederationBlacklist tables.FederationBlacklist + FederationAssumedOffline tables.FederationAssumedOffline + FederationRelayServers tables.FederationRelayServers + FederationOutboundPeeks tables.FederationOutboundPeeks + FederationInboundPeeks tables.FederationInboundPeeks + NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON + NotaryServerKeysMetadata tables.FederationNotaryServerKeysMetadata + ServerSigningKeys tables.FederationServerSigningKeys } // An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. @@ -61,6 +58,10 @@ func NewReceipt(nid int64) Receipt { return Receipt{nid: nid} } +func (r *Receipt) GetNID() int64 { + return r.nid +} + func (r *Receipt) String() string { return fmt.Sprintf("%d", r.nid) } @@ -308,108 +309,3 @@ func (d *Database) GetNotaryKeys( }) return sks, err } - -func (d *Database) StoreAsyncTransaction( - ctx context.Context, txn gomatrixserverlib.Transaction, -) (*Receipt, error) { - var err error - json, err := json.Marshal(txn) - if err != nil { - return nil, fmt.Errorf("d.JSONUnmarshall: %w", err) - } - - var nid int64 - _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - nid, err = d.FederationTransactionJSON.InsertTransactionJSON(ctx, txn, string(json)) - return err - }) - if err != nil { - return nil, fmt.Errorf("d.insertTransactionJSON: %w", err) - } - return &Receipt{ - nid: nid, - }, nil -} - -func (d *Database) AssociateAsyncTransactionWithDestinations( - ctx context.Context, - destinations map[gomatrixserverlib.UserID]struct{}, - transactionID gomatrixserverlib.TransactionID, - receipt *Receipt, -) error { - for destination := range destinations { - err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - err := d.FederationQueueTransactions.InsertQueueTransaction( - ctx, txn, transactionID, destination.Domain(), receipt.nid) - return err - }) - if err != nil { - return fmt.Errorf("d.insertQueueTransaction: %w", err) - } - } - - return nil -} - -func (d *Database) CleanAsyncTransactions( - ctx context.Context, - userID gomatrixserverlib.UserID, - receipts []*Receipt, -) error { - println(len(receipts)) - nids := make([]int64, len(receipts)) - for i, receipt := range receipts { - nids[i] = receipt.nid - } - err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - err := d.FederationQueueTransactions.DeleteQueueTransactions(ctx, txn, userID.Domain(), nids) - return err - }) - if err != nil { - return fmt.Errorf("d.insertQueueTransaction: %w", err) - } - - return nil -} - -func (d *Database) GetAsyncTransaction( - ctx context.Context, - userID gomatrixserverlib.UserID, -) (*gomatrixserverlib.Transaction, *Receipt, error) { - nids, err := d.FederationQueueTransactions.SelectQueueTransactions(ctx, nil, userID.Domain(), 1) - if err != nil { - return nil, nil, fmt.Errorf("d.SelectQueueTransaction: %w", err) - } - if len(nids) == 0 { - return nil, nil, nil - } - - txns := map[int64][]byte{} - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - txns, err = d.FederationTransactionJSON.SelectTransactionJSON(ctx, txn, nids) - return err - }) - if err != nil { - return nil, nil, fmt.Errorf("d.SelectTransactionJSON: %w", err) - } - - transaction := &gomatrixserverlib.Transaction{} - err = json.Unmarshal(txns[nids[0]], transaction) - if err != nil { - return nil, nil, fmt.Errorf("Unmarshall transaction: %w", err) - } - - receipt := NewReceipt(nids[0]) - return transaction, &receipt, nil -} - -func (d *Database) GetAsyncTransactionCount( - ctx context.Context, - userID gomatrixserverlib.UserID, -) (int64, error) { - count, err := d.FederationQueueTransactions.SelectQueueTransactionCount(ctx, nil, userID.Domain()) - if err != nil { - return 0, fmt.Errorf("d.SelectQueueTransactionCount: %w", err) - } - return count, nil -} diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index dfc788a81..028e2f880 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -55,14 +55,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } - queueTransactions, err := NewSQLiteQueueTransactionsTable(d.db) - if err != nil { - return nil, err - } - transactionJSON, err := NewSQLiteTransactionJSONTable(d.db) - if err != nil { - return nil, err - } assumedOffline, err := NewSQLiteAssumedOfflineTable(d.db) if err != nil { return nil, err @@ -104,24 +96,22 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return nil, err } d.Database = shared.Database{ - DB: d.db, - IsLocalServerName: isLocalServerName, - Cache: cache, - Writer: d.writer, - FederationJoinedHosts: joinedHosts, - FederationQueuePDUs: queuePDUs, - FederationQueueEDUs: queueEDUs, - FederationQueueJSON: queueJSON, - FederationQueueTransactions: queueTransactions, - FederationTransactionJSON: transactionJSON, - FederationBlacklist: blacklist, - FederationAssumedOffline: assumedOffline, - FederationRelayServers: relayServers, - FederationOutboundPeeks: outboundPeeks, - FederationInboundPeeks: inboundPeeks, - NotaryServerKeysJSON: notaryKeys, - NotaryServerKeysMetadata: notaryKeysMetadata, - ServerSigningKeys: serverSigningKeys, + DB: d.db, + IsLocalServerName: isLocalServerName, + Cache: cache, + Writer: d.writer, + FederationJoinedHosts: joinedHosts, + FederationQueuePDUs: queuePDUs, + FederationQueueEDUs: queueEDUs, + FederationQueueJSON: queueJSON, + FederationBlacklist: blacklist, + FederationAssumedOffline: assumedOffline, + FederationRelayServers: relayServers, + FederationOutboundPeeks: outboundPeeks, + FederationInboundPeeks: inboundPeeks, + NotaryServerKeysJSON: notaryKeys, + NotaryServerKeysMetadata: notaryKeysMetadata, + ServerSigningKeys: serverSigningKeys, } return &d, nil } diff --git a/relayapi/api/api.go b/relayapi/api/api.go new file mode 100644 index 000000000..fab76becd --- /dev/null +++ b/relayapi/api/api.go @@ -0,0 +1,81 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + + "github.com/matrix-org/gomatrixserverlib" +) + +// RelayInternalAPI is used to query information from the relay server. +type RelayInternalAPI interface { + RelayServerAPI + + PerformRelayServerSync( + ctx context.Context, + request *PerformRelayServerSyncRequest, + response *PerformRelayServerSyncResponse, + ) error +} + +type RelayServerAPI interface { + // Store async transactions for forwarding to the destination at a later time. + PerformStoreAsync( + ctx context.Context, + request *PerformStoreAsyncRequest, + response *PerformStoreAsyncResponse, + ) error + + // Obtain the oldest stored transaction for the specified userID. + QueryAsyncTransactions( + ctx context.Context, + request *QueryAsyncTransactionsRequest, + response *QueryAsyncTransactionsResponse, + ) error +} + +type PerformRelayServerSyncRequest struct { + UserID gomatrixserverlib.UserID `json:"user_id"` + RelayServer gomatrixserverlib.ServerName `json:"relay_name"` +} + +type PerformRelayServerSyncResponse struct { +} + +type QueryRelayServersRequest struct { + Server gomatrixserverlib.ServerName +} + +type QueryRelayServersResponse struct { + RelayServers []gomatrixserverlib.ServerName +} + +type PerformStoreAsyncRequest struct { + Txn gomatrixserverlib.Transaction `json:"transaction"` + UserID gomatrixserverlib.UserID `json:"user_id"` +} + +type PerformStoreAsyncResponse struct { +} + +type QueryAsyncTransactionsRequest struct { + UserID gomatrixserverlib.UserID `json:"user_id"` +} + +type QueryAsyncTransactionsResponse struct { + Txn gomatrixserverlib.Transaction `json:"transaction"` + RemainingCount uint32 `json:"remaining"` +} diff --git a/relayapi/internal/api.go b/relayapi/internal/api.go new file mode 100644 index 000000000..3b0995edc --- /dev/null +++ b/relayapi/internal/api.go @@ -0,0 +1,52 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/relayapi/storage" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" +) + +type RelayInternalAPI struct { + db storage.Database + fedClient *gomatrixserverlib.FederationClient + rsAPI rsAPI.RoomserverInternalAPI + keyRing *gomatrixserverlib.KeyRing + producer *producers.SyncAPIProducer + presenceEnabledInbound bool + serverName gomatrixserverlib.ServerName +} + +func NewRelayInternalAPI( + db storage.Database, + fedClient *gomatrixserverlib.FederationClient, + rsAPI rsAPI.RoomserverInternalAPI, + keyRing *gomatrixserverlib.KeyRing, + producer *producers.SyncAPIProducer, + presenceEnabledInbound bool, + serverName gomatrixserverlib.ServerName, +) RelayInternalAPI { + return RelayInternalAPI{ + db: db, + fedClient: fedClient, + rsAPI: rsAPI, + keyRing: keyRing, + producer: producer, + presenceEnabledInbound: presenceEnabledInbound, + serverName: serverName, + } +} diff --git a/relayapi/internal/perform.go b/relayapi/internal/perform.go new file mode 100644 index 000000000..955941a10 --- /dev/null +++ b/relayapi/internal/perform.go @@ -0,0 +1,131 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + + "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// PerformRelayServerSync implements api.FederationInternalAPI +func (r *RelayInternalAPI) PerformRelayServerSync( + ctx context.Context, + request *api.PerformRelayServerSyncRequest, + response *api.PerformRelayServerSyncResponse, +) error { + asyncResponse, err := r.fedClient.GetAsyncEvents(ctx, request.UserID, request.RelayServer) + if err != nil { + logrus.Errorf("GetAsyncEvents: %s", err.Error()) + return err + } + r.processTransaction(&asyncResponse.Transaction) + + for asyncResponse.Remaining > 0 { + asyncResponse, err := r.fedClient.GetAsyncEvents(ctx, request.UserID, request.RelayServer) + if err != nil { + logrus.Errorf("GetAsyncEvents: %s", err.Error()) + return err + } + r.processTransaction(&asyncResponse.Transaction) + } + + return nil +} + +// PerformStoreAsync implements api.RelayInternalAPI +func (r *RelayInternalAPI) PerformStoreAsync( + ctx context.Context, + request *api.PerformStoreAsyncRequest, + response *api.PerformStoreAsyncResponse, +) error { + logrus.Warnf("Storing transaction for %v", request.UserID) + receipt, err := r.db.StoreAsyncTransaction(ctx, request.Txn) + if err != nil { + return err + } + err = r.db.AssociateAsyncTransactionWithDestinations( + ctx, + map[gomatrixserverlib.UserID]struct{}{ + request.UserID: {}, + }, + request.Txn.TransactionID, + receipt) + + return err +} + +// QueryAsyncTransactions implements api.RelayInternalAPI +func (r *RelayInternalAPI) QueryAsyncTransactions( + ctx context.Context, + request *api.QueryAsyncTransactionsRequest, + response *api.QueryAsyncTransactionsResponse, +) error { + logrus.Warnf("Obtaining transaction for %v", request.UserID) + transaction, receipt, err := r.db.GetAsyncTransaction(ctx, request.UserID) + if err != nil { + return err + } + + // TODO : Shouldn't be deleting unless the transaction was successfully returned... + // TODO : Should delete transaction json from table if no more associations + // Maybe track last received transaction, and send that as part of the request, + // then delete before getting the new events from the db. + if transaction != nil && receipt != nil { + err = r.db.CleanAsyncTransactions(ctx, request.UserID, []*shared.Receipt{receipt}) + if err != nil { + return err + } + + // TODO : Clean async transactions json + } + + // TODO : These db calls should happen at the same time right? + count, err := r.db.GetAsyncTransactionCount(ctx, request.UserID) + if err != nil { + return err + } + + response.RemainingCount = uint32(count) + if transaction != nil { + response.Txn = *transaction + logrus.Warnf("Obtained transaction: %v", transaction.TransactionID) + } + return nil +} + +func (r *RelayInternalAPI) processTransaction(txn *gomatrixserverlib.Transaction) { + logrus.Warn("Processing transaction from relay server") + mu := internal.NewMutexByRoom() + t := internal.NewTxnReq( + r.rsAPI, + nil, + r.serverName, + r.keyRing, + mu, + r.producer, + r.presenceEnabledInbound, + txn.PDUs, + txn.EDUs, + txn.Origin, + txn.TransactionID, + txn.Destination) + + t.ProcessTransaction(context.TODO()) +} diff --git a/relayapi/inthttp/client.go b/relayapi/inthttp/client.go new file mode 100644 index 000000000..9aff88475 --- /dev/null +++ b/relayapi/inthttp/client.go @@ -0,0 +1,70 @@ +package inthttp + +import ( + "context" + "errors" + "net/http" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/relayapi/api" +) + +// HTTP paths for the internal HTTP API +const ( + RelayAPIPerformRelayServerSyncPath = "/relayapi/performRelayServerSync" + RelayAPIPerformStoreAsyncPath = "/relayapi/performStoreAsync" + RelayAPIQueryAsyncTransactionsPath = "/relayapi/queryAsyncTransactions" +) + +// NewRelayAPIClient creates a RelayInternalAPI implemented by talking to a HTTP POST API. +// If httpClient is nil an error is returned +func NewRelayAPIClient(relayapiURL string, httpClient *http.Client, cache caching.ServerKeyCache) (api.RelayInternalAPI, error) { + if httpClient == nil { + return nil, errors.New("NewRelayInternalAPIHTTP: httpClient is ") + } + return &httpRelayInternalAPI{ + relayAPIURL: relayapiURL, + httpClient: httpClient, + cache: cache, + }, nil +} + +type httpRelayInternalAPI struct { + relayAPIURL string + httpClient *http.Client + cache caching.ServerKeyCache +} + +func (h *httpRelayInternalAPI) PerformRelayServerSync( + ctx context.Context, + request *api.PerformRelayServerSyncRequest, + response *api.PerformRelayServerSyncResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformRelayServerSync", h.relayAPIURL+RelayAPIPerformRelayServerSyncPath, + h.httpClient, ctx, request, response, + ) +} + +func (h *httpRelayInternalAPI) PerformStoreAsync( + ctx context.Context, + request *api.PerformStoreAsyncRequest, + response *api.PerformStoreAsyncResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformStoreAsync", h.relayAPIURL+RelayAPIPerformStoreAsyncPath, + h.httpClient, ctx, request, response, + ) +} + +func (h *httpRelayInternalAPI) QueryAsyncTransactions( + ctx context.Context, + request *api.QueryAsyncTransactionsRequest, + response *api.QueryAsyncTransactionsResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryAsyncTransactions", h.relayAPIURL+RelayAPIQueryAsyncTransactionsPath, + h.httpClient, ctx, request, response, + ) +} diff --git a/relayapi/inthttp/server.go b/relayapi/inthttp/server.go new file mode 100644 index 000000000..9c15a73d3 --- /dev/null +++ b/relayapi/inthttp/server.go @@ -0,0 +1,27 @@ +package inthttp + +import ( + "github.com/gorilla/mux" + + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/relayapi/api" +) + +// AddRoutes adds the RelayInternalAPI handlers to the http.ServeMux. +// nolint:gocyclo +func AddRoutes(intAPI api.RelayInternalAPI, internalAPIMux *mux.Router) { + internalAPIMux.Handle( + RelayAPIPerformRelayServerSyncPath, + httputil.MakeInternalRPCAPI("RelayAPIPerformRelayServerSync", intAPI.PerformRelayServerSync), + ) + + internalAPIMux.Handle( + RelayAPIPerformStoreAsyncPath, + httputil.MakeInternalRPCAPI("RelayAPIPerformStoreAsync", intAPI.PerformStoreAsync), + ) + + internalAPIMux.Handle( + RelayAPIQueryAsyncTransactionsPath, + httputil.MakeInternalRPCAPI("RelayAPIQueryAsyncTransactions", intAPI.QueryAsyncTransactions), + ) +} diff --git a/relayapi/relayapi.go b/relayapi/relayapi.go index 2237b0431..c69f78cd6 100644 --- a/relayapi/relayapi.go +++ b/relayapi/relayapi.go @@ -15,79 +15,77 @@ package relayapi import ( - "context" - + "github.com/gorilla/mux" + federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/producers" - "github.com/matrix-org/dendrite/internal" + keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/relayapi/api" + relayAPI "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/inthttp" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage" rsAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/base" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) -type RelayAPI struct { - fedClient *gomatrixserverlib.FederationClient - rsAPI rsAPI.RoomserverInternalAPI - keyRing *gomatrixserverlib.KeyRing - producer *producers.SyncAPIProducer - presenceEnabledInbound bool - serverName gomatrixserverlib.ServerName +// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions +// on the given input API. +func AddInternalRoutes(router *mux.Router, intAPI api.RelayInternalAPI) { + inthttp.AddRoutes(intAPI, router) } -func NewRelayAPI( +// AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component. +func AddPublicRoutes( + base *base.BaseDendrite, + userAPI userapi.UserInternalAPI, + fedClient *gomatrixserverlib.FederationClient, + keyRing gomatrixserverlib.JSONVerifier, + rsAPI rsAPI.FederationRoomserverAPI, + relayAPI relayAPI.RelayInternalAPI, + fedAPI federationAPI.FederationInternalAPI, + keyAPI keyserverAPI.FederationKeyAPI, +) { + fedCfg := &base.Cfg.FederationAPI + + relay, ok := relayAPI.(*internal.RelayInternalAPI) + if !ok { + panic("relayapi.AddPublicRoutes called with a RelayInternalAPI impl which was not " + + "RelayInternalAPI. This is a programming error.") + } + + routing.Setup( + base.PublicFederationAPIMux, + fedCfg, + relay, + keyRing, + ) +} + +func NewRelayInternalAPI( + base *base.BaseDendrite, fedClient *gomatrixserverlib.FederationClient, rsAPI rsAPI.RoomserverInternalAPI, keyRing *gomatrixserverlib.KeyRing, producer *producers.SyncAPIProducer, - presenceEnabledInbound bool, - serverName gomatrixserverlib.ServerName, -) RelayAPI { - return RelayAPI{ - fedClient: fedClient, - rsAPI: rsAPI, - keyRing: keyRing, - producer: producer, - presenceEnabledInbound: presenceEnabledInbound, - serverName: serverName, - } -} +) internal.RelayInternalAPI { + cfg := &base.Cfg.RelayAPI -// PerformRelayServerSync implements api.FederationInternalAPI -func (r *RelayAPI) PerformRelayServerSync(userID gomatrixserverlib.UserID, relayServer gomatrixserverlib.ServerName) error { - asyncResponse, err := r.fedClient.GetAsyncEvents(context.Background(), userID, relayServer) + relayDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName) if err != nil { - logrus.Errorf("GetAsyncEvents: %s", err.Error()) - return err - } - r.processTransaction(&asyncResponse.Transaction) - - for asyncResponse.Remaining > 0 { - asyncResponse, err := r.fedClient.GetAsyncEvents(context.Background(), userID, relayServer) - if err != nil { - logrus.Errorf("GetAsyncEvents: %s", err.Error()) - return err - } - r.processTransaction(&asyncResponse.Transaction) + logrus.WithError(err).Panic("failed to connect to relay db") } - return nil -} - -func (r *RelayAPI) processTransaction(txn *gomatrixserverlib.Transaction) { - logrus.Warn("Processing transaction from relay server") - mu := internal.NewMutexByRoom() - t := internal.NewTxnReq( - r.rsAPI, - nil, - r.serverName, - r.keyRing, - mu, - r.producer, - r.presenceEnabledInbound, - txn.PDUs, - txn.EDUs, - txn.Origin, - txn.TransactionID, - txn.Destination) - - t.ProcessTransaction(context.TODO()) + return internal.NewRelayInternalAPI( + relayDB, + fedClient, + rsAPI, + keyRing, + producer, + base.Cfg.Global.Presence.EnableInbound, + base.Cfg.Global.ServerName, + ) } diff --git a/federationapi/routing/asyncevents.go b/relayapi/routing/asyncevents.go similarity index 82% rename from federationapi/routing/asyncevents.go rename to relayapi/routing/asyncevents.go index 34ff0499c..a86ae05ef 100644 --- a/federationapi/routing/asyncevents.go +++ b/relayapi/routing/asyncevents.go @@ -3,7 +3,7 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -19,12 +19,12 @@ type AsyncEventsResponse struct { func GetAsyncEvents( httpReq *http.Request, fedReq *gomatrixserverlib.FederationRequest, - fedAPI api.FederationInternalAPI, + relayAPI api.RelayInternalAPI, userID gomatrixserverlib.UserID, ) util.JSONResponse { logrus.Infof("Handling async_events for %v", userID) var response api.QueryAsyncTransactionsResponse - err := fedAPI.QueryAsyncTransactions(httpReq.Context(), &api.QueryAsyncTransactionsRequest{UserID: userID}, &response) + err := relayAPI.QueryAsyncTransactions(httpReq.Context(), &api.QueryAsyncTransactionsRequest{UserID: userID}, &response) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, diff --git a/federationapi/routing/asyncevents_test.go b/relayapi/routing/asyncevents_test.go similarity index 72% rename from federationapi/routing/asyncevents_test.go rename to relayapi/routing/asyncevents_test.go index f9775a90e..6f2ff10a7 100644 --- a/federationapi/routing/asyncevents_test.go +++ b/relayapi/routing/asyncevents_test.go @@ -5,21 +5,21 @@ import ( "net/http" "testing" - "github.com/matrix-org/dendrite/federationapi/internal" - "github.com/matrix-org/dendrite/federationapi/routing" - "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage" + "github.com/matrix-org/dendrite/relayapi/storage/shared" "github.com/matrix-org/gomatrixserverlib" "github.com/stretchr/testify/assert" ) func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) { - testDB := createDatabase() + testDB := storage.NewFakeRelayDatabase() db := shared.Database{ - Writer: sqlutil.NewDummyWriter(), - FederationQueueTransactions: testDB, - FederationTransactionJSON: testDB, + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, } httpReq := &http.Request{} userID, err := gomatrixserverlib.NewUserID("@local:domain", false) @@ -33,11 +33,11 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) { t.Fatalf("Failed to store transaction: %s", err.Error()) } - fedAPI := internal.NewFederationInternalAPI( - &db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil, + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", ) - response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID) + response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) jsonResponse := response.JSON.(routing.AsyncEventsResponse) @@ -46,11 +46,11 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) { } func TestGetAsyncReturnsSavedTransaction(t *testing.T) { - testDB := createDatabase() + testDB := storage.NewFakeRelayDatabase() db := shared.Database{ - Writer: sqlutil.NewDummyWriter(), - FederationQueueTransactions: testDB, - FederationTransactionJSON: testDB, + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, } httpReq := &http.Request{} userID, err := gomatrixserverlib.NewUserID("@local:domain", false) @@ -74,11 +74,11 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) { t.Fatalf("Failed to associate transaction with user: %s", err.Error()) } - fedAPI := internal.NewFederationInternalAPI( - &db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil, + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", ) - response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID) + response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) jsonResponse := response.JSON.(routing.AsyncEventsResponse) @@ -87,11 +87,11 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) { } func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) { - testDB := createDatabase() + testDB := storage.NewFakeRelayDatabase() db := shared.Database{ - Writer: sqlutil.NewDummyWriter(), - FederationQueueTransactions: testDB, - FederationTransactionJSON: testDB, + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, } httpReq := &http.Request{} userID, err := gomatrixserverlib.NewUserID("@local:domain", false) @@ -131,18 +131,18 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) { t.Fatalf("Failed to associate transaction with user: %s", err.Error()) } - fedAPI := internal.NewFederationInternalAPI( - &db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil, + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", ) - response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID) + response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) jsonResponse := response.JSON.(routing.AsyncEventsResponse) assert.Equal(t, uint32(1), jsonResponse.Remaining) assert.Equal(t, transaction, jsonResponse.Transaction) - response = routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID) + response = routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) jsonResponse = response.JSON.(routing.AsyncEventsResponse) diff --git a/federationapi/routing/forwardasync.go b/relayapi/routing/forwardasync.go similarity index 92% rename from federationapi/routing/forwardasync.go rename to relayapi/routing/forwardasync.go index a53d48b97..9f078da7e 100644 --- a/federationapi/routing/forwardasync.go +++ b/relayapi/routing/forwardasync.go @@ -5,7 +5,7 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -15,7 +15,7 @@ import ( func ForwardAsync( httpReq *http.Request, fedReq *gomatrixserverlib.FederationRequest, - fedAPI api.FederationInternalAPI, + relayAPI api.RelayInternalAPI, txnID gomatrixserverlib.TransactionID, userID gomatrixserverlib.UserID, ) util.JSONResponse { @@ -54,7 +54,7 @@ func ForwardAsync( UserID: userID, } res := api.PerformStoreAsyncResponse{} - err := fedAPI.PerformStoreAsync(httpReq.Context(), &req, &res) + err := relayAPI.PerformStoreAsync(httpReq.Context(), &req, &res) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, diff --git a/relayapi/routing/forwardasync_test.go b/relayapi/routing/forwardasync_test.go new file mode 100644 index 000000000..f402e8a47 --- /dev/null +++ b/relayapi/routing/forwardasync_test.go @@ -0,0 +1,111 @@ +package routing_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/gomatrixserverlib" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") + testDestination = gomatrixserverlib.ServerName("white.orchard") +) + +func createTransaction() gomatrixserverlib.Transaction { + txn := gomatrixserverlib.Transaction{} + txn.PDUs = []json.RawMessage{ + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + } + txn.Origin = testOrigin + txn.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + txn.Destination = testDestination + return txn +} + +func createFederationRequest(userID gomatrixserverlib.UserID) (gomatrixserverlib.Transaction, gomatrixserverlib.FederationRequest) { + txn := createTransaction() + var federationPathPrefixV1 = "/_matrix/federation/v1" + path := federationPathPrefixV1 + "/forward_async/" + string(txn.TransactionID) + "/" + userID.Raw() + request := gomatrixserverlib.NewFederationRequest("PUT", txn.Origin, txn.Destination, path) + request.SetContent(txn) + + return txn, request +} + +func TestEmptyForwardReturnsOk(t *testing.T) { + testDB := storage.NewFakeRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + if err != nil { + t.Fatalf("Invalid userID: %s", err.Error()) + } + _, request := createFederationRequest(*userID) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.ForwardAsync(httpReq, &request, &relayAPI, "1", *userID) + + expected := 200 + if response.Code != expected { + t.Fatalf("Expected: %v, Actual: %v", expected, response.Code) + } +} + +func TestUniqueTransactionStoredInDatabase(t *testing.T) { + testDB := storage.NewFakeRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + if err != nil { + t.Fatalf("Invalid userID: %s", err.Error()) + } + inputTransaction, request := createFederationRequest(*userID) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.ForwardAsync( + httpReq, &request, &relayAPI, inputTransaction.TransactionID, *userID) + transaction, _, err := db.GetAsyncTransaction(context.TODO(), *userID) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + + expected := 200 + if response.Code != expected { + t.Fatalf("Expected Return Code: %v, Actual: %v", expected, response.Code) + } + if transactionCount != 1 { + t.Fatalf("Expected count of 1, Actual: %d", transactionCount) + } + if transaction.TransactionID != inputTransaction.TransactionID { + t.Fatalf("Expected Transaction ID: %s, Actual: %s", + inputTransaction.TransactionID, transaction.TransactionID) + } +} diff --git a/relayapi/routing/routing.go b/relayapi/routing/routing.go new file mode 100644 index 000000000..4c284b688 --- /dev/null +++ b/relayapi/routing/routing.go @@ -0,0 +1,123 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "fmt" + "net/http" + "time" + + "github.com/getsentry/sentry-go" + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/httputil" + relayInternal "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// Setup registers HTTP handlers with the given ServeMux. +// The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly +// path unescape twice (once from the router, once from MakeRelayAPI). We need to have this enabled +// so we can decode paths like foo/bar%2Fbaz as [foo, bar/baz] - by default it will decode to [foo, bar, baz] +// +// Due to Setup being used to call many other functions, a gocyclo nolint is +// applied: +// nolint: gocyclo +func Setup( + fedMux *mux.Router, + cfg *config.FederationAPI, + relayAPI *relayInternal.RelayInternalAPI, + keys gomatrixserverlib.JSONVerifier, +) { + v1fedmux := fedMux.PathPrefix("/v1").Subrouter() + + v1fedmux.Handle("/forward_async/{txnID}/{userID}", MakeRelayAPI( + "relay_forward_async", "", cfg.Matrix.IsLocalServerName, keys, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Username was invalid"), + } + } + return ForwardAsync( + httpReq, request, relayAPI, gomatrixserverlib.TransactionID(vars["txnID"]), + *userID, + ) + }, + )).Methods(http.MethodPut, http.MethodOptions) + + v1fedmux.Handle("/async_events/{userID}", MakeRelayAPI( + "relay_async_events", "", cfg.Matrix.IsLocalServerName, keys, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Username was invalid"), + } + } + return GetAsyncEvents(httpReq, request, relayAPI, *userID) + }, + )).Methods(http.MethodGet, http.MethodOptions) +} + +// MakeRelayAPI makes an http.Handler that checks matrix relay authentication. +func MakeRelayAPI( + metricsName string, serverName gomatrixserverlib.ServerName, + isLocalServerName func(gomatrixserverlib.ServerName) bool, + keyRing gomatrixserverlib.JSONVerifier, + f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse, +) http.Handler { + h := func(req *http.Request) util.JSONResponse { + fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( + req, time.Now(), serverName, isLocalServerName, keyRing, + ) + if fedReq == nil { + return errResp + } + // add the user to Sentry, if enabled + hub := sentry.GetHubFromContext(req.Context()) + if hub != nil { + hub.Scope().SetTag("origin", string(fedReq.Origin())) + hub.Scope().SetTag("uri", fedReq.RequestURI()) + } + defer func() { + if r := recover(); r != nil { + if hub != nil { + hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path)) + } + // re-panic to return the 500 + panic(r) + } + }() + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MatrixErrorResponse(400, "M_UNRECOGNISED", "badly encoded query params") + } + + jsonRes := f(req, fedReq, vars) + // do not log 4xx as errors as they are client fails, not server fails + if hub != nil && jsonRes.Code >= 500 { + hub.Scope().SetExtra("response", jsonRes) + hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code)) + } + return jsonRes + } + return httputil.MakeExternalAPI(metricsName, h) +} diff --git a/relayapi/storage/fake_relay_db.go b/relayapi/storage/fake_relay_db.go new file mode 100644 index 000000000..98a03b565 --- /dev/null +++ b/relayapi/storage/fake_relay_db.go @@ -0,0 +1,95 @@ +package storage + +import ( + "context" + "database/sql" + "encoding/json" + "sync" + + "github.com/matrix-org/gomatrixserverlib" +) + +type testDatabase struct { + nid int64 + nidMutex sync.Mutex + transactions map[int64]json.RawMessage + associations map[gomatrixserverlib.ServerName][]int64 +} + +func NewFakeRelayDatabase() *testDatabase { + return &testDatabase{ + nid: 1, + nidMutex: sync.Mutex{}, + transactions: make(map[int64]json.RawMessage), + associations: make(map[gomatrixserverlib.ServerName][]int64), + } +} + +func (d *testDatabase) InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error { + if _, ok := d.associations[serverName]; !ok { + d.associations[serverName] = []int64{} + } + d.associations[serverName] = append(d.associations[serverName], nid) + return nil +} + +func (d *testDatabase) DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error { + for _, nid := range jsonNIDs { + for index, associatedNID := range d.associations[serverName] { + if associatedNID == nid { + d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...) + } + } + } + + return nil +} + +func (d *testDatabase) SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) { + results := []int64{} + resultCount := limit + if limit > len(d.associations[serverName]) { + resultCount = len(d.associations[serverName]) + } + if resultCount > 0 { + for i := 0; i < resultCount; i++ { + results = append(results, d.associations[serverName][i]) + } + } + + return results, nil +} + +func (d *testDatabase) SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) { + return int64(len(d.associations[serverName])), nil +} + +func (d *testDatabase) InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) { + d.nidMutex.Lock() + defer d.nidMutex.Unlock() + + nid := d.nid + d.transactions[nid] = []byte(json) + d.nid++ + + return nid, nil +} + +func (d *testDatabase) DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error { + for _, nid := range nids { + delete(d.transactions, nid) + } + + return nil +} + +func (d *testDatabase) SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) { + result := make(map[int64][]byte) + for _, nid := range jsonNIDs { + if transaction, ok := d.transactions[nid]; ok { + result[nid] = transaction + } + } + + return result, nil +} diff --git a/relayapi/storage/interface.go b/relayapi/storage/interface.go new file mode 100644 index 000000000..a1218f804 --- /dev/null +++ b/relayapi/storage/interface.go @@ -0,0 +1,30 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + + "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database interface { + StoreAsyncTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*shared.Receipt, error) + AssociateAsyncTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *shared.Receipt) error + CleanAsyncTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*shared.Receipt) error + GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *shared.Receipt, error) + GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error) +} diff --git a/federationapi/storage/postgres/transaction_json_table.go b/relayapi/storage/postgres/relay_queue_json_table.go similarity index 63% rename from federationapi/storage/postgres/transaction_json_table.go rename to relayapi/storage/postgres/relay_queue_json_table.go index 507120edb..c3e1858f1 100644 --- a/federationapi/storage/postgres/transaction_json_table.go +++ b/relayapi/storage/postgres/relay_queue_json_table.go @@ -23,60 +23,60 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" ) -const transactionJSONSchema = ` --- The federationsender_transaction_json table contains event contents that +const relayQueueJSONSchema = ` +-- The relayapi_queue_json table contains event contents that -- we are storing for future forwarding. -CREATE TABLE IF NOT EXISTS federationsender_transaction_json ( +CREATE TABLE IF NOT EXISTS relayapi_queue_json ( -- The JSON NID. This allows cross-referencing to find the JSON blob. json_nid BIGSERIAL, -- The JSON body. Text so that we preserve UTF-8. json_body TEXT NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS federationsender_transaction_json_json_nid_idx - ON federationsender_transaction_json (json_nid); +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx + ON relayapi_queue_json (json_nid); ` -const insertTransactionJSONSQL = "" + - "INSERT INTO federationsender_transaction_json (json_body)" + +const insertQueueJSONSQL = "" + + "INSERT INTO relayapi_queue_json (json_body)" + " VALUES ($1)" + " RETURNING json_nid" -const deleteTransactionJSONSQL = "" + - "DELETE FROM federationsender_transaction_json WHERE json_nid = ANY($1)" +const deleteQueueJSONSQL = "" + + "DELETE FROM relayapi_queue_json WHERE json_nid = ANY($1)" -const selectTransactionJSONSQL = "" + - "SELECT json_nid, json_body FROM federationsender_transaction_json" + +const selectQueueJSONSQL = "" + + "SELECT json_nid, json_body FROM relayapi_queue_json" + " WHERE json_nid = ANY($1)" -type transactionJSONStatements struct { +type relayQueueJSONStatements struct { db *sql.DB insertJSONStmt *sql.Stmt deleteJSONStmt *sql.Stmt selectJSONStmt *sql.Stmt } -func NewPostgresTransactionJSONTable(db *sql.DB) (s *transactionJSONStatements, err error) { - s = &transactionJSONStatements{ +func NewPostgresRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) { + s = &relayQueueJSONStatements{ db: db, } - _, err = s.db.Exec(transactionJSONSchema) + _, err = s.db.Exec(relayQueueJSONSchema) if err != nil { return } - if s.insertJSONStmt, err = s.db.Prepare(insertTransactionJSONSQL); err != nil { + if s.insertJSONStmt, err = s.db.Prepare(insertQueueJSONSQL); err != nil { return } - if s.deleteJSONStmt, err = s.db.Prepare(deleteTransactionJSONSQL); err != nil { + if s.deleteJSONStmt, err = s.db.Prepare(deleteQueueJSONSQL); err != nil { return } - if s.selectJSONStmt, err = s.db.Prepare(selectTransactionJSONSQL); err != nil { + if s.selectJSONStmt, err = s.db.Prepare(selectQueueJSONSQL); err != nil { return } return } -func (s *transactionJSONStatements) InsertTransactionJSON( +func (s *relayQueueJSONStatements) InsertQueueJSON( ctx context.Context, txn *sql.Tx, json string, ) (int64, error) { stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) @@ -87,7 +87,7 @@ func (s *transactionJSONStatements) InsertTransactionJSON( return lastid, nil } -func (s *transactionJSONStatements) DeleteTransactionJSON( +func (s *relayQueueJSONStatements) DeleteQueueJSON( ctx context.Context, txn *sql.Tx, nids []int64, ) error { stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt) @@ -95,7 +95,7 @@ func (s *transactionJSONStatements) DeleteTransactionJSON( return err } -func (s *transactionJSONStatements) SelectTransactionJSON( +func (s *relayQueueJSONStatements) SelectQueueJSON( ctx context.Context, txn *sql.Tx, jsonNIDs []int64, ) (map[int64][]byte, error) { blobs := map[int64][]byte{} diff --git a/federationapi/storage/postgres/queue_transactions_table.go b/relayapi/storage/postgres/relay_queue_table.go similarity index 52% rename from federationapi/storage/postgres/queue_transactions_table.go rename to relayapi/storage/postgres/relay_queue_table.go index fe67ef9cd..c2ca78c6f 100644 --- a/federationapi/storage/postgres/queue_transactions_table.go +++ b/relayapi/storage/postgres/relay_queue_table.go @@ -24,79 +24,79 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -const queueTransactionsSchema = ` -CREATE TABLE IF NOT EXISTS federationsender_queue_transactions ( +const relayQueueSchema = ` +CREATE TABLE IF NOT EXISTS relayapi_queue ( -- The transaction ID that was generated before persisting the event. transaction_id TEXT NOT NULL, -- The destination server that we will send the event to. server_name TEXT NOT NULL, - -- The JSON NID from the federationsender_transaction_json table. + -- The JSON NID from the relayapi_queue_json table. json_nid BIGINT NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_transactions_transaction_json_nid_idx - ON federationsender_queue_transactions (json_nid, server_name); -CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_json_nid_idx - ON federationsender_queue_transactions (json_nid); -CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_server_name_idx - ON federationsender_queue_transactions (server_name); +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx + ON relayapi_queue (json_nid, server_name); +CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx + ON relayapi_queue (json_nid); +CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx + ON relayapi_queue (server_name); ` -const insertQueueTransactionSQL = "" + - "INSERT INTO federationsender_queue_transactions (transaction_id, server_name, json_nid)" + +const insertQueueEntrySQL = "" + + "INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" + " VALUES ($1, $2, $3)" -const deleteQueueTransactionsSQL = "" + - "DELETE FROM federationsender_queue_transactions WHERE server_name = $1 AND json_nid = ANY($2)" +const deleteQueueEntriesSQL = "" + + "DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid = ANY($2)" -const selectQueueTransactionsSQL = "" + - "SELECT json_nid FROM federationsender_queue_transactions" + +const selectQueueEntriesSQL = "" + + "SELECT json_nid FROM relayapi_queue" + " WHERE server_name = $1" + " LIMIT $2" -const selectQueueTransactionsCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_transactions" + +const selectQueueEntryCountSQL = "" + + "SELECT COUNT(*) FROM relayapi_queue" + " WHERE server_name = $1" -type queueTransactionsStatements struct { - db *sql.DB - insertQueueTransactionStmt *sql.Stmt - deleteQueueTransactionsStmt *sql.Stmt - selectQueueTransactionsStmt *sql.Stmt - selectQueueTransactionsCountStmt *sql.Stmt +type relayQueueStatements struct { + db *sql.DB + insertQueueEntryStmt *sql.Stmt + deleteQueueEntriesStmt *sql.Stmt + selectQueueEntriesStmt *sql.Stmt + selectQueueEntryCountStmt *sql.Stmt } -func NewPostgresQueueTransactionsTable(db *sql.DB) (s *queueTransactionsStatements, err error) { - s = &queueTransactionsStatements{ +func NewPostgresRelayQueueTable(db *sql.DB) (s *relayQueueStatements, err error) { + s = &relayQueueStatements{ db: db, } - _, err = s.db.Exec(queueTransactionsSchema) + _, err = s.db.Exec(relayQueueSchema) if err != nil { return } - if s.insertQueueTransactionStmt, err = s.db.Prepare(insertQueueTransactionSQL); err != nil { + if s.insertQueueEntryStmt, err = s.db.Prepare(insertQueueEntrySQL); err != nil { return } - if s.deleteQueueTransactionsStmt, err = s.db.Prepare(deleteQueueTransactionsSQL); err != nil { + if s.deleteQueueEntriesStmt, err = s.db.Prepare(deleteQueueEntriesSQL); err != nil { return } - if s.selectQueueTransactionsStmt, err = s.db.Prepare(selectQueueTransactionsSQL); err != nil { + if s.selectQueueEntriesStmt, err = s.db.Prepare(selectQueueEntriesSQL); err != nil { return } - if s.selectQueueTransactionsCountStmt, err = s.db.Prepare(selectQueueTransactionsCountSQL); err != nil { + if s.selectQueueEntryCountStmt, err = s.db.Prepare(selectQueueEntryCountSQL); err != nil { return } return } -func (s *queueTransactionsStatements) InsertQueueTransaction( +func (s *relayQueueStatements) InsertQueueEntry( ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64, ) error { - stmt := sqlutil.TxStmt(txn, s.insertQueueTransactionStmt) + stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt) _, err := stmt.ExecContext( ctx, transactionID, // the transaction ID that we initially attempted @@ -106,22 +106,22 @@ func (s *queueTransactionsStatements) InsertQueueTransaction( return err } -func (s *queueTransactionsStatements) DeleteQueueTransactions( +func (s *relayQueueStatements) DeleteQueueEntries( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionsStmt) + stmt := sqlutil.TxStmt(txn, s.deleteQueueEntriesStmt) _, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs)) return err } -func (s *queueTransactionsStatements) SelectQueueTransactions( +func (s *relayQueueStatements) SelectQueueEntries( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int, ) ([]int64, error) { - stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsStmt) + stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt) rows, err := stmt.QueryContext(ctx, serverName, limit) if err != nil { return nil, err @@ -139,11 +139,11 @@ func (s *queueTransactionsStatements) SelectQueueTransactions( return result, rows.Err() } -func (s *queueTransactionsStatements) SelectQueueTransactionCount( +func (s *relayQueueStatements) SelectQueueEntryCount( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (int64, error) { var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsCountStmt) + stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt) err := stmt.QueryRowContext(ctx, serverName).Scan(&count) if err == sql.ErrNoRows { // It's acceptable for there to be no rows referencing a given diff --git a/relayapi/storage/postgres/storage.go b/relayapi/storage/postgres/storage.go new file mode 100644 index 000000000..3902cc8ab --- /dev/null +++ b/relayapi/storage/postgres/storage.go @@ -0,0 +1,59 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "database/sql" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// Database stores information needed by the relayapi +type Database struct { + shared.Database + db *sql.DB + writer sqlutil.Writer +} + +// NewDatabase opens a new database +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { + var d Database + var err error + if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { + return nil, err + } + queue, err := NewPostgresRelayQueueTable(d.db) + if err != nil { + return nil, err + } + queueJSON, err := NewPostgresRelayQueueJSONTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + IsLocalServerName: isLocalServerName, + Cache: cache, + Writer: d.writer, + RelayQueue: queue, + RelayQueueJSON: queueJSON, + } + return &d, nil +} diff --git a/relayapi/storage/shared/storage.go b/relayapi/storage/shared/storage.go new file mode 100644 index 000000000..04045ae92 --- /dev/null +++ b/relayapi/storage/shared/storage.go @@ -0,0 +1,142 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database struct { + DB *sql.DB + IsLocalServerName func(gomatrixserverlib.ServerName) bool + Cache caching.FederationCache + Writer sqlutil.Writer + RelayQueue tables.RelayQueue + RelayQueueJSON tables.RelayQueueJSON +} + +func (d *Database) StoreAsyncTransaction( + ctx context.Context, txn gomatrixserverlib.Transaction, +) (*shared.Receipt, error) { + var err error + json, err := json.Marshal(txn) + if err != nil { + return nil, fmt.Errorf("d.JSONUnmarshall: %w", err) + } + + var nid int64 + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + nid, err = d.RelayQueueJSON.InsertQueueJSON(ctx, txn, string(json)) + return err + }) + if err != nil { + return nil, fmt.Errorf("d.insertQueueJSON: %w", err) + } + + receipt := shared.NewReceipt(nid) + return &receipt, nil +} + +func (d *Database) AssociateAsyncTransactionWithDestinations( + ctx context.Context, + destinations map[gomatrixserverlib.UserID]struct{}, + transactionID gomatrixserverlib.TransactionID, + receipt *shared.Receipt, +) error { + for destination := range destinations { + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + err := d.RelayQueue.InsertQueueEntry( + ctx, txn, transactionID, destination.Domain(), receipt.GetNID()) + return err + }) + if err != nil { + return fmt.Errorf("d.insertQueueEntry: %w", err) + } + } + + return nil +} + +func (d *Database) CleanAsyncTransactions( + ctx context.Context, + userID gomatrixserverlib.UserID, + receipts []*shared.Receipt, +) error { + println(len(receipts)) + nids := make([]int64, len(receipts)) + for i, receipt := range receipts { + nids[i] = receipt.GetNID() + } + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + err := d.RelayQueue.DeleteQueueEntries(ctx, txn, userID.Domain(), nids) + return err + }) + if err != nil { + return fmt.Errorf("d.deleteQueueEntries: %w", err) + } + + return nil +} + +func (d *Database) GetAsyncTransaction( + ctx context.Context, + userID gomatrixserverlib.UserID, +) (*gomatrixserverlib.Transaction, *shared.Receipt, error) { + nids, err := d.RelayQueue.SelectQueueEntries(ctx, nil, userID.Domain(), 1) + if err != nil { + return nil, nil, fmt.Errorf("d.SelectQueueEntries: %w", err) + } + if len(nids) == 0 { + return nil, nil, nil + } + + txns := map[int64][]byte{} + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + txns, err = d.RelayQueueJSON.SelectQueueJSON(ctx, txn, nids) + return err + }) + if err != nil { + return nil, nil, fmt.Errorf("d.SelectQueueJSON: %w", err) + } + + transaction := &gomatrixserverlib.Transaction{} + err = json.Unmarshal(txns[nids[0]], transaction) + if err != nil { + return nil, nil, fmt.Errorf("Unmarshall transaction: %w", err) + } + + receipt := shared.NewReceipt(nids[0]) + return transaction, &receipt, nil +} + +func (d *Database) GetAsyncTransactionCount( + ctx context.Context, + userID gomatrixserverlib.UserID, +) (int64, error) { + count, err := d.RelayQueue.SelectQueueEntryCount(ctx, nil, userID.Domain()) + if err != nil { + return 0, fmt.Errorf("d.SelectQueueEntryCount: %w", err) + } + return count, nil +} diff --git a/federationapi/storage/sqlite3/transaction_json_table.go b/relayapi/storage/sqlite3/relay_queue_json_table.go similarity index 60% rename from federationapi/storage/sqlite3/transaction_json_table.go rename to relayapi/storage/sqlite3/relay_queue_json_table.go index 30ad297ac..1635847a7 100644 --- a/federationapi/storage/sqlite3/transaction_json_table.go +++ b/relayapi/storage/sqlite3/relay_queue_json_table.go @@ -24,53 +24,53 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" ) -const transactionJSONSchema = ` --- The federationsender_transaction_json table contains event contents that +const relayQueueJSONSchema = ` +-- The relayapi_queue_json table contains event contents that -- we are storing for future forwarding. -CREATE TABLE IF NOT EXISTS federationsender_transaction_json ( +CREATE TABLE IF NOT EXISTS relayapi_queue_json ( -- The JSON NID. This allows cross-referencing to find the JSON blob. json_nid INTEGER PRIMARY KEY AUTOINCREMENT, -- The JSON body. Text so that we preserve UTF-8. json_body TEXT NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS federationsender_transaction_json_json_nid_idx - ON federationsender_transaction_json (json_nid); +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx + ON relayapi_queue_json (json_nid); ` -const insertTransactionJSONSQL = "" + - "INSERT INTO federationsender_transaction_json (json_body)" + +const insertQueueJSONSQL = "" + + "INSERT INTO relayapi_queue_json (json_body)" + " VALUES ($1)" -const deleteTransactionJSONSQL = "" + - "DELETE FROM federationsender_transaction_json WHERE json_nid IN ($1)" +const deleteQueueJSONSQL = "" + + "DELETE FROM relayapi_queue_json WHERE json_nid IN ($1)" -const selectTransactionJSONSQL = "" + - "SELECT json_nid, json_body FROM federationsender_transaction_json" + +const selectQueueJSONSQL = "" + + "SELECT json_nid, json_body FROM relayapi_queue_json" + " WHERE json_nid IN ($1)" -type transactionJSONStatements struct { +type relayQueueJSONStatements struct { db *sql.DB insertJSONStmt *sql.Stmt //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic } -func NewSQLiteTransactionJSONTable(db *sql.DB) (s *transactionJSONStatements, err error) { - s = &transactionJSONStatements{ +func NewSQLiteRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) { + s = &relayQueueJSONStatements{ db: db, } - _, err = db.Exec(transactionJSONSchema) + _, err = db.Exec(relayQueueJSONSchema) if err != nil { return } - if s.insertJSONStmt, err = db.Prepare(insertTransactionJSONSQL); err != nil { + if s.insertJSONStmt, err = db.Prepare(insertQueueJSONSQL); err != nil { return } return } -func (s *transactionJSONStatements) InsertTransactionJSON( +func (s *relayQueueJSONStatements) InsertQueueJSON( ctx context.Context, txn *sql.Tx, json string, ) (lastid int64, err error) { stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) @@ -85,13 +85,13 @@ func (s *transactionJSONStatements) InsertTransactionJSON( return } -func (s *transactionJSONStatements) DeleteTransactionJSON( +func (s *relayQueueJSONStatements) DeleteQueueJSON( ctx context.Context, txn *sql.Tx, nids []int64, ) error { - deleteSQL := strings.Replace(deleteTransactionJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + deleteSQL := strings.Replace(deleteQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) deleteStmt, err := txn.Prepare(deleteSQL) if err != nil { - return fmt.Errorf("s.deleteTransactionJSON s.db.Prepare: %w", err) + return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) } iNIDs := make([]interface{}, len(nids)) @@ -104,13 +104,13 @@ func (s *transactionJSONStatements) DeleteTransactionJSON( return err } -func (s *transactionJSONStatements) SelectTransactionJSON( +func (s *relayQueueJSONStatements) SelectQueueJSON( ctx context.Context, txn *sql.Tx, jsonNIDs []int64, ) (map[int64][]byte, error) { - selectSQL := strings.Replace(selectTransactionJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1) + selectSQL := strings.Replace(selectQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1) selectStmt, err := txn.Prepare(selectSQL) if err != nil { - return nil, fmt.Errorf("s.selectTransactionJSON s.db.Prepare: %w", err) + return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err) } iNIDs := make([]interface{}, len(jsonNIDs)) @@ -122,14 +122,14 @@ func (s *transactionJSONStatements) SelectTransactionJSON( stmt := sqlutil.TxStmt(txn, selectStmt) rows, err := stmt.QueryContext(ctx, iNIDs...) if err != nil { - return nil, fmt.Errorf("s.selectTransactionJSON stmt.QueryContext: %w", err) + return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err) } - defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectQueueJSON: rows.close() failed") for rows.Next() { var nid int64 var blob []byte if err = rows.Scan(&nid, &blob); err != nil { - return nil, fmt.Errorf("s.selectTransactionJSON rows.Scan: %w", err) + return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err) } blobs[nid] = blob } diff --git a/federationapi/storage/sqlite3/queue_transactions_table.go b/relayapi/storage/sqlite3/relay_queue_table.go similarity index 52% rename from federationapi/storage/sqlite3/queue_transactions_table.go rename to relayapi/storage/sqlite3/relay_queue_table.go index e616abe78..fdaa57f00 100644 --- a/federationapi/storage/sqlite3/queue_transactions_table.go +++ b/relayapi/storage/sqlite3/relay_queue_table.go @@ -25,79 +25,79 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -const queueTransactionsSchema = ` -CREATE TABLE IF NOT EXISTS federationsender_queue_transactions ( +const relayQueueSchema = ` +CREATE TABLE IF NOT EXISTS relayapi_queue ( -- The transaction ID that was generated before persisting the event. transaction_id TEXT NOT NULL, -- The domain part of the user ID the m.room.member event is for. server_name TEXT NOT NULL, - -- The JSON NID from the federationsender_queue_transactions_json table. + -- The JSON NID from the relayapi_queue_json table. json_nid BIGINT NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_transactions_transaction_json_nid_idx - ON federationsender_queue_transactions (json_nid, server_name); -CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_json_nid_idx - ON federationsender_queue_transactions (json_nid); -CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_server_name_idx - ON federationsender_queue_transactions (server_name); +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx + ON relayapi_queue (json_nid, server_name); +CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx + ON relayapi_queue (json_nid); +CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx + ON relayapi_queue (server_name); ` -const insertQueueTransactionSQL = "" + - "INSERT INTO federationsender_queue_transactions (transaction_id, server_name, json_nid)" + +const insertQueueEntrySQL = "" + + "INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" + " VALUES ($1, $2, $3)" -const deleteQueueTransactionsSQL = "" + - "DELETE FROM federationsender_queue_transactions WHERE server_name = $1 AND json_nid IN ($2)" +const deleteQueueEntriesSQL = "" + + "DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid IN ($2)" -const selectQueueTransactionsSQL = "" + - "SELECT json_nid FROM federationsender_queue_transactions" + +const selectQueueEntriesSQL = "" + + "SELECT json_nid FROM relayapi_queue" + " WHERE server_name = $1" + " LIMIT $2" -const selectQueueTransactionsCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_transactions" + +const selectQueueEntryCountSQL = "" + + "SELECT COUNT(*) FROM relayapi_queue" + " WHERE server_name = $1" -type queueTransactionsStatements struct { - db *sql.DB - insertQueueTransactionStmt *sql.Stmt - selectQueueTransactionsStmt *sql.Stmt - selectQueueTransactionsCountStmt *sql.Stmt - // deleteQueueTransactionsStmt *sql.Stmt - prepared at runtime due to variadic +type relayQueueStatements struct { + db *sql.DB + insertQueueEntryStmt *sql.Stmt + selectQueueEntriesStmt *sql.Stmt + selectQueueEntryCountStmt *sql.Stmt + // deleteQueueEntriesStmt *sql.Stmt - prepared at runtime due to variadic } -func NewSQLiteQueueTransactionsTable(db *sql.DB) (s *queueTransactionsStatements, err error) { - s = &queueTransactionsStatements{ +func NewSQLiteRelayQueueTable(db *sql.DB) (s *relayQueueStatements, err error) { + s = &relayQueueStatements{ db: db, } - _, err = db.Exec(queueTransactionsSchema) + _, err = db.Exec(relayQueueSchema) if err != nil { return } - if s.insertQueueTransactionStmt, err = db.Prepare(insertQueueTransactionSQL); err != nil { + if s.insertQueueEntryStmt, err = db.Prepare(insertQueueEntrySQL); err != nil { return } - //if s.deleteQueueTransactionsStmt, err = db.Prepare(deleteQueueTransactionsSQL); err != nil { + //if s.deleteQueueEntriesStmt, err = db.Prepare(deleteQueueEntriesSQL); err != nil { // return //} - if s.selectQueueTransactionsStmt, err = db.Prepare(selectQueueTransactionsSQL); err != nil { + if s.selectQueueEntriesStmt, err = db.Prepare(selectQueueEntriesSQL); err != nil { return } - if s.selectQueueTransactionsCountStmt, err = db.Prepare(selectQueueTransactionsCountSQL); err != nil { + if s.selectQueueEntryCountStmt, err = db.Prepare(selectQueueEntryCountSQL); err != nil { return } return } -func (s *queueTransactionsStatements) InsertQueueTransaction( +func (s *relayQueueStatements) InsertQueueEntry( ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64, ) error { - stmt := sqlutil.TxStmt(txn, s.insertQueueTransactionStmt) + stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt) _, err := stmt.ExecContext( ctx, transactionID, // the transaction ID that we initially attempted @@ -107,15 +107,15 @@ func (s *queueTransactionsStatements) InsertQueueTransaction( return err } -func (s *queueTransactionsStatements) DeleteQueueTransactions( +func (s *relayQueueStatements) DeleteQueueEntries( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64, ) error { - deleteSQL := strings.Replace(deleteQueueTransactionsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) + deleteSQL := strings.Replace(deleteQueueEntriesSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) deleteStmt, err := txn.Prepare(deleteSQL) if err != nil { - return fmt.Errorf("s.deleteQueueTransactionJSON s.db.Prepare: %w", err) + return fmt.Errorf("s.deleteQueueEntries s.db.Prepare: %w", err) } params := make([]interface{}, len(jsonNIDs)+1) @@ -129,12 +129,12 @@ func (s *queueTransactionsStatements) DeleteQueueTransactions( return err } -func (s *queueTransactionsStatements) SelectQueueTransactions( +func (s *relayQueueStatements) SelectQueueEntries( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int, ) ([]int64, error) { - stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsStmt) + stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt) rows, err := stmt.QueryContext(ctx, serverName, limit) if err != nil { return nil, err @@ -152,11 +152,11 @@ func (s *queueTransactionsStatements) SelectQueueTransactions( return result, rows.Err() } -func (s *queueTransactionsStatements) SelectQueueTransactionCount( +func (s *relayQueueStatements) SelectQueueEntryCount( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (int64, error) { var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsCountStmt) + stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt) err := stmt.QueryRowContext(ctx, serverName).Scan(&count) if err == sql.ErrNoRows { // It's acceptable for there to be no rows referencing a given diff --git a/relayapi/storage/sqlite3/storage.go b/relayapi/storage/sqlite3/storage.go new file mode 100644 index 000000000..558f5120e --- /dev/null +++ b/relayapi/storage/sqlite3/storage.go @@ -0,0 +1,53 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "database/sql" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// Database stores information needed by the federation sender +type Database struct { + shared.Database + db *sql.DB + writer sqlutil.Writer +} + +// NewDatabase opens a new database +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { + var d Database + var err error + if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { + return nil, err + } + queue, err := NewSQLiteRelayQueueTable(d.db) + if err != nil { + return nil, err + } + queueJSON, err := NewSQLiteRelayQueueJSONTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + IsLocalServerName: isLocalServerName, + Cache: cache, + Writer: d.writer, + RelayQueue: queue, + RelayQueueJSON: queueJSON, + } + return &d, nil +} diff --git a/relayapi/storage/storage.go b/relayapi/storage/storage.go new file mode 100644 index 000000000..e4cefc1fd --- /dev/null +++ b/relayapi/storage/storage.go @@ -0,0 +1,41 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !wasm +// +build !wasm + +package storage + +import ( + "fmt" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// NewDatabase opens a new database +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (Database, error) { + switch { + case dbProperties.ConnectionString.IsSQLite(): + return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName) + case dbProperties.ConnectionString.IsPostgres(): + return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName) + default: + return nil, fmt.Errorf("unexpected database type") + } +} diff --git a/relayapi/storage/tables/interface.go b/relayapi/storage/tables/interface.go new file mode 100644 index 000000000..a615c7848 --- /dev/null +++ b/relayapi/storage/tables/interface.go @@ -0,0 +1,35 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "context" + "database/sql" + + "github.com/matrix-org/gomatrixserverlib" +) + +type RelayQueue interface { + InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error + DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error + SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) + SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) +} + +type RelayQueueJSON interface { + InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) + DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error + SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) +} diff --git a/federationapi/storage/tables/transaction_json_table_test.go b/relayapi/storage/tables/relay_queue_json_table_test.go similarity index 77% rename from federationapi/storage/tables/transaction_json_table_test.go rename to relayapi/storage/tables/relay_queue_json_table_test.go index 9569b0f0c..64ced1985 100644 --- a/federationapi/storage/tables/transaction_json_table_test.go +++ b/relayapi/storage/tables/relay_queue_json_table_test.go @@ -8,10 +8,10 @@ import ( "testing" "time" - "github.com/matrix-org/dendrite/federationapi/storage/postgres" - "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" - "github.com/matrix-org/dendrite/federationapi/storage/tables" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/relayapi/storage/tables" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" @@ -35,31 +35,31 @@ func mustCreateTransaction() gomatrixserverlib.Transaction { return txn } -type TransactionJSONDatabase struct { +type RelayQueueJSONDatabase struct { DB *sql.DB Writer sqlutil.Writer - Table tables.FederationTransactionJSON + Table tables.RelayQueueJSON } -func mustCreateTransactionJSONTable(t *testing.T, dbType test.DBType) (database TransactionJSONDatabase, close func()) { +func mustCreateQueueJSONTable(t *testing.T, dbType test.DBType) (database RelayQueueJSONDatabase, close func()) { t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) db, err := sqlutil.Open(&config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, sqlutil.NewExclusiveWriter()) assert.NoError(t, err) - var tab tables.FederationTransactionJSON + var tab tables.RelayQueueJSON switch dbType { case test.DBTypePostgres: - tab, err = postgres.NewPostgresTransactionJSONTable(db) + tab, err = postgres.NewPostgresRelayQueueJSONTable(db) assert.NoError(t, err) case test.DBTypeSQLite: - tab, err = sqlite3.NewSQLiteTransactionJSONTable(db) + tab, err = sqlite3.NewSQLiteRelayQueueJSONTable(db) assert.NoError(t, err) } assert.NoError(t, err) - database = TransactionJSONDatabase{ + database = RelayQueueJSONDatabase{ DB: db, Writer: sqlutil.NewDummyWriter(), Table: tab, @@ -70,7 +70,7 @@ func mustCreateTransactionJSONTable(t *testing.T, dbType test.DBType) (database func TestShoudInsertTransaction(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateTransactionJSONTable(t, dbType) + db, close := mustCreateQueueJSONTable(t, dbType) defer close() transaction := mustCreateTransaction() @@ -79,7 +79,7 @@ func TestShoudInsertTransaction(t *testing.T) { t.Fatalf("Invalid transaction: %s", err.Error()) } - _, err = db.Table.InsertTransactionJSON(ctx, nil, string(tx)) + _, err = db.Table.InsertQueueJSON(ctx, nil, string(tx)) if err != nil { t.Fatalf("Failed inserting transaction: %s", err.Error()) } @@ -89,7 +89,7 @@ func TestShoudInsertTransaction(t *testing.T) { func TestShouldRetrieveInsertedTransaction(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateTransactionJSONTable(t, dbType) + db, close := mustCreateQueueJSONTable(t, dbType) defer close() transaction := mustCreateTransaction() @@ -98,14 +98,14 @@ func TestShouldRetrieveInsertedTransaction(t *testing.T) { t.Fatalf("Invalid transaction: %s", err.Error()) } - nid, err := db.Table.InsertTransactionJSON(ctx, nil, string(tx)) + nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx)) if err != nil { t.Fatalf("Failed inserting transaction: %s", err.Error()) } var storedJSON map[int64][]byte _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { - storedJSON, err = db.Table.SelectTransactionJSON(ctx, txn, []int64{nid}) + storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid}) return err }) if err != nil { @@ -124,7 +124,7 @@ func TestShouldRetrieveInsertedTransaction(t *testing.T) { func TestShouldDeleteTransaction(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateTransactionJSONTable(t, dbType) + db, close := mustCreateQueueJSONTable(t, dbType) defer close() transaction := mustCreateTransaction() @@ -133,14 +133,14 @@ func TestShouldDeleteTransaction(t *testing.T) { t.Fatalf("Invalid transaction: %s", err.Error()) } - nid, err := db.Table.InsertTransactionJSON(ctx, nil, string(tx)) + nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx)) if err != nil { t.Fatalf("Failed inserting transaction: %s", err.Error()) } storedJSON := map[int64][]byte{} _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { - err = db.Table.DeleteTransactionJSON(ctx, txn, []int64{nid}) + err = db.Table.DeleteQueueJSON(ctx, txn, []int64{nid}) return err }) if err != nil { @@ -149,7 +149,7 @@ func TestShouldDeleteTransaction(t *testing.T) { storedJSON = map[int64][]byte{} _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { - storedJSON, err = db.Table.SelectTransactionJSON(ctx, txn, []int64{nid}) + storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid}) return err }) if err != nil { diff --git a/federationapi/storage/tables/queue_transactions_table_test.go b/relayapi/storage/tables/relay_queue_table_test.go similarity index 68% rename from federationapi/storage/tables/queue_transactions_table_test.go rename to relayapi/storage/tables/relay_queue_table_test.go index 9266f6c95..26d764583 100644 --- a/federationapi/storage/tables/queue_transactions_table_test.go +++ b/relayapi/storage/tables/relay_queue_table_test.go @@ -7,41 +7,41 @@ import ( "testing" "time" - "github.com/matrix-org/dendrite/federationapi/storage/postgres" - "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" - "github.com/matrix-org/dendrite/federationapi/storage/tables" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/relayapi/storage/tables" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/stretchr/testify/assert" ) -type QueueTransactionsDatabase struct { +type RelayQueueDatabase struct { DB *sql.DB Writer sqlutil.Writer - Table tables.FederationQueueTransactions + Table tables.RelayQueue } -func mustCreateQueueTransactionsTable(t *testing.T, dbType test.DBType) (database QueueTransactionsDatabase, close func()) { +func mustCreateQueueTable(t *testing.T, dbType test.DBType) (database RelayQueueDatabase, close func()) { t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) db, err := sqlutil.Open(&config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, sqlutil.NewExclusiveWriter()) assert.NoError(t, err) - var tab tables.FederationQueueTransactions + var tab tables.RelayQueue switch dbType { case test.DBTypePostgres: - tab, err = postgres.NewPostgresQueueTransactionsTable(db) + tab, err = postgres.NewPostgresRelayQueueTable(db) assert.NoError(t, err) case test.DBTypeSQLite: - tab, err = sqlite3.NewSQLiteQueueTransactionsTable(db) + tab, err = sqlite3.NewSQLiteRelayQueueTable(db) assert.NoError(t, err) } assert.NoError(t, err) - database = QueueTransactionsDatabase{ + database = RelayQueueDatabase{ DB: db, Writer: sqlutil.NewDummyWriter(), Table: tab, @@ -52,13 +52,13 @@ func mustCreateQueueTransactionsTable(t *testing.T, dbType test.DBType) (databas func TestShoudInsertQueueTransaction(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateQueueTransactionsTable(t, dbType) + db, close := mustCreateQueueTable(t, dbType) defer close() transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) serverName := gomatrixserverlib.ServerName("domain") nid := int64(1) - err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid) + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) if err != nil { t.Fatalf("Failed inserting transaction: %s", err.Error()) } @@ -68,19 +68,19 @@ func TestShoudInsertQueueTransaction(t *testing.T) { func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateQueueTransactionsTable(t, dbType) + db, close := mustCreateQueueTable(t, dbType) defer close() transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) serverName := gomatrixserverlib.ServerName("domain") nid := int64(1) - err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid) + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) if err != nil { t.Fatalf("Failed inserting transaction: %s", err.Error()) } - retrievedNids, err := db.Table.SelectQueueTransactions(ctx, nil, serverName, 10) + retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 10) if err != nil { t.Fatalf("Failed retrieving transaction: %s", err.Error()) } @@ -93,27 +93,27 @@ func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) { func TestShouldDeleteQueueTransaction(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateQueueTransactionsTable(t, dbType) + db, close := mustCreateQueueTable(t, dbType) defer close() transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) serverName := gomatrixserverlib.ServerName("domain") nid := int64(1) - err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid) + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) if err != nil { t.Fatalf("Failed inserting transaction: %s", err.Error()) } _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { - err = db.Table.DeleteQueueTransactions(ctx, txn, serverName, []int64{nid}) + err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid}) return err }) if err != nil { t.Fatalf("Failed deleting transaction: %s", err.Error()) } - count, err := db.Table.SelectQueueTransactionCount(ctx, nil, serverName) + count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName) if err != nil { t.Fatalf("Failed retrieving transaction count: %s", err.Error()) } @@ -124,7 +124,7 @@ func TestShouldDeleteQueueTransaction(t *testing.T) { func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateQueueTransactionsTable(t, dbType) + db, close := mustCreateQueueTable(t, dbType) defer close() transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) @@ -135,34 +135,34 @@ func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) { nid2 := int64(2) transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano())) - err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid) + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) if err != nil { t.Fatalf("Failed inserting transaction: %s", err.Error()) } - err = db.Table.InsertQueueTransaction(ctx, nil, transactionID2, serverName2, nid) + err = db.Table.InsertQueueEntry(ctx, nil, transactionID2, serverName2, nid) if err != nil { t.Fatalf("Failed inserting transaction: %s", err.Error()) } - err = db.Table.InsertQueueTransaction(ctx, nil, transactionID3, serverName, nid2) + err = db.Table.InsertQueueEntry(ctx, nil, transactionID3, serverName, nid2) if err != nil { t.Fatalf("Failed inserting transaction: %s", err.Error()) } _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { - err = db.Table.DeleteQueueTransactions(ctx, txn, serverName, []int64{nid}) + err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid}) return err }) if err != nil { t.Fatalf("Failed deleting transaction: %s", err.Error()) } - count, err := db.Table.SelectQueueTransactionCount(ctx, nil, serverName) + count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName) if err != nil { t.Fatalf("Failed retrieving transaction count: %s", err.Error()) } assert.Equal(t, int64(1), count) - count, err = db.Table.SelectQueueTransactionCount(ctx, nil, serverName2) + count, err = db.Table.SelectQueueEntryCount(ctx, nil, serverName2) if err != nil { t.Fatalf("Failed retrieving transaction count: %s", err.Error()) } diff --git a/setup/config/config.go b/setup/config/config.go index 7e7ed1aa1..a2d1d3def 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -62,6 +62,7 @@ type Dendrite struct { RoomServer RoomServer `yaml:"room_server"` SyncAPI SyncAPI `yaml:"sync_api"` UserAPI UserAPI `yaml:"user_api"` + RelayAPI RelayAPI `yaml:"relay_api"` MSCs MSCs `yaml:"mscs"` diff --git a/setup/config/config_relayapi.go b/setup/config/config_relayapi.go new file mode 100644 index 000000000..d2aac8cb2 --- /dev/null +++ b/setup/config/config_relayapi.go @@ -0,0 +1,38 @@ +package config + +type RelayAPI struct { + Matrix *Global `yaml:"-"` + + InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` + ExternalAPI ExternalAPIOptions `yaml:"external_api,omitempty"` + + // The database stores information used by the relay queue to + // forward transactions to remote servers. + Database DatabaseOptions `yaml:"database,omitempty"` +} + +func (c *RelayAPI) Defaults(opts DefaultOpts) { + if !opts.Monolithic { + c.InternalAPI.Listen = "http://localhost:7775" + c.InternalAPI.Connect = "http://localhost:7775" + c.ExternalAPI.Listen = "http://[::]:8075" + c.Database.Defaults(10) + } + if opts.Generate { + if !opts.Monolithic { + c.Database.ConnectionString = "file:relayapi.db" + } + } +} + +func (c *RelayAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { + if isMonolith { // polylith required configs below + return + } + if c.Matrix.DatabaseOptions.ConnectionString == "" { + checkNotEmpty(configErrs, "relay_api.database.connection_string", string(c.Database.ConnectionString)) + } + checkURL(configErrs, "relay_api.external_api.listen", string(c.ExternalAPI.Listen)) + checkURL(configErrs, "relay_api.internal_api.listen", string(c.InternalAPI.Listen)) + checkURL(configErrs, "relay_api.internal_api.connect", string(c.InternalAPI.Connect)) +} diff --git a/setup/monolith.go b/setup/monolith.go index 41a897024..8b88b6555 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -23,6 +23,8 @@ import ( "github.com/matrix-org/dendrite/internal/transactions" keyAPI "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/mediaapi" + "github.com/matrix-org/dendrite/relayapi" + relayAPI "github.com/matrix-org/dendrite/relayapi/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" @@ -44,6 +46,7 @@ type Monolith struct { RoomserverAPI roomserverAPI.RoomserverInternalAPI UserAPI userapi.UserInternalAPI KeyAPI keyAPI.KeyInternalAPI + RelayAPI relayAPI.RelayInternalAPI // Optional ExtPublicRoomsProvider api.ExtraPublicRoomsProvider @@ -71,4 +74,9 @@ func (m *Monolith) AddAllPublicRoutes(base *base.BaseDendrite) { syncapi.AddPublicRoutes( base, m.UserAPI, m.RoomserverAPI, m.KeyAPI, ) + + // TODO : relayapi.AddPublicRoutes + if m.RelayAPI != nil { + relayapi.AddPublicRoutes(base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.RelayAPI, m.FederationAPI, m.KeyAPI) + } }