Refactor all relay specific stuff into it's own component

This commit is contained in:
Devon Hudson 2022-12-14 18:41:27 -07:00
parent f300a4d0e9
commit ad53326ce8
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
39 changed files with 1433 additions and 850 deletions

View file

@ -458,7 +458,7 @@ func (m *DendriteMonolith) Start() {
switch e := event.(type) { switch e := event.(type) {
case pineconeEvents.PeerAdded: case pineconeEvents.PeerAdded:
if !relayServerSyncRunning.Load() { if !relayServerSyncRunning.Load() {
go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning) // go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning)
} }
case pineconeEvents.PeerRemoved: case pineconeEvents.PeerRemoved:
if relayServerSyncRunning.Load() && m.PineconeRouter.PeerCount(-1) == 0 { if relayServerSyncRunning.Load() && m.PineconeRouter.PeerCount(-1) == 0 {
@ -486,46 +486,6 @@ func (m *DendriteMonolith) Start() {
}(pineconeEventChannel) }(pineconeEventChannel)
} }
func (m *DendriteMonolith) syncRelayServers(stop <-chan bool, running atomic.Bool) {
defer running.Store(false)
t := time.NewTimer(relayServerRetryInterval)
for {
relayServersToQuery := []gomatrixserverlib.ServerName{}
for server, complete := range m.relayServersQueried {
if !complete {
relayServersToQuery = append(relayServersToQuery, server)
}
}
if len(relayServersToQuery) == 0 {
// All relay servers have been synced.
return
}
m.queryRelayServers(relayServersToQuery)
t.Reset(relayServerRetryInterval)
select {
case <-stop:
if !t.Stop() {
<-t.C
}
return
case <-t.C:
}
}
}
func (m *DendriteMonolith) queryRelayServers(relayServers []gomatrixserverlib.ServerName) {
for _, server := range relayServers {
request := api.PerformRelayServerSyncRequest{RelayServer: server}
response := api.PerformRelayServerSyncResponse{}
err := m.federationAPI.PerformRelayServerSync(m.processContext.Context(), &request, &response)
if err == nil {
m.relayServersQueried[server] = true
}
}
}
func (m *DendriteMonolith) Stop() { func (m *DendriteMonolith) Stop() {
m.processContext.ShutdownDendrite() m.processContext.ShutdownDendrite()
_ = m.listener.Close() _ = m.listener.Close()

View file

@ -43,6 +43,7 @@ import (
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/relayapi" "github.com/matrix-org/dendrite/relayapi"
relayServerAPI "github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
@ -145,6 +146,7 @@ func main() {
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(*instanceDir, *instanceName))) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(*instanceDir, *instanceName)))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(*instanceDir, *instanceName))) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(*instanceDir, *instanceName)))
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", filepath.Join(*instanceDir, *instanceName))) cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", filepath.Join(*instanceDir, *instanceName)))
cfg.RelayAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-relayapi.db", filepath.Join(*instanceDir, *instanceName)))
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName))) cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName)))
cfg.ClientAPI.RegistrationDisabled = false cfg.ClientAPI.RegistrationDisabled = false
@ -235,6 +237,20 @@ func main() {
userProvider := users.NewPineconeUserProvider(pRouter, pQUIC, userAPI, federation) userProvider := users.NewPineconeUserProvider(pRouter, pQUIC, userAPI, federation)
roomProvider := rooms.NewPineconeRoomProvider(pRouter, pQUIC, fsAPI, federation) roomProvider := rooms.NewPineconeRoomProvider(pRouter, pQUIC, fsAPI, federation)
js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
producer := &producers.SyncAPIProducer{
JetStream: js,
TopicReceiptEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent),
TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
TopicTypingEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent),
TopicPresenceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent),
TopicDeviceListUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
TopicSigningKeyUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
Config: &base.Cfg.FederationAPI,
UserAPI: userAPI,
}
relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer)
monolith := setup.Monolith{ monolith := setup.Monolith{
Config: base.Cfg, Config: base.Cfg,
Client: conn.CreateClient(base, pQUIC), Client: conn.CreateClient(base, pQUIC),
@ -246,6 +262,7 @@ func main() {
RoomserverAPI: rsAPI, RoomserverAPI: rsAPI,
UserAPI: userAPI, UserAPI: userAPI,
KeyAPI: keyAPI, KeyAPI: keyAPI,
RelayAPI: &relayAPI,
ExtPublicRoomsProvider: roomProvider, ExtPublicRoomsProvider: roomProvider,
ExtUserDirectoryProvider: userProvider, ExtUserDirectoryProvider: userProvider,
} }
@ -319,32 +336,12 @@ func main() {
relayServerSyncRunning := atomic.NewBool(false) relayServerSyncRunning := atomic.NewBool(false)
stopRelayServerSync := make(chan bool) stopRelayServerSync := make(chan bool)
js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.FederationAPI.Matrix.JetStream)
producer := &producers.SyncAPIProducer{
JetStream: js,
TopicReceiptEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent),
TopicSendToDeviceEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
TopicTypingEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputTypingEvent),
TopicPresenceEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent),
TopicDeviceListUpdate: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
TopicSigningKeyUpdate: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
Config: &base.Cfg.FederationAPI,
UserAPI: userAPI,
}
m := RelayServerRetriever{ m := RelayServerRetriever{
Context: context.Background(), Context: context.Background(),
ServerName: gomatrixserverlib.ServerName(pRouter.PublicKey().String()), ServerName: gomatrixserverlib.ServerName(pRouter.PublicKey().String()),
FederationAPI: fsAPI, FederationAPI: fsAPI,
RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool), RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool),
RelayAPI: relayapi.NewRelayAPI( RelayAPI: monolith.RelayAPI,
federation,
rsAPI,
keyRing,
producer,
cfg.Global.Presence.EnableInbound,
cfg.Global.ServerName,
),
} }
m.InitializeRelayServers(eLog) m.InitializeRelayServers(eLog)
@ -387,7 +384,7 @@ type RelayServerRetriever struct {
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
FederationAPI api.FederationInternalAPI FederationAPI api.FederationInternalAPI
RelayServersQueried map[gomatrixserverlib.ServerName]bool RelayServersQueried map[gomatrixserverlib.ServerName]bool
RelayAPI relayapi.RelayAPI RelayAPI relayServerAPI.RelayInternalAPI
} }
func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) {
@ -440,7 +437,12 @@ func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverli
if err != nil { if err != nil {
return return
} }
err = m.RelayAPI.PerformRelayServerSync(*userID, server) request := relayServerAPI.PerformRelayServerSyncRequest{
UserID: *userID,
RelayServer: server,
}
response := relayServerAPI.PerformRelayServerSyncResponse{}
err = m.RelayAPI.PerformRelayServerSync(context.Background(), &request, &response)
if err == nil { if err == nil {
m.RelayServersQueried[server] = true m.RelayServersQueried[server] = true
// TODO : What happens if your relay receives new messages after this point? // TODO : What happens if your relay receives new messages after this point?

View file

@ -18,7 +18,6 @@ type FederationInternalAPI interface {
gomatrixserverlib.KeyDatabase gomatrixserverlib.KeyDatabase
ClientFederationAPI ClientFederationAPI
RoomserverFederationAPI RoomserverFederationAPI
RelayServerAPI
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
@ -45,22 +44,6 @@ type FederationInternalAPI interface {
) error ) error
} }
type RelayServerAPI interface {
// Store async transactions for forwarding to the destination at a later time.
PerformStoreAsync(
ctx context.Context,
request *PerformStoreAsyncRequest,
response *PerformStoreAsyncResponse,
) error
// Obtain the oldest stored transaction for the specified userID.
QueryAsyncTransactions(
ctx context.Context,
request *QueryAsyncTransactionsRequest,
response *QueryAsyncTransactionsResponse,
) error
}
type ClientFederationAPI interface { type ClientFederationAPI interface {
// Query the server names of the joined hosts in a room. // Query the server names of the joined hosts in a room.
// Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice // Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice

View file

@ -14,7 +14,6 @@ import (
"github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/consumers" "github.com/matrix-org/dendrite/federationapi/consumers"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/dendrite/roomserver/version"
) )
@ -847,64 +846,3 @@ func (r *FederationInternalAPI) QueryRelayServers(
response.RelayServers = relayServers response.RelayServers = relayServers
return nil return nil
} }
// PerformStoreAsync implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformStoreAsync(
ctx context.Context,
request *api.PerformStoreAsyncRequest,
response *api.PerformStoreAsyncResponse,
) error {
logrus.Warnf("Storing transaction for %v", request.UserID)
receipt, err := r.db.StoreAsyncTransaction(ctx, request.Txn)
if err != nil {
return err
}
err = r.db.AssociateAsyncTransactionWithDestinations(
ctx,
map[gomatrixserverlib.UserID]struct{}{
request.UserID: {},
},
request.Txn.TransactionID,
receipt)
return err
}
// QueryAsyncTransactions implements api.FederationInternalAPI
func (r *FederationInternalAPI) QueryAsyncTransactions(
ctx context.Context,
request *api.QueryAsyncTransactionsRequest,
response *api.QueryAsyncTransactionsResponse,
) error {
logrus.Warnf("Obtaining transaction for %v", request.UserID)
transaction, receipt, err := r.db.GetAsyncTransaction(ctx, request.UserID)
if err != nil {
return err
}
// TODO : Shouldn't be deleting unless the transaction was successfully returned...
// TODO : Should delete transaction json from table if no more associations
// Maybe track last received transaction, and send that as part of the request,
// then delete before getting the new events from the db.
if transaction != nil && receipt != nil {
err = r.db.CleanAsyncTransactions(ctx, request.UserID, []*shared.Receipt{receipt})
if err != nil {
return err
}
// TODO : Clean async transactions json
}
// TODO : These db calls should happen at the same time right?
count, err := r.db.GetAsyncTransactionCount(ctx, request.UserID)
if err != nil {
return err
}
response.RemainingCount = uint32(count)
if transaction != nil {
response.Txn = *transaction
logrus.Warnf("Obtained transaction: %v", transaction.TransactionID)
}
return nil
}

View file

