Refactor all relay specific stuff into it's own component

This commit is contained in:
Devon Hudson 2022-12-14 18:41:27 -07:00
parent f300a4d0e9
commit ad53326ce8
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
39 changed files with 1433 additions and 850 deletions

View file

@ -458,7 +458,7 @@ func (m *DendriteMonolith) Start() {
switch e := event.(type) {
case pineconeEvents.PeerAdded:
if !relayServerSyncRunning.Load() {
go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning)
// go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning)
}
case pineconeEvents.PeerRemoved:
if relayServerSyncRunning.Load() && m.PineconeRouter.PeerCount(-1) == 0 {
@ -486,46 +486,6 @@ func (m *DendriteMonolith) Start() {
}(pineconeEventChannel)
}
func (m *DendriteMonolith) syncRelayServers(stop <-chan bool, running atomic.Bool) {
defer running.Store(false)
t := time.NewTimer(relayServerRetryInterval)
for {
relayServersToQuery := []gomatrixserverlib.ServerName{}
for server, complete := range m.relayServersQueried {
if !complete {
relayServersToQuery = append(relayServersToQuery, server)
}
}
if len(relayServersToQuery) == 0 {
// All relay servers have been synced.
return
}
m.queryRelayServers(relayServersToQuery)
t.Reset(relayServerRetryInterval)
select {
case <-stop:
if !t.Stop() {
<-t.C
}
return
case <-t.C:
}
}
}
func (m *DendriteMonolith) queryRelayServers(relayServers []gomatrixserverlib.ServerName) {
for _, server := range relayServers {
request := api.PerformRelayServerSyncRequest{RelayServer: server}
response := api.PerformRelayServerSyncResponse{}
err := m.federationAPI.PerformRelayServerSync(m.processContext.Context(), &request, &response)
if err == nil {
m.relayServersQueried[server] = true
}
}
}
func (m *DendriteMonolith) Stop() {
m.processContext.ShutdownDendrite()
_ = m.listener.Close()

View file

@ -43,6 +43,7 @@ import (
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/relayapi"
relayServerAPI "github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/base"
@ -145,6 +146,7 @@ func main() {
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(*instanceDir, *instanceName)))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(*instanceDir, *instanceName)))
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", filepath.Join(*instanceDir, *instanceName)))
cfg.RelayAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-relayapi.db", filepath.Join(*instanceDir, *instanceName)))
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName)))
cfg.ClientAPI.RegistrationDisabled = false
@ -235,6 +237,20 @@ func main() {
userProvider := users.NewPineconeUserProvider(pRouter, pQUIC, userAPI, federation)
roomProvider := rooms.NewPineconeRoomProvider(pRouter, pQUIC, fsAPI, federation)
js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
producer := &producers.SyncAPIProducer{
JetStream: js,
TopicReceiptEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent),
TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
TopicTypingEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent),
TopicPresenceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent),
TopicDeviceListUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
TopicSigningKeyUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
Config: &base.Cfg.FederationAPI,
UserAPI: userAPI,
}
relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer)
monolith := setup.Monolith{
Config: base.Cfg,
Client: conn.CreateClient(base, pQUIC),
@ -246,6 +262,7 @@ func main() {
RoomserverAPI: rsAPI,
UserAPI: userAPI,
KeyAPI: keyAPI,
RelayAPI: &relayAPI,
ExtPublicRoomsProvider: roomProvider,
ExtUserDirectoryProvider: userProvider,
}
@ -319,32 +336,12 @@ func main() {
relayServerSyncRunning := atomic.NewBool(false)
stopRelayServerSync := make(chan bool)
js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.FederationAPI.Matrix.JetStream)
producer := &producers.SyncAPIProducer{
JetStream: js,
TopicReceiptEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent),
TopicSendToDeviceEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
TopicTypingEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputTypingEvent),
TopicPresenceEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent),
TopicDeviceListUpdate: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
TopicSigningKeyUpdate: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
Config: &base.Cfg.FederationAPI,
UserAPI: userAPI,
}
m := RelayServerRetriever{
Context: context.Background(),
ServerName: gomatrixserverlib.ServerName(pRouter.PublicKey().String()),
FederationAPI: fsAPI,
RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool),
RelayAPI: relayapi.NewRelayAPI(
federation,
rsAPI,
keyRing,
producer,
cfg.Global.Presence.EnableInbound,
cfg.Global.ServerName,
),
RelayAPI: monolith.RelayAPI,
}
m.InitializeRelayServers(eLog)
@ -387,7 +384,7 @@ type RelayServerRetriever struct {
ServerName gomatrixserverlib.ServerName
FederationAPI api.FederationInternalAPI
RelayServersQueried map[gomatrixserverlib.ServerName]bool
RelayAPI relayapi.RelayAPI
RelayAPI relayServerAPI.RelayInternalAPI
}
func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) {
@ -440,7 +437,12 @@ func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverli
if err != nil {
return
}
err = m.RelayAPI.PerformRelayServerSync(*userID, server)
request := relayServerAPI.PerformRelayServerSyncRequest{
UserID: *userID,
RelayServer: server,
}
response := relayServerAPI.PerformRelayServerSyncResponse{}
err = m.RelayAPI.PerformRelayServerSync(context.Background(), &request, &response)
if err == nil {
m.RelayServersQueried[server] = true
// TODO : What happens if your relay receives new messages after this point?

View file

@ -18,7 +18,6 @@ type FederationInternalAPI interface {
gomatrixserverlib.KeyDatabase
ClientFederationAPI
RoomserverFederationAPI
RelayServerAPI
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
@ -45,22 +44,6 @@ type FederationInternalAPI interface {
) error
}
type RelayServerAPI interface {
// Store async transactions for forwarding to the destination at a later time.
PerformStoreAsync(
ctx context.Context,
request *PerformStoreAsyncRequest,
response *PerformStoreAsyncResponse,
) error
// Obtain the oldest stored transaction for the specified userID.
QueryAsyncTransactions(
ctx context.Context,
request *QueryAsyncTransactionsRequest,
response *QueryAsyncTransactionsResponse,
) error
}
type ClientFederationAPI interface {
// Query the server names of the joined hosts in a room.
// Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice

View file

@ -14,7 +14,6 @@ import (
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/consumers"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/version"
)
@ -847,64 +846,3 @@ func (r *FederationInternalAPI) QueryRelayServers(
response.RelayServers = relayServers
return nil
}
// PerformStoreAsync implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformStoreAsync(
ctx context.Context,
request *api.PerformStoreAsyncRequest,
response *api.PerformStoreAsyncResponse,
) error {
logrus.Warnf("Storing transaction for %v", request.UserID)
receipt, err := r.db.StoreAsyncTransaction(ctx, request.Txn)
if err != nil {
return err
}
err = r.db.AssociateAsyncTransactionWithDestinations(
ctx,
map[gomatrixserverlib.UserID]struct{}{
request.UserID: {},
},
request.Txn.TransactionID,
receipt)
return err
}
// QueryAsyncTransactions implements api.FederationInternalAPI
func (r *FederationInternalAPI) QueryAsyncTransactions(
ctx context.Context,
request *api.QueryAsyncTransactionsRequest,
response *api.QueryAsyncTransactionsResponse,
) error {
logrus.Warnf("Obtaining transaction for %v", request.UserID)
transaction, receipt, err := r.db.GetAsyncTransaction(ctx, request.UserID)
if err != nil {
return err
}
// TODO : Shouldn't be deleting unless the transaction was successfully returned...
// TODO : Should delete transaction json from table if no more associations
// Maybe track last received transaction, and send that as part of the request,
// then delete before getting the new events from the db.
if transaction != nil && receipt != nil {
err = r.db.CleanAsyncTransactions(ctx, request.UserID, []*shared.Receipt{receipt})
if err != nil {
return err
}
// TODO : Clean async transactions json
}
// TODO : These db calls should happen at the same time right?
count, err := r.db.GetAsyncTransactionCount(ctx, request.UserID)
if err != nil {
return err
}
response.RemainingCount = uint32(count)
if transaction != nil {
response.Txn = *transaction
logrus.Warnf("Obtained transaction: %v", transaction.TransactionID)
}
return nil
}

View file

@ -26,9 +26,6 @@ const (
FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers"
FederationAPIQueryRelayServers = "/federationapi/queryRelayServers"
FederationAPIPerformStoreAsyncPath = "/federationapi/performStoreAsync"
FederationAPIQueryAsyncTransactionsPath = "/federationapi/queryAsyncTransactions"
FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices"
FederationAPIClaimKeysPath = "/federationapi/client/claimKeys"
FederationAPIQueryKeysPath = "/federationapi/client/queryKeys"
@ -525,25 +522,3 @@ func (h *httpFederationInternalAPI) QueryRelayServers(
h.httpClient, ctx, request, response,
)
}
func (h *httpFederationInternalAPI) PerformStoreAsync(
ctx context.Context,
request *api.PerformStoreAsyncRequest,
response *api.PerformStoreAsyncResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformStoreAsync", h.federationAPIURL+FederationAPIPerformStoreAsyncPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpFederationInternalAPI) QueryAsyncTransactions(
ctx context.Context,
request *api.QueryAsyncTransactionsRequest,
response *api.QueryAsyncTransactionsResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryAsyncTransactions", h.federationAPIURL+FederationAPIQueryAsyncTransactionsPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -43,16 +43,6 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU),
)
internalAPIMux.Handle(
FederationAPIPerformStoreAsyncPath,
httputil.MakeInternalRPCAPI("FederationAPIPerformStoreAsync", intAPI.PerformStoreAsync),
)
internalAPIMux.Handle(
FederationAPIQueryAsyncTransactionsPath,
httputil.MakeInternalRPCAPI("FederationAPIQueryAsyncTransactions", intAPI.QueryAsyncTransactions),
)
internalAPIMux.Handle(
FederationAPIPerformWakeupServers,
httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", intAPI.PerformWakeupServers),

View file

@ -1,199 +0,0 @@
package routing_test
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"sync"
"testing"
"time"
"github.com/matrix-org/dendrite/federationapi/internal"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
const (
testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
testDestination = gomatrixserverlib.ServerName("white.orchard")
)
type testDatabase struct {
nid int64
nidMutex sync.Mutex
transactions map[int64]json.RawMessage
associations map[gomatrixserverlib.ServerName][]int64
}
func createDatabase() *testDatabase {
return &testDatabase{
nid: 1,
nidMutex: sync.Mutex{},
transactions: make(map[int64]json.RawMessage),
associations: make(map[gomatrixserverlib.ServerName][]int64),
}
}
func (d *testDatabase) InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error {
if _, ok := d.associations[serverName]; !ok {
d.associations[serverName] = []int64{}
}
d.associations[serverName] = append(d.associations[serverName], nid)
return nil
}
func (d *testDatabase) DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error {
for _, nid := range jsonNIDs {
for index, associatedNID := range d.associations[serverName] {
if associatedNID == nid {
d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...)
}
}
}
return nil
}
func (d *testDatabase) SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) {
results := []int64{}
resultCount := limit
if limit > len(d.associations[serverName]) {
resultCount = len(d.associations[serverName])
}
if resultCount > 0 {
for i := 0; i < resultCount; i++ {
results = append(results, d.associations[serverName][i])
}
}
return results, nil
}
func (d *testDatabase) SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) {
return int64(len(d.associations[serverName])), nil
}
func (d *testDatabase) InsertTransactionJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) {
d.nidMutex.Lock()
defer d.nidMutex.Unlock()
nid := d.nid
d.transactions[nid] = []byte(json)
d.nid++
return nid, nil
}
func (d *testDatabase) DeleteTransactionJSON(ctx context.Context, txn *sql.Tx, nids []int64) error {
for _, nid := range nids {
delete(d.transactions, nid)
}
return nil
}
func (d *testDatabase) SelectTransactionJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) {
result := make(map[int64][]byte)
for _, nid := range jsonNIDs {
if transaction, ok := d.transactions[nid]; ok {
result[nid] = transaction
}
}
return result, nil
}
func createTransaction() gomatrixserverlib.Transaction {
txn := gomatrixserverlib.Transaction{}
txn.PDUs = []json.RawMessage{
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`),
}
txn.Origin = testOrigin
txn.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
txn.Destination = testDestination
return txn
}
func createFederationRequest(userID gomatrixserverlib.UserID) (gomatrixserverlib.Transaction, gomatrixserverlib.FederationRequest) {
txn := createTransaction()
var federationPathPrefixV1 = "/_matrix/federation/v1"
path := federationPathPrefixV1 + "/forward_async/" + string(txn.TransactionID) + "/" + userID.Raw()
request := gomatrixserverlib.NewFederationRequest("PUT", txn.Origin, txn.Destination, path)
request.SetContent(txn)
return txn, request
}
func TestEmptyForwardReturnsOk(t *testing.T) {
testDB := createDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
FederationQueueTransactions: testDB,
FederationTransactionJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
_, request := createFederationRequest(*userID)
fedAPI := internal.NewFederationInternalAPI(
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
)
response := routing.ForwardAsync(httpReq, &request, fedAPI, "1", *userID)
expected := 200
if response.Code != expected {
t.Fatalf("Expected: %v, Actual: %v", expected, response.Code)
}
}
func TestUniqueTransactionStoredInDatabase(t *testing.T) {
testDB := createDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
FederationQueueTransactions: testDB,
FederationTransactionJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
inputTransaction, request := createFederationRequest(*userID)
fedAPI := internal.NewFederationInternalAPI(
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
)
response := routing.ForwardAsync(
httpReq, &request, fedAPI, inputTransaction.TransactionID, *userID)
transaction, _, err := db.GetAsyncTransaction(context.TODO(), *userID)
if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
expected := 200
if response.Code != expected {
t.Fatalf("Expected Return Code: %v, Actual: %v", expected, response.Code)
}
if transactionCount != 1 {
t.Fatalf("Expected count of 1, Actual: %d", transactionCount)
}
if transaction.TransactionID != inputTransaction.TransactionID {
t.Fatalf("Expected Transaction ID: %s, Actual: %s",
inputTransaction.TransactionID, transaction.TransactionID)
}
}

View file

@ -133,37 +133,6 @@ func Setup(
},
)).Methods(http.MethodPut, http.MethodOptions)
v1fedmux.Handle("/forward_async/{txnID}/{userID}", MakeFedAPI(
"federation_forward_async", "", cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username was invalid"),
}
}
return ForwardAsync(
httpReq, request, fsAPI, gomatrixserverlib.TransactionID(vars["txnID"]),
*userID,
)
},
)).Methods(http.MethodPut, http.MethodOptions)
v1fedmux.Handle("/async_events/{userID}", MakeFedAPI(
"federation_async_events", "", cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username was invalid"),
}
}
return GetAsyncEvents(httpReq, request, fsAPI, *userID)
},
)).Methods(http.MethodGet, http.MethodOptions)
v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI(
"federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {

View file

@ -51,12 +51,6 @@ type Database interface {
GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
StoreAsyncTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*shared.Receipt, error)
AssociateAsyncTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *shared.Receipt) error
CleanAsyncTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*shared.Receipt) error
GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *shared.Receipt, error)
GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error)
// these don't have contexts passed in as we want things to happen regardless of the request context
AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error
RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error

View file

@ -58,18 +58,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, err
}
queueTransactions, err := NewPostgresQueueTransactionsTable(d.db)
if err != nil {
return nil, err
}
queueJSON, err := NewPostgresQueueJSONTable(d.db)
if err != nil {
return nil, err
}
transactionJSON, err := NewPostgresTransactionJSONTable(d.db)
if err != nil {
return nil, err
}
assumedOffline, err := NewPostgresAssumedOfflineTable(d.db)
if err != nil {
return nil, err
@ -111,24 +103,22 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
return nil, err
}
d.Database = shared.Database{
DB: d.db,
IsLocalServerName: isLocalServerName,
Cache: cache,
Writer: d.writer,
FederationJoinedHosts: joinedHosts,
FederationQueuePDUs: queuePDUs,
FederationQueueEDUs: queueEDUs,
FederationQueueJSON: queueJSON,
FederationQueueTransactions: queueTransactions,
FederationTransactionJSON: transactionJSON,
FederationBlacklist: blacklist,
FederationAssumedOffline: assumedOffline,
FederationRelayServers: relayServers,
FederationInboundPeeks: inboundPeeks,
FederationOutboundPeeks: outboundPeeks,
NotaryServerKeysJSON: notaryJSON,
NotaryServerKeysMetadata: notaryMetadata,
ServerSigningKeys: serverSigningKeys,
DB: d.db,
IsLocalServerName: isLocalServerName,
Cache: cache,
Writer: d.writer,
FederationJoinedHosts: joinedHosts,
FederationQueuePDUs: queuePDUs,
FederationQueueEDUs: queueEDUs,
FederationQueueJSON: queueJSON,
FederationBlacklist: blacklist,
FederationAssumedOffline: assumedOffline,
FederationRelayServers: relayServers,
FederationInboundPeeks: inboundPeeks,
FederationOutboundPeeks: outboundPeeks,
NotaryServerKeysJSON: notaryJSON,
NotaryServerKeysMetadata: notaryMetadata,
ServerSigningKeys: serverSigningKeys,
}
return &d, nil
}

View file

@ -17,7 +17,6 @@ package shared
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
@ -29,24 +28,22 @@ import (
)
type Database struct {
DB *sql.DB
IsLocalServerName func(gomatrixserverlib.ServerName) bool
Cache caching.FederationCache
Writer sqlutil.Writer
FederationQueuePDUs tables.FederationQueuePDUs
FederationQueueEDUs tables.FederationQueueEDUs
FederationQueueJSON tables.FederationQueueJSON
FederationJoinedHosts tables.FederationJoinedHosts
FederationBlacklist tables.FederationBlacklist
FederationAssumedOffline tables.FederationAssumedOffline
FederationRelayServers tables.FederationRelayServers
FederationOutboundPeeks tables.FederationOutboundPeeks
FederationInboundPeeks tables.FederationInboundPeeks
NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON
NotaryServerKeysMetadata tables.FederationNotaryServerKeysMetadata
ServerSigningKeys tables.FederationServerSigningKeys
FederationQueueTransactions tables.FederationQueueTransactions
FederationTransactionJSON tables.FederationTransactionJSON
DB *sql.DB
IsLocalServerName func(gomatrixserverlib.ServerName) bool
Cache caching.FederationCache
Writer sqlutil.Writer
FederationQueuePDUs tables.FederationQueuePDUs
FederationQueueEDUs tables.FederationQueueEDUs
FederationQueueJSON tables.FederationQueueJSON
FederationJoinedHosts tables.FederationJoinedHosts
FederationBlacklist tables.FederationBlacklist
FederationAssumedOffline tables.FederationAssumedOffline
FederationRelayServers tables.FederationRelayServers
FederationOutboundPeeks tables.FederationOutboundPeeks
FederationInboundPeeks tables.FederationInboundPeeks
NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON
NotaryServerKeysMetadata tables.FederationNotaryServerKeysMetadata
ServerSigningKeys tables.FederationServerSigningKeys
}
// An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs.
@ -61,6 +58,10 @@ func NewReceipt(nid int64) Receipt {
return Receipt{nid: nid}
}
func (r *Receipt) GetNID() int64 {
return r.nid
}
func (r *Receipt) String() string {
return fmt.Sprintf("%d", r.nid)
}
@ -308,108 +309,3 @@ func (d *Database) GetNotaryKeys(
})
return sks, err
}
func (d *Database) StoreAsyncTransaction(
ctx context.Context, txn gomatrixserverlib.Transaction,
) (*Receipt, error) {
var err error
json, err := json.Marshal(txn)
if err != nil {
return nil, fmt.Errorf("d.JSONUnmarshall: %w", err)
}
var nid int64
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nid, err = d.FederationTransactionJSON.InsertTransactionJSON(ctx, txn, string(json))
return err
})
if err != nil {
return nil, fmt.Errorf("d.insertTransactionJSON: %w", err)
}
return &Receipt{
nid: nid,
}, nil
}
func (d *Database) AssociateAsyncTransactionWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.UserID]struct{},
transactionID gomatrixserverlib.TransactionID,
receipt *Receipt,
) error {
for destination := range destinations {
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.FederationQueueTransactions.InsertQueueTransaction(
ctx, txn, transactionID, destination.Domain(), receipt.nid)
return err
})
if err != nil {
return fmt.Errorf("d.insertQueueTransaction: %w", err)
}
}
return nil
}
func (d *Database) CleanAsyncTransactions(
ctx context.Context,
userID gomatrixserverlib.UserID,
receipts []*Receipt,
) error {
println(len(receipts))
nids := make([]int64, len(receipts))
for i, receipt := range receipts {
nids[i] = receipt.nid
}
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.FederationQueueTransactions.DeleteQueueTransactions(ctx, txn, userID.Domain(), nids)
return err
})
if err != nil {
return fmt.Errorf("d.insertQueueTransaction: %w", err)
}
return nil
}
func (d *Database) GetAsyncTransaction(
ctx context.Context,
userID gomatrixserverlib.UserID,
) (*gomatrixserverlib.Transaction, *Receipt, error) {
nids, err := d.FederationQueueTransactions.SelectQueueTransactions(ctx, nil, userID.Domain(), 1)
if err != nil {
return nil, nil, fmt.Errorf("d.SelectQueueTransaction: %w", err)
}
if len(nids) == 0 {
return nil, nil, nil
}
txns := map[int64][]byte{}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
txns, err = d.FederationTransactionJSON.SelectTransactionJSON(ctx, txn, nids)
return err
})
if err != nil {
return nil, nil, fmt.Errorf("d.SelectTransactionJSON: %w", err)
}
transaction := &gomatrixserverlib.Transaction{}
err = json.Unmarshal(txns[nids[0]], transaction)
if err != nil {
return nil, nil, fmt.Errorf("Unmarshall transaction: %w", err)
}
receipt := NewReceipt(nids[0])
return transaction, &receipt, nil
}
func (d *Database) GetAsyncTransactionCount(
ctx context.Context,
userID gomatrixserverlib.UserID,
) (int64, error) {
count, err := d.FederationQueueTransactions.SelectQueueTransactionCount(ctx, nil, userID.Domain())
if err != nil {
return 0, fmt.Errorf("d.SelectQueueTransactionCount: %w", err)
}
return count, nil
}

View file

@ -55,14 +55,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, err
}
queueTransactions, err := NewSQLiteQueueTransactionsTable(d.db)
if err != nil {
return nil, err
}
transactionJSON, err := NewSQLiteTransactionJSONTable(d.db)
if err != nil {
return nil, err
}
assumedOffline, err := NewSQLiteAssumedOfflineTable(d.db)
if err != nil {
return nil, err
@ -104,24 +96,22 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
return nil, err
}
d.Database = shared.Database{
DB: d.db,
IsLocalServerName: isLocalServerName,
Cache: cache,
Writer: d.writer,
FederationJoinedHosts: joinedHosts,
FederationQueuePDUs: queuePDUs,
FederationQueueEDUs: queueEDUs,
FederationQueueJSON: queueJSON,
FederationQueueTransactions: queueTransactions,
FederationTransactionJSON: transactionJSON,
FederationBlacklist: blacklist,
FederationAssumedOffline: assumedOffline,
FederationRelayServers: relayServers,
FederationOutboundPeeks: outboundPeeks,
FederationInboundPeeks: inboundPeeks,
NotaryServerKeysJSON: notaryKeys,
NotaryServerKeysMetadata: notaryKeysMetadata,
ServerSigningKeys: serverSigningKeys,
DB: d.db,
IsLocalServerName: isLocalServerName,
Cache: cache,
Writer: d.writer,
FederationJoinedHosts: joinedHosts,
FederationQueuePDUs: queuePDUs,
FederationQueueEDUs: queueEDUs,
FederationQueueJSON: queueJSON,
FederationBlacklist: blacklist,
FederationAssumedOffline: assumedOffline,
FederationRelayServers: relayServers,
FederationOutboundPeeks: outboundPeeks,
FederationInboundPeeks: inboundPeeks,
NotaryServerKeysJSON: notaryKeys,
NotaryServerKeysMetadata: notaryKeysMetadata,
ServerSigningKeys: serverSigningKeys,
}
return &d, nil
}

81
relayapi/api/api.go Normal file
View file

@ -0,0 +1,81 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"context"
"github.com/matrix-org/gomatrixserverlib"
)
// RelayInternalAPI is used to query information from the relay server.
type RelayInternalAPI interface {
RelayServerAPI
PerformRelayServerSync(
ctx context.Context,
request *PerformRelayServerSyncRequest,
response *PerformRelayServerSyncResponse,
) error
}
type RelayServerAPI interface {
// Store async transactions for forwarding to the destination at a later time.
PerformStoreAsync(
ctx context.Context,
request *PerformStoreAsyncRequest,
response *PerformStoreAsyncResponse,
) error
// Obtain the oldest stored transaction for the specified userID.
QueryAsyncTransactions(
ctx context.Context,
request *QueryAsyncTransactionsRequest,
response *QueryAsyncTransactionsResponse,
) error
}
type PerformRelayServerSyncRequest struct {
UserID gomatrixserverlib.UserID `json:"user_id"`
RelayServer gomatrixserverlib.ServerName `json:"relay_name"`
}
type PerformRelayServerSyncResponse struct {
}
type QueryRelayServersRequest struct {
Server gomatrixserverlib.ServerName
}
type QueryRelayServersResponse struct {
RelayServers []gomatrixserverlib.ServerName
}
type PerformStoreAsyncRequest struct {
Txn gomatrixserverlib.Transaction `json:"transaction"`
UserID gomatrixserverlib.UserID `json:"user_id"`
}
type PerformStoreAsyncResponse struct {
}
type QueryAsyncTransactionsRequest struct {
UserID gomatrixserverlib.UserID `json:"user_id"`
}
type QueryAsyncTransactionsResponse struct {
Txn gomatrixserverlib.Transaction `json:"transaction"`
RemainingCount uint32 `json:"remaining"`
}

52
relayapi/internal/api.go Normal file
View file

@ -0,0 +1,52 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/relayapi/storage"
rsAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
)
type RelayInternalAPI struct {
db storage.Database
fedClient *gomatrixserverlib.FederationClient
rsAPI rsAPI.RoomserverInternalAPI
keyRing *gomatrixserverlib.KeyRing
producer *producers.SyncAPIProducer
presenceEnabledInbound bool
serverName gomatrixserverlib.ServerName
}
func NewRelayInternalAPI(
db storage.Database,
fedClient *gomatrixserverlib.FederationClient,
rsAPI rsAPI.RoomserverInternalAPI,
keyRing *gomatrixserverlib.KeyRing,
producer *producers.SyncAPIProducer,
presenceEnabledInbound bool,
serverName gomatrixserverlib.ServerName,
) RelayInternalAPI {
return RelayInternalAPI{
db: db,
fedClient: fedClient,
rsAPI: rsAPI,
keyRing: keyRing,
producer: producer,
presenceEnabledInbound: presenceEnabledInbound,
serverName: serverName,
}
}

View file

@ -0,0 +1,131 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
// PerformRelayServerSync implements api.FederationInternalAPI
func (r *RelayInternalAPI) PerformRelayServerSync(
ctx context.Context,
request *api.PerformRelayServerSyncRequest,
response *api.PerformRelayServerSyncResponse,
) error {
asyncResponse, err := r.fedClient.GetAsyncEvents(ctx, request.UserID, request.RelayServer)
if err != nil {
logrus.Errorf("GetAsyncEvents: %s", err.Error())
return err
}
r.processTransaction(&asyncResponse.Transaction)
for asyncResponse.Remaining > 0 {
asyncResponse, err := r.fedClient.GetAsyncEvents(ctx, request.UserID, request.RelayServer)
if err != nil {
logrus.Errorf("GetAsyncEvents: %s", err.Error())
return err
}
r.processTransaction(&asyncResponse.Transaction)
}
return nil
}
// PerformStoreAsync implements api.RelayInternalAPI
func (r *RelayInternalAPI) PerformStoreAsync(
ctx context.Context,
request *api.PerformStoreAsyncRequest,
response *api.PerformStoreAsyncResponse,
) error {
logrus.Warnf("Storing transaction for %v", request.UserID)
receipt, err := r.db.StoreAsyncTransaction(ctx, request.Txn)
if err != nil {
return err
}
err = r.db.AssociateAsyncTransactionWithDestinations(
ctx,
map[gomatrixserverlib.UserID]struct{}{
request.UserID: {},
},
request.Txn.TransactionID,
receipt)
return err
}
// QueryAsyncTransactions implements api.RelayInternalAPI
func (r *RelayInternalAPI) QueryAsyncTransactions(
ctx context.Context,
request *api.QueryAsyncTransactionsRequest,
response *api.QueryAsyncTransactionsResponse,
) error {
logrus.Warnf("Obtaining transaction for %v", request.UserID)
transaction, receipt, err := r.db.GetAsyncTransaction(ctx, request.UserID)
if err != nil {
return err
}
// TODO : Shouldn't be deleting unless the transaction was successfully returned...
// TODO : Should delete transaction json from table if no more associations
// Maybe track last received transaction, and send that as part of the request,
// then delete before getting the new events from the db.
if transaction != nil && receipt != nil {
err = r.db.CleanAsyncTransactions(ctx, request.UserID, []*shared.Receipt{receipt})
if err != nil {
return err
}
// TODO : Clean async transactions json
}
// TODO : These db calls should happen at the same time right?
count, err := r.db.GetAsyncTransactionCount(ctx, request.UserID)
if err != nil {
return err
}
response.RemainingCount = uint32(count)
if transaction != nil {
response.Txn = *transaction
logrus.Warnf("Obtained transaction: %v", transaction.TransactionID)
}
return nil
}
func (r *RelayInternalAPI) processTransaction(txn *gomatrixserverlib.Transaction) {
logrus.Warn("Processing transaction from relay server")
mu := internal.NewMutexByRoom()
t := internal.NewTxnReq(
r.rsAPI,
nil,
r.serverName,
r.keyRing,
mu,
r.producer,
r.presenceEnabledInbound,
txn.PDUs,
txn.EDUs,
txn.Origin,
txn.TransactionID,
txn.Destination)
t.ProcessTransaction(context.TODO())
}

View file

@ -0,0 +1,70 @@
package inthttp
import (
"context"
"errors"
"net/http"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/relayapi/api"
)
// HTTP paths for the internal HTTP API
const (
RelayAPIPerformRelayServerSyncPath = "/relayapi/performRelayServerSync"
RelayAPIPerformStoreAsyncPath = "/relayapi/performStoreAsync"
RelayAPIQueryAsyncTransactionsPath = "/relayapi/queryAsyncTransactions"
)
// NewRelayAPIClient creates a RelayInternalAPI implemented by talking to a HTTP POST API.
// If httpClient is nil an error is returned
func NewRelayAPIClient(relayapiURL string, httpClient *http.Client, cache caching.ServerKeyCache) (api.RelayInternalAPI, error) {
if httpClient == nil {
return nil, errors.New("NewRelayInternalAPIHTTP: httpClient is <nil>")
}
return &httpRelayInternalAPI{
relayAPIURL: relayapiURL,
httpClient: httpClient,
cache: cache,
}, nil
}
type httpRelayInternalAPI struct {
relayAPIURL string
httpClient *http.Client
cache caching.ServerKeyCache
}
func (h *httpRelayInternalAPI) PerformRelayServerSync(
ctx context.Context,
request *api.PerformRelayServerSyncRequest,
response *api.PerformRelayServerSyncResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformRelayServerSync", h.relayAPIURL+RelayAPIPerformRelayServerSyncPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRelayInternalAPI) PerformStoreAsync(
ctx context.Context,
request *api.PerformStoreAsyncRequest,
response *api.PerformStoreAsyncResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformStoreAsync", h.relayAPIURL+RelayAPIPerformStoreAsyncPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRelayInternalAPI) QueryAsyncTransactions(
ctx context.Context,
request *api.QueryAsyncTransactionsRequest,
response *api.QueryAsyncTransactionsResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryAsyncTransactions", h.relayAPIURL+RelayAPIQueryAsyncTransactionsPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -0,0 +1,27 @@
package inthttp
import (
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/relayapi/api"
)
// AddRoutes adds the RelayInternalAPI handlers to the http.ServeMux.
// nolint:gocyclo
func AddRoutes(intAPI api.RelayInternalAPI, internalAPIMux *mux.Router) {
internalAPIMux.Handle(
RelayAPIPerformRelayServerSyncPath,
httputil.MakeInternalRPCAPI("RelayAPIPerformRelayServerSync", intAPI.PerformRelayServerSync),
)
internalAPIMux.Handle(
RelayAPIPerformStoreAsyncPath,
httputil.MakeInternalRPCAPI("RelayAPIPerformStoreAsync", intAPI.PerformStoreAsync),
)
internalAPIMux.Handle(
RelayAPIQueryAsyncTransactionsPath,
httputil.MakeInternalRPCAPI("RelayAPIQueryAsyncTransactions", intAPI.QueryAsyncTransactions),
)
}

View file

@ -15,79 +15,77 @@
package relayapi
import (
"context"
"github.com/gorilla/mux"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/internal"
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/relayapi/api"
relayAPI "github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/dendrite/relayapi/internal"
"github.com/matrix-org/dendrite/relayapi/inthttp"
"github.com/matrix-org/dendrite/relayapi/routing"
"github.com/matrix-org/dendrite/relayapi/storage"
rsAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
type RelayAPI struct {
fedClient *gomatrixserverlib.FederationClient
rsAPI rsAPI.RoomserverInternalAPI
keyRing *gomatrixserverlib.KeyRing
producer *producers.SyncAPIProducer
presenceEnabledInbound bool
serverName gomatrixserverlib.ServerName
// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions
// on the given input API.
func AddInternalRoutes(router *mux.Router, intAPI api.RelayInternalAPI) {
inthttp.AddRoutes(intAPI, router)
}
func NewRelayAPI(
// AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component.
func AddPublicRoutes(
base *base.BaseDendrite,
userAPI userapi.UserInternalAPI,
fedClient *gomatrixserverlib.FederationClient,
keyRing gomatrixserverlib.JSONVerifier,
rsAPI rsAPI.FederationRoomserverAPI,
relayAPI relayAPI.RelayInternalAPI,
fedAPI federationAPI.FederationInternalAPI,
keyAPI keyserverAPI.FederationKeyAPI,
) {
fedCfg := &base.Cfg.FederationAPI
relay, ok := relayAPI.(*internal.RelayInternalAPI)
if !ok {
panic("relayapi.AddPublicRoutes called with a RelayInternalAPI impl which was not " +
"RelayInternalAPI. This is a programming error.")
}
routing.Setup(
base.PublicFederationAPIMux,
fedCfg,
relay,
keyRing,
)
}
func NewRelayInternalAPI(
base *base.BaseDendrite,
fedClient *gomatrixserverlib.FederationClient,
rsAPI rsAPI.RoomserverInternalAPI,
keyRing *gomatrixserverlib.KeyRing,
producer *producers.SyncAPIProducer,
presenceEnabledInbound bool,
serverName gomatrixserverlib.ServerName,
) RelayAPI {
return RelayAPI{
fedClient: fedClient,
rsAPI: rsAPI,
keyRing: keyRing,
producer: producer,
presenceEnabledInbound: presenceEnabledInbound,
serverName: serverName,
}
}
) internal.RelayInternalAPI {
cfg := &base.Cfg.RelayAPI
// PerformRelayServerSync implements api.FederationInternalAPI
func (r *RelayAPI) PerformRelayServerSync(userID gomatrixserverlib.UserID, relayServer gomatrixserverlib.ServerName) error {
asyncResponse, err := r.fedClient.GetAsyncEvents(context.Background(), userID, relayServer)
relayDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName)
if err != nil {
logrus.Errorf("GetAsyncEvents: %s", err.Error())
return err
}
r.processTransaction(&asyncResponse.Transaction)
for asyncResponse.Remaining > 0 {
asyncResponse, err := r.fedClient.GetAsyncEvents(context.Background(), userID, relayServer)
if err != nil {
logrus.Errorf("GetAsyncEvents: %s", err.Error())
return err
}
r.processTransaction(&asyncResponse.Transaction)
logrus.WithError(err).Panic("failed to connect to relay db")
}
return nil
}
func (r *RelayAPI) processTransaction(txn *gomatrixserverlib.Transaction) {
logrus.Warn("Processing transaction from relay server")
mu := internal.NewMutexByRoom()
t := internal.NewTxnReq(
r.rsAPI,
nil,
r.serverName,
r.keyRing,
mu,
r.producer,
r.presenceEnabledInbound,
txn.PDUs,
txn.EDUs,
txn.Origin,
txn.TransactionID,
txn.Destination)
t.ProcessTransaction(context.TODO())
return internal.NewRelayInternalAPI(
relayDB,
fedClient,
rsAPI,
keyRing,
producer,
base.Cfg.Global.Presence.EnableInbound,
base.Cfg.Global.ServerName,
)
}

View file

@ -3,7 +3,7 @@ package routing
import (
"net/http"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
@ -19,12 +19,12 @@ type AsyncEventsResponse struct {
func GetAsyncEvents(
httpReq *http.Request,
fedReq *gomatrixserverlib.FederationRequest,
fedAPI api.FederationInternalAPI,
relayAPI api.RelayInternalAPI,
userID gomatrixserverlib.UserID,
) util.JSONResponse {
logrus.Infof("Handling async_events for %v", userID)
var response api.QueryAsyncTransactionsResponse
err := fedAPI.QueryAsyncTransactions(httpReq.Context(), &api.QueryAsyncTransactionsRequest{UserID: userID}, &response)
err := relayAPI.QueryAsyncTransactions(httpReq.Context(), &api.QueryAsyncTransactionsRequest{UserID: userID}, &response)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,

View file

@ -5,21 +5,21 @@ import (
"net/http"
"testing"
"github.com/matrix-org/dendrite/federationapi/internal"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/relayapi/internal"
"github.com/matrix-org/dendrite/relayapi/routing"
"github.com/matrix-org/dendrite/relayapi/storage"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) {
testDB := createDatabase()
testDB := storage.NewFakeRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
FederationQueueTransactions: testDB,
FederationTransactionJSON: testDB,
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
@ -33,11 +33,11 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) {
t.Fatalf("Failed to store transaction: %s", err.Error())
}
fedAPI := internal.NewFederationInternalAPI(
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse := response.JSON.(routing.AsyncEventsResponse)
@ -46,11 +46,11 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) {
}
func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
testDB := createDatabase()
testDB := storage.NewFakeRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
FederationQueueTransactions: testDB,
FederationTransactionJSON: testDB,
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
@ -74,11 +74,11 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
t.Fatalf("Failed to associate transaction with user: %s", err.Error())
}
fedAPI := internal.NewFederationInternalAPI(
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse := response.JSON.(routing.AsyncEventsResponse)
@ -87,11 +87,11 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
}
func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) {
testDB := createDatabase()
testDB := storage.NewFakeRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
FederationQueueTransactions: testDB,
FederationTransactionJSON: testDB,
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
@ -131,18 +131,18 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) {
t.Fatalf("Failed to associate transaction with user: %s", err.Error())
}
fedAPI := internal.NewFederationInternalAPI(
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse := response.JSON.(routing.AsyncEventsResponse)
assert.Equal(t, uint32(1), jsonResponse.Remaining)
assert.Equal(t, transaction, jsonResponse.Transaction)
response = routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
response = routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse = response.JSON.(routing.AsyncEventsResponse)

View file

@ -5,7 +5,7 @@ import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@ -15,7 +15,7 @@ import (
func ForwardAsync(
httpReq *http.Request,
fedReq *gomatrixserverlib.FederationRequest,
fedAPI api.FederationInternalAPI,
relayAPI api.RelayInternalAPI,
txnID gomatrixserverlib.TransactionID,
userID gomatrixserverlib.UserID,
) util.JSONResponse {
@ -54,7 +54,7 @@ func ForwardAsync(
UserID: userID,
}
res := api.PerformStoreAsyncResponse{}
err := fedAPI.PerformStoreAsync(httpReq.Context(), &req, &res)
err := relayAPI.PerformStoreAsync(httpReq.Context(), &req, &res)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,

View file

@ -0,0 +1,111 @@
package routing_test
import (
"context"
"encoding/json"
"fmt"
"net/http"
"testing"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/internal"
"github.com/matrix-org/dendrite/relayapi/routing"
"github.com/matrix-org/dendrite/relayapi/storage"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/gomatrixserverlib"
)
const (
testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
testDestination = gomatrixserverlib.ServerName("white.orchard")
)
func createTransaction() gomatrixserverlib.Transaction {
txn := gomatrixserverlib.Transaction{}
txn.PDUs = []json.RawMessage{
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`),
}
txn.Origin = testOrigin
txn.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
txn.Destination = testDestination
return txn
}
func createFederationRequest(userID gomatrixserverlib.UserID) (gomatrixserverlib.Transaction, gomatrixserverlib.FederationRequest) {
txn := createTransaction()
var federationPathPrefixV1 = "/_matrix/federation/v1"
path := federationPathPrefixV1 + "/forward_async/" + string(txn.TransactionID) + "/" + userID.Raw()
request := gomatrixserverlib.NewFederationRequest("PUT", txn.Origin, txn.Destination, path)
request.SetContent(txn)
return txn, request
}
func TestEmptyForwardReturnsOk(t *testing.T) {
testDB := storage.NewFakeRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
_, request := createFederationRequest(*userID)
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.ForwardAsync(httpReq, &request, &relayAPI, "1", *userID)
expected := 200
if response.Code != expected {
t.Fatalf("Expected: %v, Actual: %v", expected, response.Code)
}
}
func TestUniqueTransactionStoredInDatabase(t *testing.T) {
testDB := storage.NewFakeRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
inputTransaction, request := createFederationRequest(*userID)
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.ForwardAsync(
httpReq, &request, &relayAPI, inputTransaction.TransactionID, *userID)
transaction, _, err := db.GetAsyncTransaction(context.TODO(), *userID)
if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
expected := 200
if response.Code != expected {
t.Fatalf("Expected Return Code: %v, Actual: %v", expected, response.Code)
}
if transactionCount != 1 {
t.Fatalf("Expected count of 1, Actual: %d", transactionCount)
}
if transaction.TransactionID != inputTransaction.TransactionID {
t.Fatalf("Expected Transaction ID: %s, Actual: %s",
inputTransaction.TransactionID, transaction.TransactionID)
}
}

123
relayapi/routing/routing.go Normal file
View file

@ -0,0 +1,123 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"fmt"
"net/http"
"time"
"github.com/getsentry/sentry-go"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/httputil"
relayInternal "github.com/matrix-org/dendrite/relayapi/internal"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// Setup registers HTTP handlers with the given ServeMux.
// The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly
// path unescape twice (once from the router, once from MakeRelayAPI). We need to have this enabled
// so we can decode paths like foo/bar%2Fbaz as [foo, bar/baz] - by default it will decode to [foo, bar, baz]
//
// Due to Setup being used to call many other functions, a gocyclo nolint is
// applied:
// nolint: gocyclo
func Setup(
fedMux *mux.Router,
cfg *config.FederationAPI,
relayAPI *relayInternal.RelayInternalAPI,
keys gomatrixserverlib.JSONVerifier,
) {
v1fedmux := fedMux.PathPrefix("/v1").Subrouter()
v1fedmux.Handle("/forward_async/{txnID}/{userID}", MakeRelayAPI(
"relay_forward_async", "", cfg.Matrix.IsLocalServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username was invalid"),
}
}
return ForwardAsync(
httpReq, request, relayAPI, gomatrixserverlib.TransactionID(vars["txnID"]),
*userID,
)
},
)).Methods(http.MethodPut, http.MethodOptions)
v1fedmux.Handle("/async_events/{userID}", MakeRelayAPI(
"relay_async_events", "", cfg.Matrix.IsLocalServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username was invalid"),
}
}
return GetAsyncEvents(httpReq, request, relayAPI, *userID)
},
)).Methods(http.MethodGet, http.MethodOptions)
}
// MakeRelayAPI makes an http.Handler that checks matrix relay authentication.
func MakeRelayAPI(
metricsName string, serverName gomatrixserverlib.ServerName,
isLocalServerName func(gomatrixserverlib.ServerName) bool,
keyRing gomatrixserverlib.JSONVerifier,
f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse,
) http.Handler {
h := func(req *http.Request) util.JSONResponse {
fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
req, time.Now(), serverName, isLocalServerName, keyRing,
)
if fedReq == nil {
return errResp
}
// add the user to Sentry, if enabled
hub := sentry.GetHubFromContext(req.Context())
if hub != nil {
hub.Scope().SetTag("origin", string(fedReq.Origin()))
hub.Scope().SetTag("uri", fedReq.RequestURI())
}
defer func() {
if r := recover(); r != nil {
if hub != nil {
hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path))
}
// re-panic to return the 500
panic(r)
}
}()
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.MatrixErrorResponse(400, "M_UNRECOGNISED", "badly encoded query params")
}
jsonRes := f(req, fedReq, vars)
// do not log 4xx as errors as they are client fails, not server fails
if hub != nil && jsonRes.Code >= 500 {
hub.Scope().SetExtra("response", jsonRes)
hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code))
}
return jsonRes
}
return httputil.MakeExternalAPI(metricsName, h)
}