@ -26,9 +26,6 @@ const (
FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers" FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers"
FederationAPIQueryRelayServers = "/federationapi/queryRelayServers" FederationAPIQueryRelayServers = "/federationapi/queryRelayServers"
FederationAPIPerformStoreAsyncPath = "/federationapi/performStoreAsync"
FederationAPIQueryAsyncTransactionsPath = "/federationapi/queryAsyncTransactions"
FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices"
FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" FederationAPIClaimKeysPath = "/federationapi/client/claimKeys"
FederationAPIQueryKeysPath = "/federationapi/client/queryKeys" FederationAPIQueryKeysPath = "/federationapi/client/queryKeys"
@ -525,25 +522,3 @@ func (h *httpFederationInternalAPI) QueryRelayServers(
h.httpClient, ctx, request, response, h.httpClient, ctx, request, response,
) )
} }
func (h *httpFederationInternalAPI) PerformStoreAsync(
ctx context.Context,
request *api.PerformStoreAsyncRequest,
response *api.PerformStoreAsyncResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformStoreAsync", h.federationAPIURL+FederationAPIPerformStoreAsyncPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpFederationInternalAPI) QueryAsyncTransactions(
ctx context.Context,
request *api.QueryAsyncTransactionsRequest,
response *api.QueryAsyncTransactionsResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryAsyncTransactions", h.federationAPIURL+FederationAPIQueryAsyncTransactionsPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -43,16 +43,6 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU), httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU),
) )
internalAPIMux.Handle(
FederationAPIPerformStoreAsyncPath,
httputil.MakeInternalRPCAPI("FederationAPIPerformStoreAsync", intAPI.PerformStoreAsync),
)
internalAPIMux.Handle(
FederationAPIQueryAsyncTransactionsPath,
httputil.MakeInternalRPCAPI("FederationAPIQueryAsyncTransactions", intAPI.QueryAsyncTransactions),
)
internalAPIMux.Handle( internalAPIMux.Handle(
FederationAPIPerformWakeupServers, FederationAPIPerformWakeupServers,
httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", intAPI.PerformWakeupServers), httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", intAPI.PerformWakeupServers),

View file

@ -1,199 +0,0 @@
package routing_test
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"sync"
"testing"
"time"
"github.com/matrix-org/dendrite/federationapi/internal"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
const (
testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
testDestination = gomatrixserverlib.ServerName("white.orchard")
)
type testDatabase struct {
nid int64
nidMutex sync.Mutex
transactions map[int64]json.RawMessage
associations map[gomatrixserverlib.ServerName][]int64
}
func createDatabase() *testDatabase {
return &testDatabase{
nid: 1,
nidMutex: sync.Mutex{},
transactions: make(map[int64]json.RawMessage),
associations: make(map[gomatrixserverlib.ServerName][]int64),
}
}
func (d *testDatabase) InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error {
if _, ok := d.associations[serverName]; !ok {
d.associations[serverName] = []int64{}
}
d.associations[serverName] = append(d.associations[serverName], nid)
return nil
}
func (d *testDatabase) DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error {
for _, nid := range jsonNIDs {
for index, associatedNID := range d.associations[serverName] {
if associatedNID == nid {
d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...)
}
}
}
return nil
}
func (d *testDatabase) SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) {
results := []int64{}
resultCount := limit
if limit > len(d.associations[serverName]) {
resultCount = len(d.associations[serverName])
}
if resultCount > 0 {
for i := 0; i < resultCount; i++ {
results = append(results, d.associations[serverName][i])
}
}
return results, nil
}
func (d *testDatabase) SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) {
return int64(len(d.associations[serverName])), nil
}
func (d *testDatabase) InsertTransactionJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) {
d.nidMutex.Lock()
defer d.nidMutex.Unlock()
nid := d.nid
d.transactions[nid] = []byte(json)
d.nid++
return nid, nil
}
func (d *testDatabase) DeleteTransactionJSON(ctx context.Context, txn *sql.Tx, nids []int64) error {
for _, nid := range nids {
delete(d.transactions, nid)
}
return nil
}
func (d *testDatabase) SelectTransactionJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) {
result := make(map[int64][]byte)
for _, nid := range jsonNIDs {
if transaction, ok := d.transactions[nid]; ok {
result[nid] = transaction
}
}
return result, nil
}
func createTransaction() gomatrixserverlib.Transaction {
txn := gomatrixserverlib.Transaction{}
txn.PDUs = []json.RawMessage{
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`),
}
txn.Origin = testOrigin
txn.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
txn.Destination = testDestination
return txn
}
func createFederationRequest(userID gomatrixserverlib.UserID) (gomatrixserverlib.Transaction, gomatrixserverlib.FederationRequest) {
txn := createTransaction()
var federationPathPrefixV1 = "/_matrix/federation/v1"
path := federationPathPrefixV1 + "/forward_async/" + string(txn.TransactionID) + "/" + userID.Raw()
request := gomatrixserverlib.NewFederationRequest("PUT", txn.Origin, txn.Destination, path)
request.SetContent(txn)
return txn, request
}
func TestEmptyForwardReturnsOk(t *testing.T) {
testDB := createDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
FederationQueueTransactions: testDB,
FederationTransactionJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
_, request := createFederationRequest(*userID)
fedAPI := internal.NewFederationInternalAPI(
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
)
response := routing.ForwardAsync(httpReq, &request, fedAPI, "1", *userID)
expected := 200
if response.Code != expected {
t.Fatalf("Expected: %v, Actual: %v", expected, response.Code)
}
}
func TestUniqueTransactionStoredInDatabase(t *testing.T) {
testDB := createDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
FederationQueueTransactions: testDB,
FederationTransactionJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
inputTransaction, request := createFederationRequest(*userID)
fedAPI := internal.NewFederationInternalAPI(
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
)
response := routing.ForwardAsync(
httpReq, &request, fedAPI, inputTransaction.TransactionID, *userID)
transaction, _, err := db.GetAsyncTransaction(context.TODO(), *userID)
if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
expected := 200
if response.Code != expected {
t.Fatalf("Expected Return Code: %v, Actual: %v", expected, response.Code)
}
if transactionCount != 1 {
t.Fatalf("Expected count of 1, Actual: %d", transactionCount)
}
if transaction.TransactionID != inputTransaction.TransactionID {
t.Fatalf("Expected Transaction ID: %s, Actual: %s",
inputTransaction.TransactionID, transaction.TransactionID)
}
}

View file

@ -133,37 +133,6 @@ func Setup(
}, },
)).Methods(http.MethodPut, http.MethodOptions) )).Methods(http.MethodPut, http.MethodOptions)
v1fedmux.Handle("/forward_async/{txnID}/{userID}", MakeFedAPI(
"federation_forward_async", "", cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username was invalid"),
}
}
return ForwardAsync(
httpReq, request, fsAPI, gomatrixserverlib.TransactionID(vars["txnID"]),
*userID,
)
},
)).Methods(http.MethodPut, http.MethodOptions)
v1fedmux.Handle("/async_events/{userID}", MakeFedAPI(
"federation_async_events", "", cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username was invalid"),
}
}
return GetAsyncEvents(httpReq, request, fsAPI, *userID)
},
)).Methods(http.MethodGet, http.MethodOptions)
v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI(
"federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {

View file

@ -51,12 +51,6 @@ type Database interface {
GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
StoreAsyncTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*shared.Receipt, error)
AssociateAsyncTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *shared.Receipt) error
CleanAsyncTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*shared.Receipt) error
GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *shared.Receipt, error)
GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error)
// these don't have contexts passed in as we want things to happen regardless of the request context // these don't have contexts passed in as we want things to happen regardless of the request context
AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error
RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error

View file

@ -58,18 +58,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, err return nil, err
} }
queueTransactions, err := NewPostgresQueueTransactionsTable(d.db)
if err != nil {
return nil, err
}
queueJSON, err := NewPostgresQueueJSONTable(d.db) queueJSON, err := NewPostgresQueueJSONTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
transactionJSON, err := NewPostgresTransactionJSONTable(d.db)
if err != nil {
return nil, err
}
assumedOffline, err := NewPostgresAssumedOfflineTable(d.db) assumedOffline, err := NewPostgresAssumedOfflineTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -111,24 +103,22 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
return nil, err return nil, err
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
IsLocalServerName: isLocalServerName, IsLocalServerName: isLocalServerName,
Cache: cache, Cache: cache,
Writer: d.writer, Writer: d.writer,
FederationJoinedHosts: joinedHosts, FederationJoinedHosts: joinedHosts,
FederationQueuePDUs: queuePDUs, FederationQueuePDUs: queuePDUs,
FederationQueueEDUs: queueEDUs, FederationQueueEDUs: queueEDUs,
FederationQueueJSON: queueJSON, FederationQueueJSON: queueJSON,
FederationQueueTransactions: queueTransactions, FederationBlacklist: blacklist,
FederationTransactionJSON: transactionJSON, FederationAssumedOffline: assumedOffline,
FederationBlacklist: blacklist, FederationRelayServers: relayServers,
FederationAssumedOffline: assumedOffline, FederationInboundPeeks: inboundPeeks,
FederationRelayServers: relayServers, FederationOutboundPeeks: outboundPeeks,
FederationInboundPeeks: inboundPeeks, NotaryServerKeysJSON: notaryJSON,
FederationOutboundPeeks: outboundPeeks, NotaryServerKeysMetadata: notaryMetadata,
NotaryServerKeysJSON: notaryJSON, ServerSigningKeys: serverSigningKeys,
NotaryServerKeysMetadata: notaryMetadata,
ServerSigningKeys: serverSigningKeys,
} }
return &d, nil return &d, nil
} }

View file

@ -17,7 +17,6 @@ package shared
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"fmt" "fmt"
"time" "time"
@ -29,24 +28,22 @@ import (
) )
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
IsLocalServerName func(gomatrixserverlib.ServerName) bool IsLocalServerName func(gomatrixserverlib.ServerName) bool
Cache caching.FederationCache Cache caching.FederationCache
Writer sqlutil.Writer Writer sqlutil.Writer
FederationQueuePDUs tables.FederationQueuePDUs FederationQueuePDUs tables.FederationQueuePDUs
FederationQueueEDUs tables.FederationQueueEDUs FederationQueueEDUs tables.FederationQueueEDUs
FederationQueueJSON tables.FederationQueueJSON FederationQueueJSON tables.FederationQueueJSON
FederationJoinedHosts tables.FederationJoinedHosts FederationJoinedHosts tables.FederationJoinedHosts
FederationBlacklist tables.FederationBlacklist FederationBlacklist tables.FederationBlacklist
FederationAssumedOffline tables.FederationAssumedOffline FederationAssumedOffline tables.FederationAssumedOffline
FederationRelayServers tables.FederationRelayServers FederationRelayServers tables.FederationRelayServers
FederationOutboundPeeks tables.FederationOutboundPeeks FederationOutboundPeeks tables.FederationOutboundPeeks
FederationInboundPeeks tables.FederationInboundPeeks FederationInboundPeeks tables.FederationInboundPeeks
NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON
NotaryServerKeysMetadata tables.FederationNotaryServerKeysMetadata NotaryServerKeysMetadata tables.FederationNotaryServerKeysMetadata
ServerSigningKeys tables.FederationServerSigningKeys ServerSigningKeys tables.FederationServerSigningKeys
FederationQueueTransactions tables.FederationQueueTransactions
FederationTransactionJSON tables.FederationTransactionJSON
} }
// An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. // An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs.
@ -61,6 +58,10 @@ func NewReceipt(nid int64) Receipt {
return Receipt{nid: nid} return Receipt{nid: nid}
} }
func (r *Receipt) GetNID() int64 {
return r.nid
}
func (r *Receipt) String() string { func (r *Receipt) String() string {
return fmt.Sprintf("%d", r.nid) return fmt.Sprintf("%d", r.nid)
} }
@ -308,108 +309,3 @@ func (d *Database) GetNotaryKeys(
}) })
return sks, err return sks, err
} }
func (d *Database) StoreAsyncTransaction(
ctx context.Context, txn gomatrixserverlib.Transaction,
) (*Receipt, error) {
var err error
json, err := json.Marshal(txn)
if err != nil {
return nil, fmt.Errorf("d.JSONUnmarshall: %w", err)
}
var nid int64
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nid, err = d.FederationTransactionJSON.InsertTransactionJSON(ctx, txn, string(json))
return err
})
if err != nil {
return nil, fmt.Errorf("d.insertTransactionJSON: %w", err)
}
return &Receipt{
nid: nid,
}, nil
}
func (d *Database) AssociateAsyncTransactionWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.UserID]struct{},
transactionID gomatrixserverlib.TransactionID,
receipt *Receipt,
) error {
for destination := range destinations {
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.FederationQueueTransactions.InsertQueueTransaction(
ctx, txn, transactionID, destination.Domain(), receipt.nid)
return err
})
if err != nil {
return fmt.Errorf("d.insertQueueTransaction: %w", err)
}
}
return nil
}
func (d *Database) CleanAsyncTransactions(
ctx context.Context,
userID gomatrixserverlib.UserID,
receipts []*Receipt,
) error {
println(len(receipts))
nids := make([]int64, len(receipts))
for i, receipt := range receipts {
nids[i] = receipt.nid
}
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.FederationQueueTransactions.DeleteQueueTransactions(ctx, txn, userID.Domain(), nids)
return err
})
if err != nil {
return fmt.Errorf("d.insertQueueTransaction: %w", err)
}
return nil
}
func (d *Database) GetAsyncTransaction(
ctx context.Context,
userID gomatrixserverlib.UserID,
) (*gomatrixserverlib.Transaction, *Receipt, error) {
nids, err := d.FederationQueueTransactions.SelectQueueTransactions(ctx, nil, userID.Domain(), 1)
if err != nil {
return nil, nil, fmt.Errorf("d.SelectQueueTransaction: %w", err)
}
if len(nids) == 0 {
return nil, nil, nil
}
txns := map[int64][]byte{}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
txns, err = d.FederationTransactionJSON.SelectTransactionJSON(ctx, txn, nids)
return err
})
if err != nil {
return nil, nil, fmt.Errorf("d.SelectTransactionJSON: %w", err)
}
transaction := &gomatrixserverlib.Transaction{}
err = json.Unmarshal(txns[nids[0]], transaction)
if err != nil {
return nil, nil, fmt.Errorf("Unmarshall transaction: %w", err)
}
receipt := NewReceipt(nids[0])
return transaction, &receipt, nil
}
func (d *Database) GetAsyncTransactionCount(
ctx context.Context,
userID gomatrixserverlib.UserID,
) (int64, error) {
count, err := d.FederationQueueTransactions.SelectQueueTransactionCount(ctx, nil, userID.Domain())
if err != nil {
return 0, fmt.Errorf("d.SelectQueueTransactionCount: %w", err)
}
return count, nil
}

View file

@ -55,14 +55,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, err return nil, err
} }
queueTransactions, err := NewSQLiteQueueTransactionsTable(d.db)
if err != nil {
return nil, err
}
transactionJSON, err := NewSQLiteTransactionJSONTable(d.db)
if err != nil {
return nil, err
}
assumedOffline, err := NewSQLiteAssumedOfflineTable(d.db) assumedOffline, err := NewSQLiteAssumedOfflineTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -104,24 +96,22 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
return nil, err return nil, err
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
IsLocalServerName: isLocalServerName, IsLocalServerName: isLocalServerName,
Cache: cache, Cache: cache,
Writer: d.writer, Writer: d.writer,
FederationJoinedHosts: joinedHosts, FederationJoinedHosts: joinedHosts,
FederationQueuePDUs: queuePDUs, FederationQueuePDUs: queuePDUs,
FederationQueueEDUs: queueEDUs, FederationQueueEDUs: queueEDUs,
FederationQueueJSON: queueJSON, FederationQueueJSON: queueJSON,
FederationQueueTransactions: queueTransactions, FederationBlacklist: blacklist,
FederationTransactionJSON: transactionJSON, FederationAssumedOffline: assumedOffline,
FederationBlacklist: blacklist, FederationRelayServers: relayServers,
FederationAssumedOffline: assumedOffline, FederationOutboundPeeks: outboundPeeks,
FederationRelayServers: relayServers, FederationInboundPeeks: inboundPeeks,
FederationOutboundPeeks: outboundPeeks, NotaryServerKeysJSON: notaryKeys,
FederationInboundPeeks: inboundPeeks, NotaryServerKeysMetadata: notaryKeysMetadata,
NotaryServerKeysJSON: notaryKeys, ServerSigningKeys: serverSigningKeys,
NotaryServerKeysMetadata: notaryKeysMetadata,
ServerSigningKeys: serverSigningKeys,
} }
return &d, nil return &d, nil
} }

81
relayapi/api/api.go Normal file
View file

@ -0,0 +1,81 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"context"
"github.com/matrix-org/gomatrixserverlib"
)
// RelayInternalAPI is used to query information from the relay server.
type RelayInternalAPI interface {
RelayServerAPI
PerformRelayServerSync(
ctx context.Context,
request *PerformRelayServerSyncRequest,
response *PerformRelayServerSyncResponse,
) error
}
type RelayServerAPI interface {
// Store async transactions for forwarding to the destination at a later time.
PerformStoreAsync(
ctx context.Context,
request *PerformStoreAsyncRequest,
response *PerformStoreAsyncResponse,
) error
// Obtain the oldest stored transaction for the specified userID.
QueryAsyncTransactions(
ctx context.Context,
request *QueryAsyncTransactionsRequest,
response *QueryAsyncTransactionsResponse,
) error
}
type PerformRelayServerSyncRequest struct {
UserID gomatrixserverlib.UserID `json:"user_id"`
RelayServer gomatrixserverlib.ServerName `json:"relay_name"`
}
type PerformRelayServerSyncResponse struct {
}
type QueryRelayServersRequest struct {
Server gomatrixserverlib.ServerName
}
type QueryRelayServersResponse struct {
RelayServers []gomatrixserverlib.ServerName
}
type PerformStoreAsyncRequest struct {
Txn gomatrixserverlib.Transaction `json:"transaction"`
UserID gomatrixserverlib.UserID `json:"user_id"`
}
type PerformStoreAsyncResponse struct {
}
type QueryAsyncTransactionsRequest struct {
UserID gomatrixserverlib.UserID `json:"user_id"`
}
type QueryAsyncTransactionsResponse struct {
Txn gomatrixserverlib.Transaction `json:"transaction"`
RemainingCount uint32 `json:"remaining"`
}

52
relayapi/internal/api.go Normal file
View file

@ -0,0 +1,52 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/relayapi/storage"
rsAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
)
type RelayInternalAPI struct {
db storage.Database
fedClient *gomatrixserverlib.FederationClient
rsAPI rsAPI.RoomserverInternalAPI
keyRing *gomatrixserverlib.KeyRing
producer *producers.SyncAPIProducer
presenceEnabledInbound bool
serverName gomatrixserverlib.ServerName
}
func NewRelayInternalAPI(
db storage.Database,
fedClient *gomatrixserverlib.FederationClient,
rsAPI rsAPI.RoomserverInternalAPI,
keyRing *gomatrixserverlib.KeyRing,
producer *producers.SyncAPIProducer,
presenceEnabledInbound bool,
serverName gomatrixserverlib.ServerName,
) RelayInternalAPI {
return RelayInternalAPI{
db: db,
fedClient: fedClient,
rsAPI: rsAPI,
keyRing: keyRing,
producer: producer,
presenceEnabledInbound: presenceEnabledInbound,
serverName: serverName,
}
}

View file

@ -0,0 +1,131 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
// PerformRelayServerSync implements api.FederationInternalAPI
func (r *RelayInternalAPI) PerformRelayServerSync(
ctx context.Context,
request *api.PerformRelayServerSyncRequest,
response *api.PerformRelayServerSyncResponse,
) error {
asyncResponse, err := r.fedClient.GetAsyncEvents(ctx, request.UserID, request.RelayServer)
if err != nil {
logrus.Errorf("GetAsyncEvents: %s", err.Error())
return err
}
r.processTransaction(&asyncResponse.Transaction)
for asyncResponse.Remaining > 0 {
asyncResponse, err := r.fedClient.GetAsyncEvents(ctx, request.UserID, request.RelayServer)
if err != nil {
logrus.Errorf("GetAsyncEvents: %s", err.Error())
return err
}
r.processTransaction(&asyncResponse.Transaction)
}
return nil
}
// PerformStoreAsync implements api.RelayInternalAPI
func (r *RelayInternalAPI) PerformStoreAsync(
ctx context.Context,
request *api.PerformStoreAsyncRequest,
response *api.PerformStoreAsyncResponse,
) error {
logrus.Warnf("Storing transaction for %v", request.UserID)
receipt, err := r.db.StoreAsyncTransaction(ctx, request.Txn)
if err != nil {
return err
}
err = r.db.AssociateAsyncTransactionWithDestinations(
ctx,
map[gomatrixserverlib.UserID]struct{}{
request.UserID: {},
},
request.Txn.TransactionID,
receipt)
return err
}
// QueryAsyncTransactions implements api.RelayInternalAPI
func (r *RelayInternalAPI) QueryAsyncTransactions(
ctx context.Context,
request *api.QueryAsyncTransactionsRequest,
response *api.QueryAsyncTransactionsResponse,
) error {
logrus.Warnf("Obtaining transaction for %v", request.UserID)
transaction, receipt, err := r.db.GetAsyncTransaction(ctx, request.UserID)
if err != nil {
return err
}
// TODO : Shouldn't be deleting unless the transaction was successfully returned...
// TODO : Should delete transaction json from table if no more associations
// Maybe track last received transaction, and send that as part of the request,
// then delete before getting the new events from the db.
if transaction != nil && receipt != nil {
err = r.db.CleanAsyncTransactions(ctx, request.UserID, []*shared.Receipt{receipt})
if err != nil {
return err
}
// TODO : Clean async transactions json
}
// TODO : These db calls should happen at the same time right?
count, err := r.db.GetAsyncTransactionCount(ctx, request.UserID)
if err != nil {
return err
}
response.RemainingCount = uint32(count)
if transaction != nil {
response.Txn = *transaction
logrus.Warnf("Obtained transaction: %v", transaction.TransactionID)
}
return nil
}
func (r *RelayInternalAPI) processTransaction(txn *gomatrixserverlib.Transaction) {
logrus.Warn("Processing transaction from relay server")
mu := internal.NewMutexByRoom()
t := internal.NewTxnReq(
r.rsAPI,
nil,
r.serverName,
r.keyRing,
mu,
r.producer,
r.presenceEnabledInbound,
txn.PDUs,
txn.EDUs,
txn.Origin,
txn.TransactionID,
txn.Destination)
t.ProcessTransaction(context.TODO())
}

View file

@ -0,0 +1,70 @@
package inthttp
import (
"context"
"errors"
"net/http"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/relayapi/api"
)
// HTTP paths for the internal HTTP API
const (
RelayAPIPerformRelayServerSyncPath = "/relayapi/performRelayServerSync"
RelayAPIPerformStoreAsyncPath = "/relayapi/performStoreAsync"
RelayAPIQueryAsyncTransactionsPath = "/relayapi/queryAsyncTransactions"
)
// NewRelayAPIClient creates a RelayInternalAPI implemented by talking to a HTTP POST API.
// If httpClient is nil an error is returned
func NewRelayAPIClient(relayapiURL string, httpClient *http.Client, cache caching.ServerKeyCache) (api.RelayInternalAPI, error) {
if httpClient == nil {
return nil, errors.New("NewRelayInternalAPIHTTP: httpClient is <nil>")
}
return &httpRelayInternalAPI{
relayAPIURL: relayapiURL,
httpClient: httpClient,
cache: cache,
}, nil
}
type httpRelayInternalAPI struct {
relayAPIURL string
httpClient *http.Client
cache caching.ServerKeyCache
}
func (h *httpRelayInternalAPI) PerformRelayServerSync(
ctx context.Context,
request *api.PerformRelayServerSyncRequest,
response *api.PerformRelayServerSyncResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformRelayServerSync", h.relayAPIURL+RelayAPIPerformRelayServerSyncPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRelayInternalAPI) PerformStoreAsync(
ctx context.Context,
request *api.PerformStoreAsyncRequest,
response *api.PerformStoreAsyncResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformStoreAsync", h.relayAPIURL+RelayAPIPerformStoreAsyncPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRelayInternalAPI) QueryAsyncTransactions(
ctx context.Context,
request *api.QueryAsyncTransactionsRequest,
response *api.QueryAsyncTransactionsResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryAsyncTransactions", h.relayAPIURL+RelayAPIQueryAsyncTransactionsPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -0,0 +1,27 @@
package inthttp
import (
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/relayapi/api"
)
// AddRoutes adds the RelayInternalAPI handlers to the http.ServeMux.
// nolint:gocyclo
func AddRoutes(intAPI api.RelayInternalAPI, internalAPIMux *mux.Router) {
internalAPIMux.Handle(
RelayAPIPerformRelayServerSyncPath,
httputil.MakeInternalRPCAPI("RelayAPIPerformRelayServerSync", intAPI.PerformRelayServerSync),
)
internalAPIMux.Handle(
RelayAPIPerformStoreAsyncPath,
httputil.MakeInternalRPCAPI("RelayAPIPerformStoreAsync", intAPI.PerformStoreAsync),
)
internalAPIMux.Handle(
RelayAPIQueryAsyncTransactionsPath,
httputil.MakeInternalRPCAPI("RelayAPIQueryAsyncTransactions", intAPI.QueryAsyncTransactions),
)
}

View file

@ -15,79 +15,77 @@
package relayapi package relayapi
import ( import (
"context" "github.com/gorilla/mux"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/internal" keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/relayapi/api"
relayAPI "github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/dendrite/relayapi/internal"
"github.com/matrix-org/dendrite/relayapi/inthttp"
"github.com/matrix-org/dendrite/relayapi/routing"
"github.com/matrix-org/dendrite/relayapi/storage"
rsAPI "github.com/matrix-org/dendrite/roomserver/api" rsAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
type RelayAPI struct { // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions
fedClient *gomatrixserverlib.FederationClient // on the given input API.
rsAPI rsAPI.RoomserverInternalAPI func AddInternalRoutes(router *mux.Router, intAPI api.RelayInternalAPI) {
keyRing *gomatrixserverlib.KeyRing inthttp.AddRoutes(intAPI, router)
producer *producers.SyncAPIProducer
presenceEnabledInbound bool
serverName gomatrixserverlib.ServerName
} }
func NewRelayAPI( // AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component.
func AddPublicRoutes(
base *base.BaseDendrite,
userAPI userapi.UserInternalAPI,
fedClient *gomatrixserverlib.FederationClient,
keyRing gomatrixserverlib.JSONVerifier,
rsAPI rsAPI.FederationRoomserverAPI,
relayAPI relayAPI.RelayInternalAPI,
fedAPI federationAPI.FederationInternalAPI,
keyAPI keyserverAPI.FederationKeyAPI,
) {
fedCfg := &base.Cfg.FederationAPI
relay, ok := relayAPI.(*internal.RelayInternalAPI)
if !ok {
panic("relayapi.AddPublicRoutes called with a RelayInternalAPI impl which was not " +
"RelayInternalAPI. This is a programming error.")
}
routing.Setup(
base.PublicFederationAPIMux,
fedCfg,
relay,
keyRing,
)
}
func NewRelayInternalAPI(
base *base.BaseDendrite,
fedClient *gomatrixserverlib.FederationClient, fedClient *gomatrixserverlib.FederationClient,
rsAPI rsAPI.RoomserverInternalAPI, rsAPI rsAPI.RoomserverInternalAPI,
keyRing *gomatrixserverlib.KeyRing, keyRing *gomatrixserverlib.KeyRing,
producer *producers.SyncAPIProducer, producer *producers.SyncAPIProducer,
presenceEnabledInbound bool, ) internal.RelayInternalAPI {
serverName gomatrixserverlib.ServerName, cfg := &base.Cfg.RelayAPI
) RelayAPI {
return RelayAPI{
fedClient: fedClient,
rsAPI: rsAPI,
keyRing: keyRing,
producer: producer,
presenceEnabledInbound: presenceEnabledInbound,
serverName: serverName,
}
}
// PerformRelayServerSync implements api.FederationInternalAPI relayDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName)
func (r *RelayAPI) PerformRelayServerSync(userID gomatrixserverlib.UserID, relayServer gomatrixserverlib.ServerName) error {
asyncResponse, err := r.fedClient.GetAsyncEvents(context.Background(), userID, relayServer)
if err != nil { if err != nil {
logrus.Errorf("GetAsyncEvents: %s", err.Error()) logrus.WithError(err).Panic("failed to connect to relay db")
return err
}
r.processTransaction(&asyncResponse.Transaction)
for asyncResponse.Remaining > 0 {
asyncResponse, err := r.fedClient.GetAsyncEvents(context.Background(), userID, relayServer)
if err != nil {
logrus.Errorf("GetAsyncEvents: %s", err.Error())
return err
}
r.processTransaction(&asyncResponse.Transaction)
} }
return nil return internal.NewRelayInternalAPI(
} relayDB,
fedClient,
func (r *RelayAPI) processTransaction(txn *gomatrixserverlib.Transaction) { rsAPI,
logrus.Warn("Processing transaction from relay server") keyRing,
mu := internal.NewMutexByRoom() producer,
t := internal.NewTxnReq( base.Cfg.Global.Presence.EnableInbound,
r.rsAPI, base.Cfg.Global.ServerName,
nil, )
r.serverName,
r.keyRing,
mu,
r.producer,
r.presenceEnabledInbound,
txn.PDUs,
txn.EDUs,
txn.Origin,
txn.TransactionID,
txn.Destination)
t.ProcessTransaction(context.TODO())
} }

View file

@ -3,7 +3,7 @@ package routing
import ( import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -19,12 +19,12 @@ type AsyncEventsResponse struct {
func GetAsyncEvents( func GetAsyncEvents(
httpReq *http.Request, httpReq *http.Request,
fedReq *gomatrixserverlib.FederationRequest, fedReq *gomatrixserverlib.FederationRequest,
fedAPI api.FederationInternalAPI, relayAPI api.RelayInternalAPI,
userID gomatrixserverlib.UserID, userID gomatrixserverlib.UserID,
) util.JSONResponse { ) util.JSONResponse {
logrus.Infof("Handling async_events for %v", userID) logrus.Infof("Handling async_events for %v", userID)
var response api.QueryAsyncTransactionsResponse var response api.QueryAsyncTransactionsResponse
err := fedAPI.QueryAsyncTransactions(httpReq.Context(), &api.QueryAsyncTransactionsRequest{UserID: userID}, &response) err := relayAPI.QueryAsyncTransactions(httpReq.Context(), &api.QueryAsyncTransactionsRequest{UserID: userID}, &response)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,

View file

@ -5,21 +5,21 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/matrix-org/dendrite/federationapi/internal"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/relayapi/internal"
"github.com/matrix-org/dendrite/relayapi/routing"
"github.com/matrix-org/dendrite/relayapi/storage"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) { func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) {
testDB := createDatabase() testDB := storage.NewFakeRelayDatabase()
db := shared.Database{ db := shared.Database{
Writer: sqlutil.NewDummyWriter(), Writer: sqlutil.NewDummyWriter(),
FederationQueueTransactions: testDB, RelayQueue: testDB,
FederationTransactionJSON: testDB, RelayQueueJSON: testDB,
} }
httpReq := &http.Request{} httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false) userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
@ -33,11 +33,11 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) {
t.Fatalf("Failed to store transaction: %s", err.Error()) t.Fatalf("Failed to store transaction: %s", err.Error())
} }
fedAPI := internal.NewFederationInternalAPI( relayAPI := internal.NewRelayInternalAPI(
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil, &db, nil, nil, nil, nil, false, "",
) )
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID) response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code) assert.Equal(t, http.StatusOK, response.Code)
jsonResponse := response.JSON.(routing.AsyncEventsResponse) jsonResponse := response.JSON.(routing.AsyncEventsResponse)
@ -46,11 +46,11 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) {
} }
func TestGetAsyncReturnsSavedTransaction(t *testing.T) { func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
testDB := createDatabase() testDB := storage.NewFakeRelayDatabase()
db := shared.Database{ db := shared.Database{
Writer: sqlutil.NewDummyWriter(), Writer: sqlutil.NewDummyWriter(),
FederationQueueTransactions: testDB, RelayQueue: testDB,
FederationTransactionJSON: testDB, RelayQueueJSON: testDB,
} }
httpReq := &http.Request{} httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false) userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
@ -74,11 +74,11 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
t.Fatalf("Failed to associate transaction with user: %s", err.Error()) t.Fatalf("Failed to associate transaction with user: %s", err.Error())
} }
fedAPI := internal.NewFederationInternalAPI( relayAPI := internal.NewRelayInternalAPI(
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil, &db, nil, nil, nil, nil, false, "",
) )
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID) response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code) assert.Equal(t, http.StatusOK, response.Code)
jsonResponse := response.JSON.(routing.AsyncEventsResponse) jsonResponse := response.JSON.(routing.AsyncEventsResponse)
@ -87,11 +87,11 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
} }
func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) { func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) {
testDB := createDatabase() testDB := storage.NewFakeRelayDatabase()
db := shared.Database{ db := shared.Database{
Writer: sqlutil.NewDummyWriter(), Writer: sqlutil.NewDummyWriter(),
FederationQueueTransactions: testDB, RelayQueue: testDB,
FederationTransactionJSON: testDB, RelayQueueJSON: testDB,
} }
httpReq := &http.Request{} httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false) userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
@ -131,18 +131,18 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) {
t.Fatalf("Failed to associate transaction with user: %s", err.Error()) t.Fatalf("Failed to associate transaction with user: %s", err.Error())
} }
fedAPI := internal.NewFederationInternalAPI( relayAPI := internal.NewRelayInternalAPI(
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil, &db, nil, nil, nil, nil, false, "",
) )
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID) response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code) assert.Equal(t, http.StatusOK, response.Code)
jsonResponse := response.JSON.(routing.AsyncEventsResponse) jsonResponse := response.JSON.(routing.AsyncEventsResponse)
assert.Equal(t, uint32(1), jsonResponse.Remaining) assert.Equal(t, uint32(1), jsonResponse.Remaining)
assert.Equal(t, transaction, jsonResponse.Transaction) assert.Equal(t, transaction, jsonResponse.Transaction)
response = routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID) response = routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code) assert.Equal(t, http.StatusOK, response.Code)
jsonResponse = response.JSON.(routing.AsyncEventsResponse) jsonResponse = response.JSON.(routing.AsyncEventsResponse)

View file

@ -5,7 +5,7 @@ import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -15,7 +15,7 @@ import (
func ForwardAsync( func ForwardAsync(
httpReq *http.Request, httpReq *http.Request,
fedReq *gomatrixserverlib.FederationRequest, fedReq *gomatrixserverlib.FederationRequest,
fedAPI api.FederationInternalAPI, relayAPI api.RelayInternalAPI,
txnID gomatrixserverlib.TransactionID, txnID gomatrixserverlib.TransactionID,
userID gomatrixserverlib.UserID, userID gomatrixserverlib.UserID,
) util.JSONResponse { ) util.JSONResponse {
@ -54,7 +54,7 @@ func ForwardAsync(
UserID: userID, UserID: userID,
} }
res := api.PerformStoreAsyncResponse{} res := api.PerformStoreAsyncResponse{}
err := fedAPI.PerformStoreAsync(httpReq.Context(), &req, &res) err := relayAPI.PerformStoreAsync(httpReq.Context(), &req, &res)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,

View file

@ -0,0 +1,111 @@
package routing_test
import (
"context"
"encoding/json"
"fmt"
"net/http"
"testing"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/internal"
"github.com/matrix-org/dendrite/relayapi/routing"
"github.com/matrix-org/dendrite/relayapi/storage"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/gomatrixserverlib"
)
const (
testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
testDestination = gomatrixserverlib.ServerName("white.orchard")
)
func createTransaction() gomatrixserverlib.Transaction {
txn := gomatrixserverlib.Transaction{}
txn.PDUs = []json.RawMessage{
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`),
}
txn.Origin = testOrigin
txn.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
txn.Destination = testDestination
return txn
}
func createFederationRequest(userID gomatrixserverlib.UserID) (gomatrixserverlib.Transaction, gomatrixserverlib.FederationRequest) {
txn := createTransaction()
var federationPathPrefixV1 = "/_matrix/federation/v1"
path := federationPathPrefixV1 + "/forward_async/" + string(txn.TransactionID) + "/" + userID.Raw()
request := gomatrixserverlib.NewFederationRequest("PUT", txn.Origin, txn.Destination, path)
request.SetContent(txn)
return txn, request
}
func TestEmptyForwardReturnsOk(t *testing.T) {
testDB := storage.NewFakeRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
_, request := createFederationRequest(*userID)
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.ForwardAsync(httpReq, &request, &relayAPI, "1", *userID)
expected := 200
if response.Code != expected {
t.Fatalf("Expected: %v, Actual: %v", expected, response.Code)
}
}
func TestUniqueTransactionStoredInDatabase(t *testing.T) {
testDB := storage.NewFakeRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
inputTransaction, request := createFederationRequest(*userID)
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.ForwardAsync(
httpReq, &request, &relayAPI, inputTransaction.TransactionID, *userID)
transaction, _, err := db.GetAsyncTransaction(context.TODO(), *userID)
if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
expected := 200
if response.Code != expected {
t.Fatalf("Expected Return Code: %v, Actual: %v", expected, response.Code)
}
if transactionCount != 1 {
t.Fatalf("Expected count of 1, Actual: %d", transactionCount)
}
if transaction.TransactionID != inputTransaction.TransactionID {
t.Fatalf("Expected Transaction ID: %s, Actual: %s",
inputTransaction.TransactionID, transaction.TransactionID)
}
}

123
relayapi/routing/routing.go Normal file
View file

@ -0,0 +1,123 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"fmt"
"net/http"
"time"
"github.com/getsentry/sentry-go"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/httputil"
relayInternal "github.com/matrix-org/dendrite/relayapi/internal"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// Setup registers HTTP handlers with the given ServeMux.
// The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly
// path unescape twice (once from the router, once from MakeRelayAPI). We need to have this enabled
// so we can decode paths like foo/bar%2Fbaz as [foo, bar/baz] - by default it will decode to [foo, bar, baz]
//
// Due to Setup being used to call many other functions, a gocyclo nolint is
// applied:
// nolint: gocyclo
func Setup(
fedMux *mux.Router,
cfg *config.FederationAPI,
relayAPI *relayInternal.RelayInternalAPI,
keys gomatrixserverlib.JSONVerifier,
) {
v1fedmux := fedMux.PathPrefix("/v1").Subrouter()
v1fedmux.Handle("/forward_async/{txnID}/{userID}", MakeRelayAPI(
"relay_forward_async", "", cfg.Matrix.IsLocalServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username was invalid"),
}
}
return ForwardAsync(
httpReq, request, relayAPI, gomatrixserverlib.TransactionID(vars["txnID"]),
*userID,
)
},
)).Methods(http.MethodPut, http.MethodOptions)
v1fedmux.Handle("/async_events/{userID}", MakeRelayAPI(
"relay_async_events", "", cfg.Matrix.IsLocalServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username was invalid"),
}
}
return GetAsyncEvents(httpReq, request, relayAPI, *userID)
},
)).Methods(http.MethodGet, http.MethodOptions)
}
// MakeRelayAPI makes an http.Handler that checks matrix relay authentication.
func MakeRelayAPI(
metricsName string, serverName gomatrixserverlib.ServerName,
isLocalServerName func(gomatrixserverlib.ServerName) bool,
keyRing gomatrixserverlib.JSONVerifier,
f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse,
) http.Handler {
h := func(req *http.Request) util.JSONResponse {
fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
req, time.Now(), serverName, isLocalServerName, keyRing,
)
if fedReq == nil {
return errResp
}
// add the user to Sentry, if enabled
hub := sentry.GetHubFromContext(req.Context())
if hub != nil {
hub.Scope().SetTag("origin", string(fedReq.Origin()))
hub.Scope().SetTag("uri", fedReq.RequestURI())
}
defer func() {
if r := recover(); r != nil {
if hub != nil {
hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path))
}
// re-panic to return the 500
panic(r)
}
}()
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.MatrixErrorResponse(400, "M_UNRECOGNISED", "badly encoded query params")
}
jsonRes := f(req, fedReq, vars)
// do not log 4xx as errors as they are client fails, not server fails
if hub != nil && jsonRes.Code >= 500 {
hub.Scope().SetExtra("response", jsonRes)
hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code))
}
return jsonRes
}
return httputil.MakeExternalAPI(metricsName, h)
}