View file

@ -0,0 +1,95 @@
package storage
import (
"context"
"database/sql"
"encoding/json"
"sync"
"github.com/matrix-org/gomatrixserverlib"
)
type testDatabase struct {
nid int64
nidMutex sync.Mutex
transactions map[int64]json.RawMessage
associations map[gomatrixserverlib.ServerName][]int64
}
func NewFakeRelayDatabase() *testDatabase {
return &testDatabase{
nid: 1,
nidMutex: sync.Mutex{},
transactions: make(map[int64]json.RawMessage),
associations: make(map[gomatrixserverlib.ServerName][]int64),
}
}
func (d *testDatabase) InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error {
if _, ok := d.associations[serverName]; !ok {
d.associations[serverName] = []int64{}
}
d.associations[serverName] = append(d.associations[serverName], nid)
return nil
}
func (d *testDatabase) DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error {
for _, nid := range jsonNIDs {
for index, associatedNID := range d.associations[serverName] {
if associatedNID == nid {
d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...)
}
}
}
return nil
}
func (d *testDatabase) SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) {
results := []int64{}
resultCount := limit
if limit > len(d.associations[serverName]) {
resultCount = len(d.associations[serverName])
}
if resultCount > 0 {
for i := 0; i < resultCount; i++ {
results = append(results, d.associations[serverName][i])
}
}
return results, nil
}
func (d *testDatabase) SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) {
return int64(len(d.associations[serverName])), nil
}
func (d *testDatabase) InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) {
d.nidMutex.Lock()
defer d.nidMutex.Unlock()
nid := d.nid
d.transactions[nid] = []byte(json)
d.nid++
return nid, nil
}
func (d *testDatabase) DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error {
for _, nid := range nids {
delete(d.transactions, nid)
}
return nil
}
func (d *testDatabase) SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) {
result := make(map[int64][]byte)
for _, nid := range jsonNIDs {
if transaction, ok := d.transactions[nid]; ok {
result[nid] = transaction
}
}
return result, nil
}