View file

@ -0,0 +1,95 @@
package storage
import (
"context"
"database/sql"
"encoding/json"
"sync"
"github.com/matrix-org/gomatrixserverlib"
)
type testDatabase struct {
nid int64
nidMutex sync.Mutex
transactions map[int64]json.RawMessage
associations map[gomatrixserverlib.ServerName][]int64
}
func NewFakeRelayDatabase() *testDatabase {
return &testDatabase{
nid: 1,
nidMutex: sync.Mutex{},
transactions: make(map[int64]json.RawMessage),
associations: make(map[gomatrixserverlib.ServerName][]int64),
}
}
func (d *testDatabase) InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error {
if _, ok := d.associations[serverName]; !ok {
d.associations[serverName] = []int64{}
}
d.associations[serverName] = append(d.associations[serverName], nid)
return nil
}
func (d *testDatabase) DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error {
for _, nid := range jsonNIDs {
for index, associatedNID := range d.associations[serverName] {
if associatedNID == nid {
d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...)
}
}
}
return nil
}
func (d *testDatabase) SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) {
results := []int64{}
resultCount := limit
if limit > len(d.associations[serverName]) {
resultCount = len(d.associations[serverName])
}
if resultCount > 0 {
for i := 0; i < resultCount; i++ {
results = append(results, d.associations[serverName][i])
}
}
return results, nil
}
func (d *testDatabase) SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) {
return int64(len(d.associations[serverName])), nil
}
func (d *testDatabase) InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) {
d.nidMutex.Lock()
defer d.nidMutex.Unlock()
nid := d.nid
d.transactions[nid] = []byte(json)
d.nid++
return nid, nil
}
func (d *testDatabase) DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error {
for _, nid := range nids {
delete(d.transactions, nid)
}
return nil
}
func (d *testDatabase) SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) {
result := make(map[int64][]byte)
for _, nid := range jsonNIDs {
if transaction, ok := d.transactions[nid]; ok {
result[nid] = transaction
}
}
return result, nil
}

View file

@ -0,0 +1,30 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
StoreAsyncTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*shared.Receipt, error)
AssociateAsyncTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *shared.Receipt) error
CleanAsyncTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*shared.Receipt) error
GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *shared.Receipt, error)
GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error)
}

View file

@ -23,60 +23,60 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
) )
const transactionJSONSchema = ` const relayQueueJSONSchema = `
-- The federationsender_transaction_json table contains event contents that -- The relayapi_queue_json table contains event contents that
-- we are storing for future forwarding. -- we are storing for future forwarding.
CREATE TABLE IF NOT EXISTS federationsender_transaction_json ( CREATE TABLE IF NOT EXISTS relayapi_queue_json (
-- The JSON NID. This allows cross-referencing to find the JSON blob. -- The JSON NID. This allows cross-referencing to find the JSON blob.
json_nid BIGSERIAL, json_nid BIGSERIAL,
-- The JSON body. Text so that we preserve UTF-8. -- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL json_body TEXT NOT NULL
); );
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_transaction_json_json_nid_idx CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx
ON federationsender_transaction_json (json_nid); ON relayapi_queue_json (json_nid);
` `
const insertTransactionJSONSQL = "" + const insertQueueJSONSQL = "" +
"INSERT INTO federationsender_transaction_json (json_body)" + "INSERT INTO relayapi_queue_json (json_body)" +
" VALUES ($1)" + " VALUES ($1)" +
" RETURNING json_nid" " RETURNING json_nid"
const deleteTransactionJSONSQL = "" + const deleteQueueJSONSQL = "" +
"DELETE FROM federationsender_transaction_json WHERE json_nid = ANY($1)" "DELETE FROM relayapi_queue_json WHERE json_nid = ANY($1)"
const selectTransactionJSONSQL = "" + const selectQueueJSONSQL = "" +
"SELECT json_nid, json_body FROM federationsender_transaction_json" + "SELECT json_nid, json_body FROM relayapi_queue_json" +
" WHERE json_nid = ANY($1)" " WHERE json_nid = ANY($1)"
type transactionJSONStatements struct { type relayQueueJSONStatements struct {
db *sql.DB db *sql.DB
insertJSONStmt *sql.Stmt insertJSONStmt *sql.Stmt
deleteJSONStmt *sql.Stmt deleteJSONStmt *sql.Stmt
selectJSONStmt *sql.Stmt selectJSONStmt *sql.Stmt
} }
func NewPostgresTransactionJSONTable(db *sql.DB) (s *transactionJSONStatements, err error) { func NewPostgresRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) {
s = &transactionJSONStatements{ s = &relayQueueJSONStatements{
db: db, db: db,
} }
_, err = s.db.Exec(transactionJSONSchema) _, err = s.db.Exec(relayQueueJSONSchema)
if err != nil { if err != nil {
return return
} }
if s.insertJSONStmt, err = s.db.Prepare(insertTransactionJSONSQL); err != nil { if s.insertJSONStmt, err = s.db.Prepare(insertQueueJSONSQL); err != nil {
return return
} }
if s.deleteJSONStmt, err = s.db.Prepare(deleteTransactionJSONSQL); err != nil { if s.deleteJSONStmt, err = s.db.Prepare(deleteQueueJSONSQL); err != nil {
return return
} }
if s.selectJSONStmt, err = s.db.Prepare(selectTransactionJSONSQL); err != nil { if s.selectJSONStmt, err = s.db.Prepare(selectQueueJSONSQL); err != nil {
return return
} }
return return
} }
func (s *transactionJSONStatements) InsertTransactionJSON( func (s *relayQueueJSONStatements) InsertQueueJSON(
ctx context.Context, txn *sql.Tx, json string, ctx context.Context, txn *sql.Tx, json string,
) (int64, error) { ) (int64, error) {
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
@ -87,7 +87,7 @@ func (s *transactionJSONStatements) InsertTransactionJSON(
return lastid, nil return lastid, nil
} }
func (s *transactionJSONStatements) DeleteTransactionJSON( func (s *relayQueueJSONStatements) DeleteQueueJSON(
ctx context.Context, txn *sql.Tx, nids []int64, ctx context.Context, txn *sql.Tx, nids []int64,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt) stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt)
@ -95,7 +95,7 @@ func (s *transactionJSONStatements) DeleteTransactionJSON(
return err return err
} }
func (s *transactionJSONStatements) SelectTransactionJSON( func (s *relayQueueJSONStatements) SelectQueueJSON(
ctx context.Context, txn *sql.Tx, jsonNIDs []int64, ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
) (map[int64][]byte, error) { ) (map[int64][]byte, error) {
blobs := map[int64][]byte{} blobs := map[int64][]byte{}

View file

@ -24,79 +24,79 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
const queueTransactionsSchema = ` const relayQueueSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_transactions ( CREATE TABLE IF NOT EXISTS relayapi_queue (
-- The transaction ID that was generated before persisting the event. -- The transaction ID that was generated before persisting the event.
transaction_id TEXT NOT NULL, transaction_id TEXT NOT NULL,
-- The destination server that we will send the event to. -- The destination server that we will send the event to.
server_name TEXT NOT NULL, server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_transaction_json table. -- The JSON NID from the relayapi_queue_json table.
json_nid BIGINT NOT NULL json_nid BIGINT NOT NULL
); );
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_transactions_transaction_json_nid_idx CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx
ON federationsender_queue_transactions (json_nid, server_name); ON relayapi_queue (json_nid, server_name);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_json_nid_idx CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx
ON federationsender_queue_transactions (json_nid); ON relayapi_queue (json_nid);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_server_name_idx CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx
ON federationsender_queue_transactions (server_name); ON relayapi_queue (server_name);
` `
const insertQueueTransactionSQL = "" + const insertQueueEntrySQL = "" +
"INSERT INTO federationsender_queue_transactions (transaction_id, server_name, json_nid)" + "INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" +
" VALUES ($1, $2, $3)" " VALUES ($1, $2, $3)"
const deleteQueueTransactionsSQL = "" + const deleteQueueEntriesSQL = "" +
"DELETE FROM federationsender_queue_transactions WHERE server_name = $1 AND json_nid = ANY($2)" "DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid = ANY($2)"
const selectQueueTransactionsSQL = "" + const selectQueueEntriesSQL = "" +
"SELECT json_nid FROM federationsender_queue_transactions" + "SELECT json_nid FROM relayapi_queue" +
" WHERE server_name = $1" + " WHERE server_name = $1" +
" LIMIT $2" " LIMIT $2"
const selectQueueTransactionsCountSQL = "" + const selectQueueEntryCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_transactions" + "SELECT COUNT(*) FROM relayapi_queue" +
" WHERE server_name = $1" " WHERE server_name = $1"
type queueTransactionsStatements struct { type relayQueueStatements struct {
db *sql.DB db *sql.DB
insertQueueTransactionStmt *sql.Stmt insertQueueEntryStmt *sql.Stmt
deleteQueueTransactionsStmt *sql.Stmt deleteQueueEntriesStmt *sql.Stmt
selectQueueTransactionsStmt *sql.Stmt selectQueueEntriesStmt *sql.Stmt
selectQueueTransactionsCountStmt *sql.Stmt selectQueueEntryCountStmt *sql.Stmt
} }
func NewPostgresQueueTransactionsTable(db *sql.DB) (s *queueTransactionsStatements, err error) { func NewPostgresRelayQueueTable(db *sql.DB) (s *relayQueueStatements, err error) {
s = &queueTransactionsStatements{ s = &relayQueueStatements{
db: db, db: db,
} }
_, err = s.db.Exec(queueTransactionsSchema) _, err = s.db.Exec(relayQueueSchema)
if err != nil { if err != nil {
return return
} }
if s.insertQueueTransactionStmt, err = s.db.Prepare(insertQueueTransactionSQL); err != nil { if s.insertQueueEntryStmt, err = s.db.Prepare(insertQueueEntrySQL); err != nil {
return return
} }
if s.deleteQueueTransactionsStmt, err = s.db.Prepare(deleteQueueTransactionsSQL); err != nil { if s.deleteQueueEntriesStmt, err = s.db.Prepare(deleteQueueEntriesSQL); err != nil {
return return
} }
if s.selectQueueTransactionsStmt, err = s.db.Prepare(selectQueueTransactionsSQL); err != nil { if s.selectQueueEntriesStmt, err = s.db.Prepare(selectQueueEntriesSQL); err != nil {
return return
} }
if s.selectQueueTransactionsCountStmt, err = s.db.Prepare(selectQueueTransactionsCountSQL); err != nil { if s.selectQueueEntryCountStmt, err = s.db.Prepare(selectQueueEntryCountSQL); err != nil {
return return
} }
return return
} }
func (s *queueTransactionsStatements) InsertQueueTransaction( func (s *relayQueueStatements) InsertQueueEntry(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID, transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
nid int64, nid int64,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueTransactionStmt) stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt)
_, err := stmt.ExecContext( _, err := stmt.ExecContext(
ctx, ctx,
transactionID, // the transaction ID that we initially attempted transactionID, // the transaction ID that we initially attempted
@ -106,22 +106,22 @@ func (s *queueTransactionsStatements) InsertQueueTransaction(
return err return err
} }
func (s *queueTransactionsStatements) DeleteQueueTransactions( func (s *relayQueueStatements) DeleteQueueEntries(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
jsonNIDs []int64, jsonNIDs []int64,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionsStmt) stmt := sqlutil.TxStmt(txn, s.deleteQueueEntriesStmt)
_, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs)) _, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs))
return err return err
} }
func (s *queueTransactionsStatements) SelectQueueTransactions( func (s *relayQueueStatements) SelectQueueEntries(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
limit int, limit int,
) ([]int64, error) { ) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsStmt) stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit) rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil { if err != nil {
return nil, err return nil, err
@ -139,11 +139,11 @@ func (s *queueTransactionsStatements) SelectQueueTransactions(
return result, rows.Err() return result, rows.Err()
} }
func (s *queueTransactionsStatements) SelectQueueTransactionCount( func (s *relayQueueStatements) SelectQueueEntryCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) { ) (int64, error) {
var count int64 var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsCountStmt) stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count) err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given // It's acceptable for there to be no rows referencing a given

View file

@ -0,0 +1,59 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"database/sql"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database stores information needed by the relayapi
type Database struct {
shared.Database
db *sql.DB
writer sqlutil.Writer
}
// NewDatabase opens a new database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) {
var d Database
var err error
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
return nil, err
}
queue, err := NewPostgresRelayQueueTable(d.db)
if err != nil {
return nil, err
}
queueJSON, err := NewPostgresRelayQueueJSONTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
IsLocalServerName: isLocalServerName,
Cache: cache,
Writer: d.writer,
RelayQueue: queue,
RelayQueueJSON: queueJSON,
}
return &d, nil
}

View file

@ -0,0 +1,142 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shared
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
type Database struct {
DB *sql.DB
IsLocalServerName func(gomatrixserverlib.ServerName) bool
Cache caching.FederationCache
Writer sqlutil.Writer
RelayQueue tables.RelayQueue
RelayQueueJSON tables.RelayQueueJSON
}
func (d *Database) StoreAsyncTransaction(
ctx context.Context, txn gomatrixserverlib.Transaction,
) (*shared.Receipt, error) {
var err error
json, err := json.Marshal(txn)
if err != nil {
return nil, fmt.Errorf("d.JSONUnmarshall: %w", err)
}
var nid int64
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nid, err = d.RelayQueueJSON.InsertQueueJSON(ctx, txn, string(json))
return err
})
if err != nil {
return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
}
receipt := shared.NewReceipt(nid)
return &receipt, nil
}
func (d *Database) AssociateAsyncTransactionWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.UserID]struct{},
transactionID gomatrixserverlib.TransactionID,
receipt *shared.Receipt,
) error {
for destination := range destinations {
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.RelayQueue.InsertQueueEntry(
ctx, txn, transactionID, destination.Domain(), receipt.GetNID())
return err
})
if err != nil {
return fmt.Errorf("d.insertQueueEntry: %w", err)
}
}
return nil
}
func (d *Database) CleanAsyncTransactions(
ctx context.Context,
userID gomatrixserverlib.UserID,
receipts []*shared.Receipt,
) error {
println(len(receipts))
nids := make([]int64, len(receipts))
for i, receipt := range receipts {
nids[i] = receipt.GetNID()
}
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.RelayQueue.DeleteQueueEntries(ctx, txn, userID.Domain(), nids)
return err
})
if err != nil {
return fmt.Errorf("d.deleteQueueEntries: %w", err)
}
return nil
}
func (d *Database) GetAsyncTransaction(
ctx context.Context,
userID gomatrixserverlib.UserID,
) (*gomatrixserverlib.Transaction, *shared.Receipt, error) {
nids, err := d.RelayQueue.SelectQueueEntries(ctx, nil, userID.Domain(), 1)
if err != nil {
return nil, nil, fmt.Errorf("d.SelectQueueEntries: %w", err)
}
if len(nids) == 0 {
return nil, nil, nil
}
txns := map[int64][]byte{}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
txns, err = d.RelayQueueJSON.SelectQueueJSON(ctx, txn, nids)
return err
})
if err != nil {
return nil, nil, fmt.Errorf("d.SelectQueueJSON: %w", err)
}
transaction := &gomatrixserverlib.Transaction{}
err = json.Unmarshal(txns[nids[0]], transaction)
if err != nil {
return nil, nil, fmt.Errorf("Unmarshall transaction: %w", err)
}
receipt := shared.NewReceipt(nids[0])
return transaction, &receipt, nil
}
func (d *Database) GetAsyncTransactionCount(
ctx context.Context,
userID gomatrixserverlib.UserID,
) (int64, error) {
count, err := d.RelayQueue.SelectQueueEntryCount(ctx, nil, userID.Domain())
if err != nil {
return 0, fmt.Errorf("d.SelectQueueEntryCount: %w", err)
}
return count, nil
}

View file

@ -24,53 +24,53 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
) )
const transactionJSONSchema = ` const relayQueueJSONSchema = `
-- The federationsender_transaction_json table contains event contents that -- The relayapi_queue_json table contains event contents that
-- we are storing for future forwarding. -- we are storing for future forwarding.
CREATE TABLE IF NOT EXISTS federationsender_transaction_json ( CREATE TABLE IF NOT EXISTS relayapi_queue_json (
-- The JSON NID. This allows cross-referencing to find the JSON blob. -- The JSON NID. This allows cross-referencing to find the JSON blob.
json_nid INTEGER PRIMARY KEY AUTOINCREMENT, json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
-- The JSON body. Text so that we preserve UTF-8. -- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL json_body TEXT NOT NULL
); );
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_transaction_json_json_nid_idx CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx
ON federationsender_transaction_json (json_nid); ON relayapi_queue_json (json_nid);
` `
const insertTransactionJSONSQL = "" + const insertQueueJSONSQL = "" +
"INSERT INTO federationsender_transaction_json (json_body)" + "INSERT INTO relayapi_queue_json (json_body)" +
" VALUES ($1)" " VALUES ($1)"
const deleteTransactionJSONSQL = "" + const deleteQueueJSONSQL = "" +
"DELETE FROM federationsender_transaction_json WHERE json_nid IN ($1)" "DELETE FROM relayapi_queue_json WHERE json_nid IN ($1)"
const selectTransactionJSONSQL = "" + const selectQueueJSONSQL = "" +
"SELECT json_nid, json_body FROM federationsender_transaction_json" + "SELECT json_nid, json_body FROM relayapi_queue_json" +
" WHERE json_nid IN ($1)" " WHERE json_nid IN ($1)"
type transactionJSONStatements struct { type relayQueueJSONStatements struct {
db *sql.DB db *sql.DB
insertJSONStmt *sql.Stmt insertJSONStmt *sql.Stmt
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
} }
func NewSQLiteTransactionJSONTable(db *sql.DB) (s *transactionJSONStatements, err error) { func NewSQLiteRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) {
s = &transactionJSONStatements{ s = &relayQueueJSONStatements{
db: db, db: db,
} }
_, err = db.Exec(transactionJSONSchema) _, err = db.Exec(relayQueueJSONSchema)
if err != nil { if err != nil {
return return
} }
if s.insertJSONStmt, err = db.Prepare(insertTransactionJSONSQL); err != nil { if s.insertJSONStmt, err = db.Prepare(insertQueueJSONSQL); err != nil {
return return
} }
return return
} }
func (s *transactionJSONStatements) InsertTransactionJSON( func (s *relayQueueJSONStatements) InsertQueueJSON(
ctx context.Context, txn *sql.Tx, json string, ctx context.Context, txn *sql.Tx, json string,
) (lastid int64, err error) { ) (lastid int64, err error) {
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
@ -85,13 +85,13 @@ func (s *transactionJSONStatements) InsertTransactionJSON(
return return
} }
func (s *transactionJSONStatements) DeleteTransactionJSON( func (s *relayQueueJSONStatements) DeleteQueueJSON(
ctx context.Context, txn *sql.Tx, nids []int64, ctx context.Context, txn *sql.Tx, nids []int64,
) error { ) error {
deleteSQL := strings.Replace(deleteTransactionJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) deleteSQL := strings.Replace(deleteQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
deleteStmt, err := txn.Prepare(deleteSQL) deleteStmt, err := txn.Prepare(deleteSQL)
if err != nil { if err != nil {
return fmt.Errorf("s.deleteTransactionJSON s.db.Prepare: %w", err) return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
} }
iNIDs := make([]interface{}, len(nids)) iNIDs := make([]interface{}, len(nids))
@ -104,13 +104,13 @@ func (s *transactionJSONStatements) DeleteTransactionJSON(
return err return err
} }
func (s *transactionJSONStatements) SelectTransactionJSON( func (s *relayQueueJSONStatements) SelectQueueJSON(
ctx context.Context, txn *sql.Tx, jsonNIDs []int64, ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
) (map[int64][]byte, error) { ) (map[int64][]byte, error) {
selectSQL := strings.Replace(selectTransactionJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1) selectSQL := strings.Replace(selectQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
selectStmt, err := txn.Prepare(selectSQL) selectStmt, err := txn.Prepare(selectSQL)
if err != nil { if err != nil {
return nil, fmt.Errorf("s.selectTransactionJSON s.db.Prepare: %w", err) return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err)
} }
iNIDs := make([]interface{}, len(jsonNIDs)) iNIDs := make([]interface{}, len(jsonNIDs))
@ -122,14 +122,14 @@ func (s *transactionJSONStatements) SelectTransactionJSON(
stmt := sqlutil.TxStmt(txn, selectStmt) stmt := sqlutil.TxStmt(txn, selectStmt)
rows, err := stmt.QueryContext(ctx, iNIDs...) rows, err := stmt.QueryContext(ctx, iNIDs...)
if err != nil { if err != nil {
return nil, fmt.Errorf("s.selectTransactionJSON stmt.QueryContext: %w", err) return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err)
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectQueueJSON: rows.close() failed")
for rows.Next() { for rows.Next() {
var nid int64 var nid int64
var blob []byte var blob []byte
if err = rows.Scan(&nid, &blob); err != nil { if err = rows.Scan(&nid, &blob); err != nil {
return nil, fmt.Errorf("s.selectTransactionJSON rows.Scan: %w", err) return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err)
} }
blobs[nid] = blob blobs[nid] = blob
} }

View file

@ -25,79 +25,79 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
const queueTransactionsSchema = ` const relayQueueSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_transactions ( CREATE TABLE IF NOT EXISTS relayapi_queue (
-- The transaction ID that was generated before persisting the event. -- The transaction ID that was generated before persisting the event.
transaction_id TEXT NOT NULL, transaction_id TEXT NOT NULL,
-- The domain part of the user ID the m.room.member event is for. -- The domain part of the user ID the m.room.member event is for.
server_name TEXT NOT NULL, server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_queue_transactions_json table. -- The JSON NID from the relayapi_queue_json table.
json_nid BIGINT NOT NULL json_nid BIGINT NOT NULL
); );
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_transactions_transaction_json_nid_idx CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx
ON federationsender_queue_transactions (json_nid, server_name); ON relayapi_queue (json_nid, server_name);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_json_nid_idx CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx
ON federationsender_queue_transactions (json_nid); ON relayapi_queue (json_nid);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_server_name_idx CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx
ON federationsender_queue_transactions (server_name); ON relayapi_queue (server_name);
` `
const insertQueueTransactionSQL = "" + const insertQueueEntrySQL = "" +
"INSERT INTO federationsender_queue_transactions (transaction_id, server_name, json_nid)" + "INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" +
" VALUES ($1, $2, $3)" " VALUES ($1, $2, $3)"
const deleteQueueTransactionsSQL = "" + const deleteQueueEntriesSQL = "" +
"DELETE FROM federationsender_queue_transactions WHERE server_name = $1 AND json_nid IN ($2)" "DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid IN ($2)"
const selectQueueTransactionsSQL = "" + const selectQueueEntriesSQL = "" +
"SELECT json_nid FROM federationsender_queue_transactions" + "SELECT json_nid FROM relayapi_queue" +
" WHERE server_name = $1" + " WHERE server_name = $1" +
" LIMIT $2" " LIMIT $2"
const selectQueueTransactionsCountSQL = "" + const selectQueueEntryCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_transactions" + "SELECT COUNT(*) FROM relayapi_queue" +
" WHERE server_name = $1" " WHERE server_name = $1"
type queueTransactionsStatements struct { type relayQueueStatements struct {
db *sql.DB db *sql.DB
insertQueueTransactionStmt *sql.Stmt insertQueueEntryStmt *sql.Stmt
selectQueueTransactionsStmt *sql.Stmt selectQueueEntriesStmt *sql.Stmt
selectQueueTransactionsCountStmt *sql.Stmt selectQueueEntryCountStmt *sql.Stmt
// deleteQueueTransactionsStmt *sql.Stmt - prepared at runtime due to variadic // deleteQueueEntriesStmt *sql.Stmt - prepared at runtime due to variadic
} }
func NewSQLiteQueueTransactionsTable(db *sql.DB) (s *queueTransactionsStatements, err error) { func NewSQLiteRelayQueueTable(db *sql.DB) (s *relayQueueStatements, err error) {
s = &queueTransactionsStatements{ s = &relayQueueStatements{
db: db, db: db,
} }
_, err = db.Exec(queueTransactionsSchema) _, err = db.Exec(relayQueueSchema)
if err != nil { if err != nil {
return return
} }
if s.insertQueueTransactionStmt, err = db.Prepare(insertQueueTransactionSQL); err != nil { if s.insertQueueEntryStmt, err = db.Prepare(insertQueueEntrySQL); err != nil {
return return
} }
//if s.deleteQueueTransactionsStmt, err = db.Prepare(deleteQueueTransactionsSQL); err != nil { //if s.deleteQueueEntriesStmt, err = db.Prepare(deleteQueueEntriesSQL); err != nil {
// return // return
//} //}
if s.selectQueueTransactionsStmt, err = db.Prepare(selectQueueTransactionsSQL); err != nil { if s.selectQueueEntriesStmt, err = db.Prepare(selectQueueEntriesSQL); err != nil {
return return
} }
if s.selectQueueTransactionsCountStmt, err = db.Prepare(selectQueueTransactionsCountSQL); err != nil { if s.selectQueueEntryCountStmt, err = db.Prepare(selectQueueEntryCountSQL); err != nil {
return return
} }
return return
} }
func (s *queueTransactionsStatements) InsertQueueTransaction( func (s *relayQueueStatements) InsertQueueEntry(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID, transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
nid int64, nid int64,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueTransactionStmt) stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt)
_, err := stmt.ExecContext( _, err := stmt.ExecContext(
ctx, ctx,
transactionID, // the transaction ID that we initially attempted transactionID, // the transaction ID that we initially attempted
@ -107,15 +107,15 @@ func (s *queueTransactionsStatements) InsertQueueTransaction(
return err return err
} }
func (s *queueTransactionsStatements) DeleteQueueTransactions( func (s *relayQueueStatements) DeleteQueueEntries(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
jsonNIDs []int64, jsonNIDs []int64,
) error { ) error {
deleteSQL := strings.Replace(deleteQueueTransactionsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) deleteSQL := strings.Replace(deleteQueueEntriesSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
deleteStmt, err := txn.Prepare(deleteSQL) deleteStmt, err := txn.Prepare(deleteSQL)
if err != nil { if err != nil {
return fmt.Errorf("s.deleteQueueTransactionJSON s.db.Prepare: %w", err) return fmt.Errorf("s.deleteQueueEntries s.db.Prepare: %w", err)
} }
params := make([]interface{}, len(jsonNIDs)+1) params := make([]interface{}, len(jsonNIDs)+1)
@ -129,12 +129,12 @@ func (s *queueTransactionsStatements) DeleteQueueTransactions(
return err return err
} }
func (s *queueTransactionsStatements) SelectQueueTransactions( func (s *relayQueueStatements) SelectQueueEntries(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
limit int, limit int,
) ([]int64, error) { ) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsStmt) stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit) rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil { if err != nil {
return nil, err return nil, err
@ -152,11 +152,11 @@ func (s *queueTransactionsStatements) SelectQueueTransactions(
return result, rows.Err() return result, rows.Err()
} }
func (s *queueTransactionsStatements) SelectQueueTransactionCount( func (s *relayQueueStatements) SelectQueueEntryCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) { ) (int64, error) {
var count int64 var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsCountStmt) stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count) err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given // It's acceptable for there to be no rows referencing a given

View file

@ -0,0 +1,53 @@
// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"database/sql"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database stores information needed by the federation sender
type Database struct {
shared.Database
db *sql.DB
writer sqlutil.Writer
}
// NewDatabase opens a new database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) {
var d Database
var err error
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
return nil, err
}
queue, err := NewSQLiteRelayQueueTable(d.db)
if err != nil {
return nil, err
}
queueJSON, err := NewSQLiteRelayQueueJSONTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
IsLocalServerName: isLocalServerName,
Cache: cache,
Writer: d.writer,
RelayQueue: queue,
RelayQueueJSON: queueJSON,
}
return &d, nil
}

View file

@ -0,0 +1,41 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !wasm
// +build !wasm
package storage
import (
"fmt"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/relayapi/storage/postgres"
"github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// NewDatabase opens a new database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName)
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -0,0 +1,35 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tables
import (
"context"
"database/sql"
"github.com/matrix-org/gomatrixserverlib"
)
type RelayQueue interface {
InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
}
type RelayQueueJSON interface {
InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error)
DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error
SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error)
}