View file

@ -0,0 +1,30 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
StoreAsyncTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*shared.Receipt, error)
AssociateAsyncTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *shared.Receipt) error
CleanAsyncTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*shared.Receipt) error
GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *shared.Receipt, error)
GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error)
}

View file

@ -23,60 +23,60 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const transactionJSONSchema = `
-- The federationsender_transaction_json table contains event contents that
const relayQueueJSONSchema = `
-- The relayapi_queue_json table contains event contents that
-- we are storing for future forwarding.
CREATE TABLE IF NOT EXISTS federationsender_transaction_json (
CREATE TABLE IF NOT EXISTS relayapi_queue_json (
-- The JSON NID. This allows cross-referencing to find the JSON blob.
json_nid BIGSERIAL,
-- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_transaction_json_json_nid_idx
ON federationsender_transaction_json (json_nid);
CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx
ON relayapi_queue_json (json_nid);
`
const insertTransactionJSONSQL = "" +
"INSERT INTO federationsender_transaction_json (json_body)" +
const insertQueueJSONSQL = "" +
"INSERT INTO relayapi_queue_json (json_body)" +
" VALUES ($1)" +
" RETURNING json_nid"
const deleteTransactionJSONSQL = "" +
"DELETE FROM federationsender_transaction_json WHERE json_nid = ANY($1)"
const deleteQueueJSONSQL = "" +
"DELETE FROM relayapi_queue_json WHERE json_nid = ANY($1)"
const selectTransactionJSONSQL = "" +
"SELECT json_nid, json_body FROM federationsender_transaction_json" +
const selectQueueJSONSQL = "" +
"SELECT json_nid, json_body FROM relayapi_queue_json" +
" WHERE json_nid = ANY($1)"
type transactionJSONStatements struct {
type relayQueueJSONStatements struct {
db *sql.DB
insertJSONStmt *sql.Stmt
deleteJSONStmt *sql.Stmt
selectJSONStmt *sql.Stmt
}
func NewPostgresTransactionJSONTable(db *sql.DB) (s *transactionJSONStatements, err error) {
s = &transactionJSONStatements{
func NewPostgresRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) {
s = &relayQueueJSONStatements{
db: db,
}
_, err = s.db.Exec(transactionJSONSchema)
_, err = s.db.Exec(relayQueueJSONSchema)
if err != nil {
return
}
if s.insertJSONStmt, err = s.db.Prepare(insertTransactionJSONSQL); err != nil {
if s.insertJSONStmt, err = s.db.Prepare(insertQueueJSONSQL); err != nil {
return
}
if s.deleteJSONStmt, err = s.db.Prepare(deleteTransactionJSONSQL); err != nil {
if s.deleteJSONStmt, err = s.db.Prepare(deleteQueueJSONSQL); err != nil {
return
}
if s.selectJSONStmt, err = s.db.Prepare(selectTransactionJSONSQL); err != nil {
if s.selectJSONStmt, err = s.db.Prepare(selectQueueJSONSQL); err != nil {
return
}
return
}
func (s *transactionJSONStatements) InsertTransactionJSON(
func (s *relayQueueJSONStatements) InsertQueueJSON(
ctx context.Context, txn *sql.Tx, json string,
) (int64, error) {
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
@ -87,7 +87,7 @@ func (s *transactionJSONStatements) InsertTransactionJSON(
return lastid, nil
}
func (s *transactionJSONStatements) DeleteTransactionJSON(
func (s *relayQueueJSONStatements) DeleteQueueJSON(
ctx context.Context, txn *sql.Tx, nids []int64,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt)
@ -95,7 +95,7 @@ func (s *transactionJSONStatements) DeleteTransactionJSON(
return err
}
func (s *transactionJSONStatements) SelectTransactionJSON(
func (s *relayQueueJSONStatements) SelectQueueJSON(
ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
) (map[int64][]byte, error) {
blobs := map[int64][]byte{}

View file

@ -24,79 +24,79 @@ import (
"github.com/matrix-org/gomatrixserverlib"
)
const queueTransactionsSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_transactions (
const relayQueueSchema = `
CREATE TABLE IF NOT EXISTS relayapi_queue (
-- The transaction ID that was generated before persisting the event.
transaction_id TEXT NOT NULL,
-- The destination server that we will send the event to.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_transaction_json table.
-- The JSON NID from the relayapi_queue_json table.
json_nid BIGINT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_transactions_transaction_json_nid_idx
ON federationsender_queue_transactions (json_nid, server_name);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_json_nid_idx
ON federationsender_queue_transactions (json_nid);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_server_name_idx
ON federationsender_queue_transactions (server_name);
CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx
ON relayapi_queue (json_nid, server_name);
CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx
ON relayapi_queue (json_nid);
CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx
ON relayapi_queue (server_name);
`
const insertQueueTransactionSQL = "" +
"INSERT INTO federationsender_queue_transactions (transaction_id, server_name, json_nid)" +
const insertQueueEntrySQL = "" +
"INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
const deleteQueueTransactionsSQL = "" +
"DELETE FROM federationsender_queue_transactions WHERE server_name = $1 AND json_nid = ANY($2)"
const deleteQueueEntriesSQL = "" +
"DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid = ANY($2)"
const selectQueueTransactionsSQL = "" +
"SELECT json_nid FROM federationsender_queue_transactions" +
const selectQueueEntriesSQL = "" +
"SELECT json_nid FROM relayapi_queue" +
" WHERE server_name = $1" +
" LIMIT $2"
const selectQueueTransactionsCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_transactions" +
const selectQueueEntryCountSQL = "" +
"SELECT COUNT(*) FROM relayapi_queue" +
" WHERE server_name = $1"
type queueTransactionsStatements struct {
db *sql.DB
insertQueueTransactionStmt *sql.Stmt
deleteQueueTransactionsStmt *sql.Stmt
selectQueueTransactionsStmt *sql.Stmt
selectQueueTransactionsCountStmt *sql.Stmt
type relayQueueStatements struct {
db *sql.DB
insertQueueEntryStmt *sql.Stmt
deleteQueueEntriesStmt *sql.Stmt
selectQueueEntriesStmt *sql.Stmt
selectQueueEntryCountStmt *sql.Stmt
}
func NewPostgresQueueTransactionsTable(db *sql.DB) (s *queueTransactionsStatements, err error) {
s = &queueTransactionsStatements{
func NewPostgresRelayQueueTable(db *sql.DB) (s *relayQueueStatements, err error) {
s = &relayQueueStatements{
db: db,
}
_, err = s.db.Exec(queueTransactionsSchema)
_, err = s.db.Exec(relayQueueSchema)
if err != nil {
return
}
if s.insertQueueTransactionStmt, err = s.db.Prepare(insertQueueTransactionSQL); err != nil {
if s.insertQueueEntryStmt, err = s.db.Prepare(insertQueueEntrySQL); err != nil {
return
}
if s.deleteQueueTransactionsStmt, err = s.db.Prepare(deleteQueueTransactionsSQL); err != nil {
if s.deleteQueueEntriesStmt, err = s.db.Prepare(deleteQueueEntriesSQL); err != nil {
return
}
if s.selectQueueTransactionsStmt, err = s.db.Prepare(selectQueueTransactionsSQL); err != nil {
if s.selectQueueEntriesStmt, err = s.db.Prepare(selectQueueEntriesSQL); err != nil {
return
}
if s.selectQueueTransactionsCountStmt, err = s.db.Prepare(selectQueueTransactionsCountSQL); err != nil {
if s.selectQueueEntryCountStmt, err = s.db.Prepare(selectQueueEntryCountSQL); err != nil {
return
}
return
}
func (s *queueTransactionsStatements) InsertQueueTransaction(
func (s *relayQueueStatements) InsertQueueEntry(
ctx context.Context,
txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueTransactionStmt)
stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt)
_, err := stmt.ExecContext(
ctx,
transactionID, // the transaction ID that we initially attempted
@ -106,22 +106,22 @@ func (s *queueTransactionsStatements) InsertQueueTransaction(
return err
}
func (s *queueTransactionsStatements) DeleteQueueTransactions(
func (s *relayQueueStatements) DeleteQueueEntries(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
jsonNIDs []int64,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionsStmt)
stmt := sqlutil.TxStmt(txn, s.deleteQueueEntriesStmt)
_, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs))
return err
}
func (s *queueTransactionsStatements) SelectQueueTransactions(
func (s *relayQueueStatements) SelectQueueEntries(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsStmt)
stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil {
return nil, err
@ -139,11 +139,11 @@ func (s *queueTransactionsStatements) SelectQueueTransactions(
return result, rows.Err()
}
func (s *queueTransactionsStatements) SelectQueueTransactionCount(
func (s *relayQueueStatements) SelectQueueEntryCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsCountStmt)
stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given

View file

@ -0,0 +1,59 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"database/sql"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database stores information needed by the relayapi
type Database struct {
shared.Database
db *sql.DB
writer sqlutil.Writer
}
// NewDatabase opens a new database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) {
var d Database
var err error
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
return nil, err
}
queue, err := NewPostgresRelayQueueTable(d.db)
if err != nil {
return nil, err
}
queueJSON, err := NewPostgresRelayQueueJSONTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
IsLocalServerName: isLocalServerName,
Cache: cache,
Writer: d.writer,
RelayQueue: queue,
RelayQueueJSON: queueJSON,
}
return &d, nil
}

View file

@ -0,0 +1,142 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shared
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
type Database struct {
DB *sql.DB
IsLocalServerName func(gomatrixserverlib.ServerName) bool
Cache caching.FederationCache
Writer sqlutil.Writer
RelayQueue tables.RelayQueue
RelayQueueJSON tables.RelayQueueJSON
}
func (d *Database) StoreAsyncTransaction(
ctx context.Context, txn gomatrixserverlib.Transaction,
) (*shared.Receipt, error) {
var err error
json, err := json.Marshal(txn)
if err != nil {
return nil, fmt.Errorf("d.JSONUnmarshall: %w", err)
}
var nid int64
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nid, err = d.RelayQueueJSON.InsertQueueJSON(ctx, txn, string(json))
return err
})
if err != nil {
return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
}
receipt := shared.NewReceipt(nid)
return &receipt, nil
}
func (d *Database) AssociateAsyncTransactionWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.UserID]struct{},
transactionID gomatrixserverlib.TransactionID,
receipt *shared.Receipt,
) error {
for destination := range destinations {
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.RelayQueue.InsertQueueEntry(
ctx, txn, transactionID, destination.Domain(), receipt.GetNID())
return err
})
if err != nil {
return fmt.Errorf("d.insertQueueEntry: %w", err)
}
}
return nil
}
func (d *Database) CleanAsyncTransactions(
ctx context.Context,
userID gomatrixserverlib.UserID,
receipts []*shared.Receipt,
) error {
println(len(receipts))
nids := make([]int64, len(receipts))
for i, receipt := range receipts {
nids[i] = receipt.GetNID()
}
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.RelayQueue.DeleteQueueEntries(ctx, txn, userID.Domain(), nids)
return err
})
if err != nil {
return fmt.Errorf("d.deleteQueueEntries: %w", err)
}
return nil
}
func (d *Database) GetAsyncTransaction(
ctx context.Context,
userID gomatrixserverlib.UserID,
) (*gomatrixserverlib.Transaction, *shared.Receipt, error) {
nids, err := d.RelayQueue.SelectQueueEntries(ctx, nil, userID.Domain(), 1)
if err != nil {
return nil, nil, fmt.Errorf("d.SelectQueueEntries: %w", err)
}
if len(nids) == 0 {
return nil, nil, nil
}
txns := map[int64][]byte{}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
txns, err = d.RelayQueueJSON.SelectQueueJSON(ctx, txn, nids)
return err
})
if err != nil {
return nil, nil, fmt.Errorf("d.SelectQueueJSON: %w", err)
}
transaction := &gomatrixserverlib.Transaction{}
err = json.Unmarshal(txns[nids[0]], transaction)
if err != nil {
return nil, nil, fmt.Errorf("Unmarshall transaction: %w", err)
}
receipt := shared.NewReceipt(nids[0])
return transaction, &receipt, nil
}
func (d *Database) GetAsyncTransactionCount(
ctx context.Context,
userID gomatrixserverlib.UserID,
) (int64, error) {
count, err := d.RelayQueue.SelectQueueEntryCount(ctx, nil, userID.Domain())
if err != nil {
return 0, fmt.Errorf("d.SelectQueueEntryCount: %w", err)
}
return count, nil
}

View file

@ -24,53 +24,53 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const transactionJSONSchema = `
-- The federationsender_transaction_json table contains event contents that
const relayQueueJSONSchema = `
-- The relayapi_queue_json table contains event contents that
-- we are storing for future forwarding.
CREATE TABLE IF NOT EXISTS federationsender_transaction_json (
CREATE TABLE IF NOT EXISTS relayapi_queue_json (
-- The JSON NID. This allows cross-referencing to find the JSON blob.
json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
-- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_transaction_json_json_nid_idx
ON federationsender_transaction_json (json_nid);
CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx
ON relayapi_queue_json (json_nid);
`
const insertTransactionJSONSQL = "" +
"INSERT INTO federationsender_transaction_json (json_body)" +
const insertQueueJSONSQL = "" +
"INSERT INTO relayapi_queue_json (json_body)" +
" VALUES ($1)"
const deleteTransactionJSONSQL = "" +
"DELETE FROM federationsender_transaction_json WHERE json_nid IN ($1)"
const deleteQueueJSONSQL = "" +
"DELETE FROM relayapi_queue_json WHERE json_nid IN ($1)"
const selectTransactionJSONSQL = "" +
"SELECT json_nid, json_body FROM federationsender_transaction_json" +
const selectQueueJSONSQL = "" +
"SELECT json_nid, json_body FROM relayapi_queue_json" +
" WHERE json_nid IN ($1)"
type transactionJSONStatements struct {
type relayQueueJSONStatements struct {
db *sql.DB
insertJSONStmt *sql.Stmt
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
}
func NewSQLiteTransactionJSONTable(db *sql.DB) (s *transactionJSONStatements, err error) {
s = &transactionJSONStatements{
func NewSQLiteRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) {
s = &relayQueueJSONStatements{
db: db,
}
_, err = db.Exec(transactionJSONSchema)
_, err = db.Exec(relayQueueJSONSchema)
if err != nil {
return
}
if s.insertJSONStmt, err = db.Prepare(insertTransactionJSONSQL); err != nil {
if s.insertJSONStmt, err = db.Prepare(insertQueueJSONSQL); err != nil {
return
}
return
}
func (s *transactionJSONStatements) InsertTransactionJSON(
func (s *relayQueueJSONStatements) InsertQueueJSON(
ctx context.Context, txn *sql.Tx, json string,
) (lastid int64, err error) {
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
@ -85,13 +85,13 @@ func (s *transactionJSONStatements) InsertTransactionJSON(
return
}
func (s *transactionJSONStatements) DeleteTransactionJSON(
func (s *relayQueueJSONStatements) DeleteQueueJSON(
ctx context.Context, txn *sql.Tx, nids []int64,
) error {
deleteSQL := strings.Replace(deleteTransactionJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
deleteSQL := strings.Replace(deleteQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
deleteStmt, err := txn.Prepare(deleteSQL)
if err != nil {
return fmt.Errorf("s.deleteTransactionJSON s.db.Prepare: %w", err)
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
}
iNIDs := make([]interface{}, len(nids))
@ -104,13 +104,13 @@ func (s *transactionJSONStatements) DeleteTransactionJSON(
return err
}
func (s *transactionJSONStatements) SelectTransactionJSON(
func (s *relayQueueJSONStatements) SelectQueueJSON(
ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
) (map[int64][]byte, error) {
selectSQL := strings.Replace(selectTransactionJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
selectSQL := strings.Replace(selectQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
selectStmt, err := txn.Prepare(selectSQL)
if err != nil {
return nil, fmt.Errorf("s.selectTransactionJSON s.db.Prepare: %w", err)
return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err)
}
iNIDs := make([]interface{}, len(jsonNIDs))
@ -122,14 +122,14 @@ func (s *transactionJSONStatements) SelectTransactionJSON(
stmt := sqlutil.TxStmt(txn, selectStmt)
rows, err := stmt.QueryContext(ctx, iNIDs...)
if err != nil {
return nil, fmt.Errorf("s.selectTransactionJSON stmt.QueryContext: %w", err)
return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err)
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed")
defer internal.CloseAndLogIfError(ctx, rows, "selectQueueJSON: rows.close() failed")
for rows.Next() {
var nid int64
var blob []byte
if err = rows.Scan(&nid, &blob); err != nil {
return nil, fmt.Errorf("s.selectTransactionJSON rows.Scan: %w", err)
return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err)
}
blobs[nid] = blob
}

View file

@ -25,79 +25,79 @@ import (
"github.com/matrix-org/gomatrixserverlib"
)
const queueTransactionsSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_transactions (
const relayQueueSchema = `
CREATE TABLE IF NOT EXISTS relayapi_queue (
-- The transaction ID that was generated before persisting the event.
transaction_id TEXT NOT NULL,
-- The domain part of the user ID the m.room.member event is for.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_queue_transactions_json table.
-- The JSON NID from the relayapi_queue_json table.
json_nid BIGINT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_transactions_transaction_json_nid_idx
ON federationsender_queue_transactions (json_nid, server_name);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_json_nid_idx
ON federationsender_queue_transactions (json_nid);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_server_name_idx
ON federationsender_queue_transactions (server_name);
CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx
ON relayapi_queue (json_nid, server_name);
CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx
ON relayapi_queue (json_nid);
CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx
ON relayapi_queue (server_name);
`
const insertQueueTransactionSQL = "" +
"INSERT INTO federationsender_queue_transactions (transaction_id, server_name, json_nid)" +
const insertQueueEntrySQL = "" +
"INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
const deleteQueueTransactionsSQL = "" +
"DELETE FROM federationsender_queue_transactions WHERE server_name = $1 AND json_nid IN ($2)"
const deleteQueueEntriesSQL = "" +
"DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid IN ($2)"
const selectQueueTransactionsSQL = "" +
"SELECT json_nid FROM federationsender_queue_transactions" +
const selectQueueEntriesSQL = "" +
"SELECT json_nid FROM relayapi_queue" +
" WHERE server_name = $1" +
" LIMIT $2"
const selectQueueTransactionsCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_transactions" +
const selectQueueEntryCountSQL = "" +
"SELECT COUNT(*) FROM relayapi_queue" +
" WHERE server_name = $1"
type queueTransactionsStatements struct {
db *sql.DB
insertQueueTransactionStmt *sql.Stmt
selectQueueTransactionsStmt *sql.Stmt
selectQueueTransactionsCountStmt *sql.Stmt
// deleteQueueTransactionsStmt *sql.Stmt - prepared at runtime due to variadic
type relayQueueStatements struct {
db *sql.DB
insertQueueEntryStmt *sql.Stmt
selectQueueEntriesStmt *sql.Stmt
selectQueueEntryCountStmt *sql.Stmt
// deleteQueueEntriesStmt *sql.Stmt - prepared at runtime due to variadic
}
func NewSQLiteQueueTransactionsTable(db *sql.DB) (s *queueTransactionsStatements, err error) {
s = &queueTransactionsStatements{
func NewSQLiteRelayQueueTable(db *sql.DB) (s *relayQueueStatements, err error) {
s = &relayQueueStatements{
db: db,
}
_, err = db.Exec(queueTransactionsSchema)
_, err = db.Exec(relayQueueSchema)
if err != nil {
return
}
if s.insertQueueTransactionStmt, err = db.Prepare(insertQueueTransactionSQL); err != nil {
if s.insertQueueEntryStmt, err = db.Prepare(insertQueueEntrySQL); err != nil {
return
}
//if s.deleteQueueTransactionsStmt, err = db.Prepare(deleteQueueTransactionsSQL); err != nil {
//if s.deleteQueueEntriesStmt, err = db.Prepare(deleteQueueEntriesSQL); err != nil {
// return
//}
if s.selectQueueTransactionsStmt, err = db.Prepare(selectQueueTransactionsSQL); err != nil {
if s.selectQueueEntriesStmt, err = db.Prepare(selectQueueEntriesSQL); err != nil {
return
}
if s.selectQueueTransactionsCountStmt, err = db.Prepare(selectQueueTransactionsCountSQL); err != nil {
if s.selectQueueEntryCountStmt, err = db.Prepare(selectQueueEntryCountSQL); err != nil {
return
}
return
}
func (s *queueTransactionsStatements) InsertQueueTransaction(
func (s *relayQueueStatements) InsertQueueEntry(
ctx context.Context,
txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueTransactionStmt)
stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt)
_, err := stmt.ExecContext(
ctx,
transactionID, // the transaction ID that we initially attempted
@ -107,15 +107,15 @@ func (s *queueTransactionsStatements) InsertQueueTransaction(
return err
}
func (s *queueTransactionsStatements) DeleteQueueTransactions(
func (s *relayQueueStatements) DeleteQueueEntries(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
jsonNIDs []int64,
) error {
deleteSQL := strings.Replace(deleteQueueTransactionsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
deleteSQL := strings.Replace(deleteQueueEntriesSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
deleteStmt, err := txn.Prepare(deleteSQL)
if err != nil {
return fmt.Errorf("s.deleteQueueTransactionJSON s.db.Prepare: %w", err)
return fmt.Errorf("s.deleteQueueEntries s.db.Prepare: %w", err)
}
params := make([]interface{}, len(jsonNIDs)+1)
@ -129,12 +129,12 @@ func (s *queueTransactionsStatements) DeleteQueueTransactions(
return err
}
func (s *queueTransactionsStatements) SelectQueueTransactions(
func (s *relayQueueStatements) SelectQueueEntries(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsStmt)
stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil {
return nil, err
@ -152,11 +152,11 @@ func (s *queueTransactionsStatements) SelectQueueTransactions(
return result, rows.Err()
}
func (s *queueTransactionsStatements) SelectQueueTransactionCount(
func (s *relayQueueStatements) SelectQueueEntryCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsCountStmt)
stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given

View file

@ -0,0 +1,53 @@
// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"database/sql"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database stores information needed by the federation sender
type Database struct {
shared.Database
db *sql.DB
writer sqlutil.Writer
}
// NewDatabase opens a new database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) {
var d Database
var err error
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
return nil, err
}
queue, err := NewSQLiteRelayQueueTable(d.db)
if err != nil {
return nil, err
}
queueJSON, err := NewSQLiteRelayQueueJSONTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
IsLocalServerName: isLocalServerName,
Cache: cache,
Writer: d.writer,
RelayQueue: queue,
RelayQueueJSON: queueJSON,
}
return &d, nil
}

View file

@ -0,0 +1,41 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !wasm
// +build !wasm
package storage
import (
"fmt"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/relayapi/storage/postgres"
"github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// NewDatabase opens a new database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName)
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -0,0 +1,35 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tables
import (
"context"
"database/sql"
"github.com/matrix-org/gomatrixserverlib"
)
type RelayQueue interface {
InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
}
type RelayQueueJSON interface {
InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error)
DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error
SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error)
}

View file

@ -8,10 +8,10 @@ import (
"testing"
"time"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/postgres"
"github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
"github.com/matrix-org/dendrite/relayapi/storage/tables"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
@ -35,31 +35,31 @@ func mustCreateTransaction() gomatrixserverlib.Transaction {
return txn
}
type TransactionJSONDatabase struct {
type RelayQueueJSONDatabase struct {
DB *sql.DB
Writer sqlutil.Writer
Table tables.FederationTransactionJSON
Table tables.RelayQueueJSON
}
func mustCreateTransactionJSONTable(t *testing.T, dbType test.DBType) (database TransactionJSONDatabase, close func()) {
func mustCreateQueueJSONTable(t *testing.T, dbType test.DBType) (database RelayQueueJSONDatabase, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
var tab tables.FederationTransactionJSON
var tab tables.RelayQueueJSON
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresTransactionJSONTable(db)
tab, err = postgres.NewPostgresRelayQueueJSONTable(db)
assert.NoError(t, err)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteTransactionJSONTable(db)
tab, err = sqlite3.NewSQLiteRelayQueueJSONTable(db)
assert.NoError(t, err)
}
assert.NoError(t, err)
database = TransactionJSONDatabase{
database = RelayQueueJSONDatabase{
DB: db,
Writer: sqlutil.NewDummyWriter(),
Table: tab,
@ -70,7 +70,7 @@ func mustCreateTransactionJSONTable(t *testing.T, dbType test.DBType) (database
func TestShoudInsertTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateTransactionJSONTable(t, dbType)
db, close := mustCreateQueueJSONTable(t, dbType)
defer close()
transaction := mustCreateTransaction()
@ -79,7 +79,7 @@ func TestShoudInsertTransaction(t *testing.T) {
t.Fatalf("Invalid transaction: %s", err.Error())
}
_, err = db.Table.InsertTransactionJSON(ctx, nil, string(tx))
_, err = db.Table.InsertQueueJSON(ctx, nil, string(tx))
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
@ -89,7 +89,7 @@ func TestShoudInsertTransaction(t *testing.T) {
func TestShouldRetrieveInsertedTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateTransactionJSONTable(t, dbType)
db, close := mustCreateQueueJSONTable(t, dbType)
defer close()
transaction := mustCreateTransaction()
@ -98,14 +98,14 @@ func TestShouldRetrieveInsertedTransaction(t *testing.T) {
t.Fatalf("Invalid transaction: %s", err.Error())
}
nid, err := db.Table.InsertTransactionJSON(ctx, nil, string(tx))
nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx))
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
var storedJSON map[int64][]byte
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
storedJSON, err = db.Table.SelectTransactionJSON(ctx, txn, []int64{nid})
storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid})
return err
})
if err != nil {
@ -124,7 +124,7 @@ func TestShouldRetrieveInsertedTransaction(t *testing.T) {
func TestShouldDeleteTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateTransactionJSONTable(t, dbType)
db, close := mustCreateQueueJSONTable(t, dbType)
defer close()
transaction := mustCreateTransaction()
@ -133,14 +133,14 @@ func TestShouldDeleteTransaction(t *testing.T) {
t.Fatalf("Invalid transaction: %s", err.Error())
}
nid, err := db.Table.InsertTransactionJSON(ctx, nil, string(tx))
nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx))
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
storedJSON := map[int64][]byte{}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
err = db.Table.DeleteTransactionJSON(ctx, txn, []int64{nid})
err = db.Table.DeleteQueueJSON(ctx, txn, []int64{nid})
return err
})
if err != nil {
@ -149,7 +149,7 @@ func TestShouldDeleteTransaction(t *testing.T) {
storedJSON = map[int64][]byte{}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
storedJSON, err = db.Table.SelectTransactionJSON(ctx, txn, []int64{nid})
storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid})
return err
})
if err != nil {

View file

@ -7,41 +7,41 @@ import (
"testing"
"time"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/postgres"
"github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
"github.com/matrix-org/dendrite/relayapi/storage/tables"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
type QueueTransactionsDatabase struct {
type RelayQueueDatabase struct {
DB *sql.DB
Writer sqlutil.Writer
Table tables.FederationQueueTransactions
Table tables.RelayQueue
}
func mustCreateQueueTransactionsTable(t *testing.T, dbType test.DBType) (database QueueTransactionsDatabase, close func()) {
func mustCreateQueueTable(t *testing.T, dbType test.DBType) (database RelayQueueDatabase, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
var tab tables.FederationQueueTransactions
var tab tables.RelayQueue
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresQueueTransactionsTable(db)
tab, err = postgres.NewPostgresRelayQueueTable(db)
assert.NoError(t, err)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteQueueTransactionsTable(db)
tab, err = sqlite3.NewSQLiteRelayQueueTable(db)
assert.NoError(t, err)
}
assert.NoError(t, err)
database = QueueTransactionsDatabase{
database = RelayQueueDatabase{
DB: db,
Writer: sqlutil.NewDummyWriter(),
Table: tab,
@ -52,13 +52,13 @@ func mustCreateQueueTransactionsTable(t *testing.T, dbType test.DBType) (databas
func TestShoudInsertQueueTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTransactionsTable(t, dbType)
db, close := mustCreateQueueTable(t, dbType)
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
nid := int64(1)
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid)
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
@ -68,19 +68,19 @@ func TestShoudInsertQueueTransaction(t *testing.T) {
func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTransactionsTable(t, dbType)
db, close := mustCreateQueueTable(t, dbType)
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
nid := int64(1)
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid)
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
retrievedNids, err := db.Table.SelectQueueTransactions(ctx, nil, serverName, 10)
retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 10)
if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
@ -93,27 +93,27 @@ func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) {
func TestShouldDeleteQueueTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTransactionsTable(t, dbType)
db, close := mustCreateQueueTable(t, dbType)
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
nid := int64(1)
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid)
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
err = db.Table.DeleteQueueTransactions(ctx, txn, serverName, []int64{nid})
err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid})
return err
})
if err != nil {
t.Fatalf("Failed deleting transaction: %s", err.Error())
}
count, err := db.Table.SelectQueueTransactionCount(ctx, nil, serverName)
count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
@ -124,7 +124,7 @@ func TestShouldDeleteQueueTransaction(t *testing.T) {
func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTransactionsTable(t, dbType)
db, close := mustCreateQueueTable(t, dbType)
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
@ -135,34 +135,34 @@ func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) {
nid2 := int64(2)
transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano()))
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid)
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertQueueTransaction(ctx, nil, transactionID2, serverName2, nid)
err = db.Table.InsertQueueEntry(ctx, nil, transactionID2, serverName2, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertQueueTransaction(ctx, nil, transactionID3, serverName, nid2)
err = db.Table.InsertQueueEntry(ctx, nil, transactionID3, serverName, nid2)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
err = db.Table.DeleteQueueTransactions(ctx, txn, serverName, []int64{nid})
err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid})
return err
})
if err != nil {
t.Fatalf("Failed deleting transaction: %s", err.Error())
}
count, err := db.Table.SelectQueueTransactionCount(ctx, nil, serverName)
count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
assert.Equal(t, int64(1), count)
count, err = db.Table.SelectQueueTransactionCount(ctx, nil, serverName2)
count, err = db.Table.SelectQueueEntryCount(ctx, nil, serverName2)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}

View file

@ -62,6 +62,7 @@ type Dendrite struct {
RoomServer RoomServer `yaml:"room_server"`
SyncAPI SyncAPI `yaml:"sync_api"`
UserAPI UserAPI `yaml:"user_api"`
RelayAPI RelayAPI `yaml:"relay_api"`
MSCs MSCs `yaml:"mscs"`

View file

@ -0,0 +1,38 @@
package config
type RelayAPI struct {
Matrix *Global `yaml:"-"`
InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"`
ExternalAPI ExternalAPIOptions `yaml:"external_api,omitempty"`
// The database stores information used by the relay queue to
// forward transactions to remote servers.
Database DatabaseOptions `yaml:"database,omitempty"`
}
func (c *RelayAPI) Defaults(opts DefaultOpts) {
if !opts.Monolithic {
c.InternalAPI.Listen = "http://localhost:7775"
c.InternalAPI.Connect = "http://localhost:7775"
c.ExternalAPI.Listen = "http://[::]:8075"
c.Database.Defaults(10)
}
if opts.Generate {
if !opts.Monolithic {
c.Database.ConnectionString = "file:relayapi.db"
}
}
}
func (c *RelayAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
if isMonolith { // polylith required configs below
return
}
if c.Matrix.DatabaseOptions.ConnectionString == "" {
checkNotEmpty(configErrs, "relay_api.database.connection_string", string(c.Database.ConnectionString))
}
checkURL(configErrs, "relay_api.external_api.listen", string(c.ExternalAPI.Listen))
checkURL(configErrs, "relay_api.internal_api.listen", string(c.InternalAPI.Listen))
checkURL(configErrs, "relay_api.internal_api.connect", string(c.InternalAPI.Connect))
}

View file

@ -23,6 +23,8 @@ import (
"github.com/matrix-org/dendrite/internal/transactions"
keyAPI "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/mediaapi"
"github.com/matrix-org/dendrite/relayapi"
relayAPI "github.com/matrix-org/dendrite/relayapi/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
@ -44,6 +46,7 @@ type Monolith struct {
RoomserverAPI roomserverAPI.RoomserverInternalAPI
UserAPI userapi.UserInternalAPI
KeyAPI keyAPI.KeyInternalAPI
RelayAPI relayAPI.RelayInternalAPI
// Optional
ExtPublicRoomsProvider api.ExtraPublicRoomsProvider
@ -71,4 +74,9 @@ func (m *Monolith) AddAllPublicRoutes(base *base.BaseDendrite) {
syncapi.AddPublicRoutes(
base, m.UserAPI, m.RoomserverAPI, m.KeyAPI,
)
// TODO : relayapi.AddPublicRoutes
if m.RelayAPI != nil {
relayapi.AddPublicRoutes(base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.RelayAPI, m.FederationAPI, m.KeyAPI)
}
}