View file

@ -8,10 +8,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/postgres"
"github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
"github.com/matrix-org/dendrite/relayapi/storage/tables"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -35,31 +35,31 @@ func mustCreateTransaction() gomatrixserverlib.Transaction {
return txn return txn
} }
type TransactionJSONDatabase struct { type RelayQueueJSONDatabase struct {
DB *sql.DB DB *sql.DB
Writer sqlutil.Writer Writer sqlutil.Writer
Table tables.FederationTransactionJSON Table tables.RelayQueueJSON
} }
func mustCreateTransactionJSONTable(t *testing.T, dbType test.DBType) (database TransactionJSONDatabase, close func()) { func mustCreateQueueJSONTable(t *testing.T, dbType test.DBType) (database RelayQueueJSONDatabase, close func()) {
t.Helper() t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{ db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr), ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter()) }, sqlutil.NewExclusiveWriter())
assert.NoError(t, err) assert.NoError(t, err)
var tab tables.FederationTransactionJSON var tab tables.RelayQueueJSON
switch dbType { switch dbType {
case test.DBTypePostgres: case test.DBTypePostgres:
tab, err = postgres.NewPostgresTransactionJSONTable(db) tab, err = postgres.NewPostgresRelayQueueJSONTable(db)
assert.NoError(t, err) assert.NoError(t, err)
case test.DBTypeSQLite: case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteTransactionJSONTable(db) tab, err = sqlite3.NewSQLiteRelayQueueJSONTable(db)
assert.NoError(t, err) assert.NoError(t, err)
} }
assert.NoError(t, err) assert.NoError(t, err)
database = TransactionJSONDatabase{ database = RelayQueueJSONDatabase{
DB: db, DB: db,
Writer: sqlutil.NewDummyWriter(), Writer: sqlutil.NewDummyWriter(),
Table: tab, Table: tab,
@ -70,7 +70,7 @@ func mustCreateTransactionJSONTable(t *testing.T, dbType test.DBType) (database
func TestShoudInsertTransaction(t *testing.T) { func TestShoudInsertTransaction(t *testing.T) {
ctx := context.Background() ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateTransactionJSONTable(t, dbType) db, close := mustCreateQueueJSONTable(t, dbType)
defer close() defer close()
transaction := mustCreateTransaction() transaction := mustCreateTransaction()
@ -79,7 +79,7 @@ func TestShoudInsertTransaction(t *testing.T) {
t.Fatalf("Invalid transaction: %s", err.Error()) t.Fatalf("Invalid transaction: %s", err.Error())
} }
_, err = db.Table.InsertTransactionJSON(ctx, nil, string(tx)) _, err = db.Table.InsertQueueJSON(ctx, nil, string(tx))
if err != nil { if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error()) t.Fatalf("Failed inserting transaction: %s", err.Error())
} }
@ -89,7 +89,7 @@ func TestShoudInsertTransaction(t *testing.T) {
func TestShouldRetrieveInsertedTransaction(t *testing.T) { func TestShouldRetrieveInsertedTransaction(t *testing.T) {
ctx := context.Background() ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateTransactionJSONTable(t, dbType) db, close := mustCreateQueueJSONTable(t, dbType)
defer close() defer close()
transaction := mustCreateTransaction() transaction := mustCreateTransaction()
@ -98,14 +98,14 @@ func TestShouldRetrieveInsertedTransaction(t *testing.T) {
t.Fatalf("Invalid transaction: %s", err.Error()) t.Fatalf("Invalid transaction: %s", err.Error())
} }
nid, err := db.Table.InsertTransactionJSON(ctx, nil, string(tx)) nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx))
if err != nil { if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error()) t.Fatalf("Failed inserting transaction: %s", err.Error())
} }
var storedJSON map[int64][]byte var storedJSON map[int64][]byte
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
storedJSON, err = db.Table.SelectTransactionJSON(ctx, txn, []int64{nid}) storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid})
return err return err
}) })
if err != nil { if err != nil {
@ -124,7 +124,7 @@ func TestShouldRetrieveInsertedTransaction(t *testing.T) {
func TestShouldDeleteTransaction(t *testing.T) { func TestShouldDeleteTransaction(t *testing.T) {
ctx := context.Background() ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateTransactionJSONTable(t, dbType) db, close := mustCreateQueueJSONTable(t, dbType)
defer close() defer close()
transaction := mustCreateTransaction() transaction := mustCreateTransaction()
@ -133,14 +133,14 @@ func TestShouldDeleteTransaction(t *testing.T) {
t.Fatalf("Invalid transaction: %s", err.Error()) t.Fatalf("Invalid transaction: %s", err.Error())
} }
nid, err := db.Table.InsertTransactionJSON(ctx, nil, string(tx)) nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx))
if err != nil { if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error()) t.Fatalf("Failed inserting transaction: %s", err.Error())
} }
storedJSON := map[int64][]byte{} storedJSON := map[int64][]byte{}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
err = db.Table.DeleteTransactionJSON(ctx, txn, []int64{nid}) err = db.Table.DeleteQueueJSON(ctx, txn, []int64{nid})
return err return err
}) })
if err != nil { if err != nil {
@ -149,7 +149,7 @@ func TestShouldDeleteTransaction(t *testing.T) {
storedJSON = map[int64][]byte{} storedJSON = map[int64][]byte{}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
storedJSON, err = db.Table.SelectTransactionJSON(ctx, txn, []int64{nid}) storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid})
return err return err
}) })
if err != nil { if err != nil {

View file

@ -7,41 +7,41 @@ import (
"testing" "testing"
"time" "time"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/postgres"
"github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
"github.com/matrix-org/dendrite/relayapi/storage/tables"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
type QueueTransactionsDatabase struct { type RelayQueueDatabase struct {
DB *sql.DB DB *sql.DB
Writer sqlutil.Writer Writer sqlutil.Writer
Table tables.FederationQueueTransactions Table tables.RelayQueue
} }
func mustCreateQueueTransactionsTable(t *testing.T, dbType test.DBType) (database QueueTransactionsDatabase, close func()) { func mustCreateQueueTable(t *testing.T, dbType test.DBType) (database RelayQueueDatabase, close func()) {
t.Helper() t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{ db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr), ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter()) }, sqlutil.NewExclusiveWriter())
assert.NoError(t, err) assert.NoError(t, err)
var tab tables.FederationQueueTransactions var tab tables.RelayQueue
switch dbType { switch dbType {
case test.DBTypePostgres: case test.DBTypePostgres:
tab, err = postgres.NewPostgresQueueTransactionsTable(db) tab, err = postgres.NewPostgresRelayQueueTable(db)
assert.NoError(t, err) assert.NoError(t, err)
case test.DBTypeSQLite: case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteQueueTransactionsTable(db) tab, err = sqlite3.NewSQLiteRelayQueueTable(db)
assert.NoError(t, err) assert.NoError(t, err)
} }
assert.NoError(t, err) assert.NoError(t, err)
database = QueueTransactionsDatabase{ database = RelayQueueDatabase{
DB: db, DB: db,
Writer: sqlutil.NewDummyWriter(), Writer: sqlutil.NewDummyWriter(),
Table: tab, Table: tab,
@ -52,13 +52,13 @@ func mustCreateQueueTransactionsTable(t *testing.T, dbType test.DBType) (databas
func TestShoudInsertQueueTransaction(t *testing.T) { func TestShoudInsertQueueTransaction(t *testing.T) {
ctx := context.Background() ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTransactionsTable(t, dbType) db, close := mustCreateQueueTable(t, dbType)
defer close() defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain") serverName := gomatrixserverlib.ServerName("domain")
nid := int64(1) nid := int64(1)
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid) err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil { if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error()) t.Fatalf("Failed inserting transaction: %s", err.Error())
} }
@ -68,19 +68,19 @@ func TestShoudInsertQueueTransaction(t *testing.T) {
func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) { func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) {
ctx := context.Background() ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTransactionsTable(t, dbType) db, close := mustCreateQueueTable(t, dbType)
defer close() defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain") serverName := gomatrixserverlib.ServerName("domain")
nid := int64(1) nid := int64(1)
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid) err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil { if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error()) t.Fatalf("Failed inserting transaction: %s", err.Error())
} }
retrievedNids, err := db.Table.SelectQueueTransactions(ctx, nil, serverName, 10) retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 10)
if err != nil { if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error()) t.Fatalf("Failed retrieving transaction: %s", err.Error())
} }
@ -93,27 +93,27 @@ func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) {
func TestShouldDeleteQueueTransaction(t *testing.T) { func TestShouldDeleteQueueTransaction(t *testing.T) {
ctx := context.Background() ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTransactionsTable(t, dbType) db, close := mustCreateQueueTable(t, dbType)
defer close() defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain") serverName := gomatrixserverlib.ServerName("domain")
nid := int64(1) nid := int64(1)
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid) err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil { if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error()) t.Fatalf("Failed inserting transaction: %s", err.Error())
} }
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
err = db.Table.DeleteQueueTransactions(ctx, txn, serverName, []int64{nid}) err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid})
return err return err
}) })
if err != nil { if err != nil {
t.Fatalf("Failed deleting transaction: %s", err.Error()) t.Fatalf("Failed deleting transaction: %s", err.Error())
} }
count, err := db.Table.SelectQueueTransactionCount(ctx, nil, serverName) count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName)
if err != nil { if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error()) t.Fatalf("Failed retrieving transaction count: %s", err.Error())
} }
@ -124,7 +124,7 @@ func TestShouldDeleteQueueTransaction(t *testing.T) {
func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) { func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) {
ctx := context.Background() ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTransactionsTable(t, dbType) db, close := mustCreateQueueTable(t, dbType)
defer close() defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
@ -135,34 +135,34 @@ func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) {
nid2 := int64(2) nid2 := int64(2)
transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano())) transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano()))
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid) err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil { if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error()) t.Fatalf("Failed inserting transaction: %s", err.Error())
} }
err = db.Table.InsertQueueTransaction(ctx, nil, transactionID2, serverName2, nid) err = db.Table.InsertQueueEntry(ctx, nil, transactionID2, serverName2, nid)
if err != nil { if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error()) t.Fatalf("Failed inserting transaction: %s", err.Error())
} }
err = db.Table.InsertQueueTransaction(ctx, nil, transactionID3, serverName, nid2) err = db.Table.InsertQueueEntry(ctx, nil, transactionID3, serverName, nid2)
if err != nil { if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error()) t.Fatalf("Failed inserting transaction: %s", err.Error())
} }
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
err = db.Table.DeleteQueueTransactions(ctx, txn, serverName, []int64{nid}) err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid})
return err return err
}) })
if err != nil { if err != nil {
t.Fatalf("Failed deleting transaction: %s", err.Error()) t.Fatalf("Failed deleting transaction: %s", err.Error())
} }
count, err := db.Table.SelectQueueTransactionCount(ctx, nil, serverName) count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName)
if err != nil { if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error()) t.Fatalf("Failed retrieving transaction count: %s", err.Error())
} }
assert.Equal(t, int64(1), count) assert.Equal(t, int64(1), count)
count, err = db.Table.SelectQueueTransactionCount(ctx, nil, serverName2) count, err = db.Table.SelectQueueEntryCount(ctx, nil, serverName2)
if err != nil { if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error()) t.Fatalf("Failed retrieving transaction count: %s", err.Error())
} }

View file

@ -62,6 +62,7 @@ type Dendrite struct {
RoomServer RoomServer `yaml:"room_server"` RoomServer RoomServer `yaml:"room_server"`
SyncAPI SyncAPI `yaml:"sync_api"` SyncAPI SyncAPI `yaml:"sync_api"`
UserAPI UserAPI `yaml:"user_api"` UserAPI UserAPI `yaml:"user_api"`
RelayAPI RelayAPI `yaml:"relay_api"`
MSCs MSCs `yaml:"mscs"` MSCs MSCs `yaml:"mscs"`

View file

@ -0,0 +1,38 @@
package config
type RelayAPI struct {
Matrix *Global `yaml:"-"`
InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"`
ExternalAPI ExternalAPIOptions `yaml:"external_api,omitempty"`
// The database stores information used by the relay queue to
// forward transactions to remote servers.
Database DatabaseOptions `yaml:"database,omitempty"`
}
func (c *RelayAPI) Defaults(opts DefaultOpts) {
if !opts.Monolithic {
c.InternalAPI.Listen = "http://localhost:7775"
c.InternalAPI.Connect = "http://localhost:7775"
c.ExternalAPI.Listen = "http://[::]:8075"
c.Database.Defaults(10)
}
if opts.Generate {
if !opts.Monolithic {
c.Database.ConnectionString = "file:relayapi.db"
}
}
}
func (c *RelayAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
if isMonolith { // polylith required configs below
return
}
if c.Matrix.DatabaseOptions.ConnectionString == "" {
checkNotEmpty(configErrs, "relay_api.database.connection_string", string(c.Database.ConnectionString))
}
checkURL(configErrs, "relay_api.external_api.listen", string(c.ExternalAPI.Listen))
checkURL(configErrs, "relay_api.internal_api.listen", string(c.InternalAPI.Listen))
checkURL(configErrs, "relay_api.internal_api.connect", string(c.InternalAPI.Connect))
}

View file

@ -23,6 +23,8 @@ import (
"github.com/matrix-org/dendrite/internal/transactions" "github.com/matrix-org/dendrite/internal/transactions"
keyAPI "github.com/matrix-org/dendrite/keyserver/api" keyAPI "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/mediaapi" "github.com/matrix-org/dendrite/mediaapi"
"github.com/matrix-org/dendrite/relayapi"
relayAPI "github.com/matrix-org/dendrite/relayapi/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -44,6 +46,7 @@ type Monolith struct {
RoomserverAPI roomserverAPI.RoomserverInternalAPI RoomserverAPI roomserverAPI.RoomserverInternalAPI
UserAPI userapi.UserInternalAPI UserAPI userapi.UserInternalAPI
KeyAPI keyAPI.KeyInternalAPI KeyAPI keyAPI.KeyInternalAPI
RelayAPI relayAPI.RelayInternalAPI
// Optional // Optional
ExtPublicRoomsProvider api.ExtraPublicRoomsProvider ExtPublicRoomsProvider api.ExtraPublicRoomsProvider
@ -71,4 +74,9 @@ func (m *Monolith) AddAllPublicRoutes(base *base.BaseDendrite) {
syncapi.AddPublicRoutes( syncapi.AddPublicRoutes(
base, m.UserAPI, m.RoomserverAPI, m.KeyAPI, base, m.UserAPI, m.RoomserverAPI, m.KeyAPI,
) )
// TODO : relayapi.AddPublicRoutes
if m.RelayAPI != nil {
relayapi.AddPublicRoutes(base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.RelayAPI, m.FederationAPI, m.KeyAPI)
}
} }