Refactor all relay specific stuff into it's own component

This commit is contained in:
Devon Hudson 2022-12-14 18:41:27 -07:00
parent f300a4d0e9
commit ad53326ce8
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
39 changed files with 1433 additions and 850 deletions

View file

@ -458,7 +458,7 @@ func (m *DendriteMonolith) Start() {
switch e := event.(type) {
case pineconeEvents.PeerAdded:
if !relayServerSyncRunning.Load() {
go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning)
// go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning)
}
case pineconeEvents.PeerRemoved:
if relayServerSyncRunning.Load() && m.PineconeRouter.PeerCount(-1) == 0 {
@ -486,46 +486,6 @@ func (m *DendriteMonolith) Start() {
}(pineconeEventChannel)
}
func (m *DendriteMonolith) syncRelayServers(stop <-chan bool, running atomic.Bool) {
defer running.Store(false)
t := time.NewTimer(relayServerRetryInterval)
for {
relayServersToQuery := []gomatrixserverlib.ServerName{}
for server, complete := range m.relayServersQueried {
if !complete {
relayServersToQuery = append(relayServersToQuery, server)
}
}
if len(relayServersToQuery) == 0 {
// All relay servers have been synced.
return
}
m.queryRelayServers(relayServersToQuery)
t.Reset(relayServerRetryInterval)
select {
case <-stop:
if !t.Stop() {
<-t.C
}
return
case <-t.C:
}
}
}
func (m *DendriteMonolith) queryRelayServers(relayServers []gomatrixserverlib.ServerName) {
for _, server := range relayServers {
request := api.PerformRelayServerSyncRequest{RelayServer: server}
response := api.PerformRelayServerSyncResponse{}
err := m.federationAPI.PerformRelayServerSync(m.processContext.Context(), &request, &response)
if err == nil {
m.relayServersQueried[server] = true
}
}
}
func (m *DendriteMonolith) Stop() {
m.processContext.ShutdownDendrite()
_ = m.listener.Close()

View file

@ -43,6 +43,7 @@ import (
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/relayapi"
relayServerAPI "github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/base"
@ -145,6 +146,7 @@ func main() {
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(*instanceDir, *instanceName)))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(*instanceDir, *instanceName)))
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", filepath.Join(*instanceDir, *instanceName)))
cfg.RelayAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-relayapi.db", filepath.Join(*instanceDir, *instanceName)))
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName)))
cfg.ClientAPI.RegistrationDisabled = false
@ -235,6 +237,20 @@ func main() {
userProvider := users.NewPineconeUserProvider(pRouter, pQUIC, userAPI, federation)
roomProvider := rooms.NewPineconeRoomProvider(pRouter, pQUIC, fsAPI, federation)
js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
producer := &producers.SyncAPIProducer{
JetStream: js,
TopicReceiptEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent),
TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
TopicTypingEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent),
TopicPresenceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent),
TopicDeviceListUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
TopicSigningKeyUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
Config: &base.Cfg.FederationAPI,
UserAPI: userAPI,
}
relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer)
monolith := setup.Monolith{
Config: base.Cfg,
Client: conn.CreateClient(base, pQUIC),
@ -246,6 +262,7 @@ func main() {
RoomserverAPI: rsAPI,
UserAPI: userAPI,
KeyAPI: keyAPI,
RelayAPI: &relayAPI,
ExtPublicRoomsProvider: roomProvider,
ExtUserDirectoryProvider: userProvider,
}
@ -319,32 +336,12 @@ func main() {
relayServerSyncRunning := atomic.NewBool(false)
stopRelayServerSync := make(chan bool)
js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.FederationAPI.Matrix.JetStream)
producer := &producers.SyncAPIProducer{
JetStream: js,
TopicReceiptEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent),
TopicSendToDeviceEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
TopicTypingEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputTypingEvent),
TopicPresenceEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent),
TopicDeviceListUpdate: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
TopicSigningKeyUpdate: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
Config: &base.Cfg.FederationAPI,
UserAPI: userAPI,
}
m := RelayServerRetriever{
Context: context.Background(),
ServerName: gomatrixserverlib.ServerName(pRouter.PublicKey().String()),
FederationAPI: fsAPI,
RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool),
RelayAPI: relayapi.NewRelayAPI(
federation,
rsAPI,
keyRing,
producer,
cfg.Global.Presence.EnableInbound,
cfg.Global.ServerName,
),
RelayAPI: monolith.RelayAPI,
}
m.InitializeRelayServers(eLog)
@ -387,7 +384,7 @@ type RelayServerRetriever struct {
ServerName gomatrixserverlib.ServerName
FederationAPI api.FederationInternalAPI
RelayServersQueried map[gomatrixserverlib.ServerName]bool
RelayAPI relayapi.RelayAPI
RelayAPI relayServerAPI.RelayInternalAPI
}
func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) {
@ -440,7 +437,12 @@ func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverli
if err != nil {
return
}
err = m.RelayAPI.PerformRelayServerSync(*userID, server)
request := relayServerAPI.PerformRelayServerSyncRequest{
UserID: *userID,
RelayServer: server,
}
response := relayServerAPI.PerformRelayServerSyncResponse{}
err = m.RelayAPI.PerformRelayServerSync(context.Background(), &request, &response)
if err == nil {
m.RelayServersQueried[server] = true
// TODO : What happens if your relay receives new messages after this point?

View file

@ -18,7 +18,6 @@ type FederationInternalAPI interface {
gomatrixserverlib.KeyDatabase
ClientFederationAPI
RoomserverFederationAPI
RelayServerAPI
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
@ -45,22 +44,6 @@ type FederationInternalAPI interface {
) error
}
type RelayServerAPI interface {
// Store async transactions for forwarding to the destination at a later time.
PerformStoreAsync(
ctx context.Context,
request *PerformStoreAsyncRequest,
response *PerformStoreAsyncResponse,
) error
// Obtain the oldest stored transaction for the specified userID.
QueryAsyncTransactions(
ctx context.Context,
request *QueryAsyncTransactionsRequest,
response *QueryAsyncTransactionsResponse,
) error
}
type ClientFederationAPI interface {
// Query the server names of the joined hosts in a room.
// Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice

View file

@ -14,7 +14,6 @@ import (
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/consumers"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/version"
)
@ -847,64 +846,3 @@ func (r *FederationInternalAPI) QueryRelayServers(
response.RelayServers = relayServers
return nil
}
// PerformStoreAsync implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformStoreAsync(
ctx context.Context,
request *api.PerformStoreAsyncRequest,
response *api.PerformStoreAsyncResponse,
) error {
logrus.Warnf("Storing transaction for %v", request.UserID)
receipt, err := r.db.StoreAsyncTransaction(ctx, request.Txn)
if err != nil {
return err
}
err = r.db.AssociateAsyncTransactionWithDestinations(
ctx,
map[gomatrixserverlib.UserID]struct{}{
request.UserID: {},
},
request.Txn.TransactionID,
receipt)
return err
}
// QueryAsyncTransactions implements api.FederationInternalAPI
func (r *FederationInternalAPI) QueryAsyncTransactions(
ctx context.Context,
request *api.QueryAsyncTransactionsRequest,
response *api.QueryAsyncTransactionsResponse,
) error {
logrus.Warnf("Obtaining transaction for %v", request.UserID)
transaction, receipt, err := r.db.GetAsyncTransaction(ctx, request.UserID)
if err != nil {
return err
}
// TODO : Shouldn't be deleting unless the transaction was successfully returned...
// TODO : Should delete transaction json from table if no more associations
// Maybe track last received transaction, and send that as part of the request,
// then delete before getting the new events from the db.
if transaction != nil && receipt != nil {
err = r.db.CleanAsyncTransactions(ctx, request.UserID, []*shared.Receipt{receipt})
if err != nil {
return err
}
// TODO : Clean async transactions json
}
// TODO : These db calls should happen at the same time right?
count, err := r.db.GetAsyncTransactionCount(ctx, request.UserID)
if err != nil {
return err
}
response.RemainingCount = uint32(count)
if transaction != nil {
response.Txn = *transaction
logrus.Warnf("Obtained transaction: %v", transaction.TransactionID)
}
return nil
}

View file

@ -26,9 +26,6 @@ const (
FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers"
FederationAPIQueryRelayServers = "/federationapi/queryRelayServers"
FederationAPIPerformStoreAsyncPath = "/federationapi/performStoreAsync"
FederationAPIQueryAsyncTransactionsPath = "/federationapi/queryAsyncTransactions"
FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices"
FederationAPIClaimKeysPath = "/federationapi/client/claimKeys"
FederationAPIQueryKeysPath = "/federationapi/client/queryKeys"
@ -525,25 +522,3 @@ func (h *httpFederationInternalAPI) QueryRelayServers(
h.httpClient, ctx, request, response,
)
}
func (h *httpFederationInternalAPI) PerformStoreAsync(
ctx context.Context,
request *api.PerformStoreAsyncRequest,
response *api.PerformStoreAsyncResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformStoreAsync", h.federationAPIURL+FederationAPIPerformStoreAsyncPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpFederationInternalAPI) QueryAsyncTransactions(
ctx context.Context,
request *api.QueryAsyncTransactionsRequest,
response *api.QueryAsyncTransactionsResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryAsyncTransactions", h.federationAPIURL+FederationAPIQueryAsyncTransactionsPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -43,16 +43,6 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU),
)
internalAPIMux.Handle(
FederationAPIPerformStoreAsyncPath,
httputil.MakeInternalRPCAPI("FederationAPIPerformStoreAsync", intAPI.PerformStoreAsync),
)
internalAPIMux.Handle(
FederationAPIQueryAsyncTransactionsPath,
httputil.MakeInternalRPCAPI("FederationAPIQueryAsyncTransactions", intAPI.QueryAsyncTransactions),
)
internalAPIMux.Handle(
FederationAPIPerformWakeupServers,
httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", intAPI.PerformWakeupServers),

View file

@ -1,199 +0,0 @@
package routing_test
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"sync"
"testing"
"time"
"github.com/matrix-org/dendrite/federationapi/internal"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
const (
testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
testDestination = gomatrixserverlib.ServerName("white.orchard")
)
type testDatabase struct {
nid int64
nidMutex sync.Mutex
transactions map[int64]json.RawMessage
associations map[gomatrixserverlib.ServerName][]int64
}
func createDatabase() *testDatabase {
return &testDatabase{
nid: 1,
nidMutex: sync.Mutex{},
transactions: make(map[int64]json.RawMessage),
associations: make(map[gomatrixserverlib.ServerName][]int64),
}
}
func (d *testDatabase) InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error {
if _, ok := d.associations[serverName]; !ok {
d.associations[serverName] = []int64{}
}
d.associations[serverName] = append(d.associations[serverName], nid)
return nil
}
func (d *testDatabase) DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error {
for _, nid := range jsonNIDs {
for index, associatedNID := range d.associations[serverName] {
if associatedNID == nid {
d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...)
}
}
}
return nil
}
func (d *testDatabase) SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) {
results := []int64{}
resultCount := limit
if limit > len(d.associations[serverName]) {
resultCount = len(d.associations[serverName])
}
if resultCount > 0 {
for i := 0; i < resultCount; i++ {
results = append(results, d.associations[serverName][i])
}
}
return results, nil
}
func (d *testDatabase) SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) {
return int64(len(d.associations[serverName])), nil
}
func (d *testDatabase) InsertTransactionJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) {
d.nidMutex.Lock()
defer d.nidMutex.Unlock()
nid := d.nid
d.transactions[nid] = []byte(json)
d.nid++
return nid, nil
}
func (d *testDatabase) DeleteTransactionJSON(ctx context.Context, txn *sql.Tx, nids []int64) error {
for _, nid := range nids {
delete(d.transactions, nid)
}
return nil
}
func (d *testDatabase) SelectTransactionJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) {
result := make(map[int64][]byte)
for _, nid := range jsonNIDs {
if transaction, ok := d.transactions[nid]; ok {
result[nid] = transaction
}
}
return result, nil
}
func createTransaction() gomatrixserverlib.Transaction {
txn := gomatrixserverlib.Transaction{}
txn.PDUs = []json.RawMessage{
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`),
}
txn.Origin = testOrigin
txn.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
txn.Destination = testDestination
return txn
}
func createFederationRequest(userID gomatrixserverlib.UserID) (gomatrixserverlib.Transaction, gomatrixserverlib.FederationRequest) {
txn := createTransaction()
var federationPathPrefixV1 = "/_matrix/federation/v1"
path := federationPathPrefixV1 + "/forward_async/" + string(txn.TransactionID) + "/" + userID.Raw()
request := gomatrixserverlib.NewFederationRequest("PUT", txn.Origin, txn.Destination, path)
request.SetContent(txn)
return txn, request
}
func TestEmptyForwardReturnsOk(t *testing.T) {
testDB := createDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
FederationQueueTransactions: testDB,
FederationTransactionJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
_, request := createFederationRequest(*userID)
fedAPI := internal.NewFederationInternalAPI(
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
)
response := routing.ForwardAsync(httpReq, &request, fedAPI, "1", *userID)
expected := 200
if response.Code != expected {
t.Fatalf("Expected: %v, Actual: %v", expected, response.Code)
}
}
func TestUniqueTransactionStoredInDatabase(t *testing.T) {
testDB := createDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
FederationQueueTransactions: testDB,
FederationTransactionJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
inputTransaction, request := createFederationRequest(*userID)
fedAPI := internal.NewFederationInternalAPI(
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
)
response := routing.ForwardAsync(
httpReq, &request, fedAPI, inputTransaction.TransactionID, *userID)
transaction, _, err := db.GetAsyncTransaction(context.TODO(), *userID)
if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
expected := 200
if response.Code != expected {
t.Fatalf("Expected Return Code: %v, Actual: %v", expected, response.Code)
}
if transactionCount != 1 {
t.Fatalf("Expected count of 1, Actual: %d", transactionCount)
}
if transaction.TransactionID != inputTransaction.TransactionID {
t.Fatalf("Expected Transaction ID: %s, Actual: %s",
inputTransaction.TransactionID, transaction.TransactionID)
}
}

View file

@ -133,37 +133,6 @@ func Setup(
},
)).Methods(http.MethodPut, http.MethodOptions)
v1fedmux.Handle("/forward_async/{txnID}/{userID}", MakeFedAPI(
"federation_forward_async", "", cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username was invalid"),
}
}
return ForwardAsync(
httpReq, request, fsAPI, gomatrixserverlib.TransactionID(vars["txnID"]),
*userID,
)
},
)).Methods(http.MethodPut, http.MethodOptions)
v1fedmux.Handle("/async_events/{userID}", MakeFedAPI(
"federation_async_events", "", cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username was invalid"),
}
}
return GetAsyncEvents(httpReq, request, fsAPI, *userID)
},
)).Methods(http.MethodGet, http.MethodOptions)
v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI(
"federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {

View file

@ -51,12 +51,6 @@ type Database interface {
GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
StoreAsyncTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*shared.Receipt, error)
AssociateAsyncTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *shared.Receipt) error
CleanAsyncTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*shared.Receipt) error
GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *shared.Receipt, error)
GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error)
// these don't have contexts passed in as we want things to happen regardless of the request context
AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error
RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error

View file

@ -58,18 +58,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, err
}
queueTransactions, err := NewPostgresQueueTransactionsTable(d.db)
if err != nil {
return nil, err
}
queueJSON, err := NewPostgresQueueJSONTable(d.db)
if err != nil {
return nil, err
}
transactionJSON, err := NewPostgresTransactionJSONTable(d.db)
if err != nil {
return nil, err
}
assumedOffline, err := NewPostgresAssumedOfflineTable(d.db)
if err != nil {
return nil, err
@ -119,8 +111,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
FederationQueuePDUs: queuePDUs,
FederationQueueEDUs: queueEDUs,
FederationQueueJSON: queueJSON,
FederationQueueTransactions: queueTransactions,
FederationTransactionJSON: transactionJSON,
FederationBlacklist: blacklist,
FederationAssumedOffline: assumedOffline,
FederationRelayServers: relayServers,

View file

@ -17,7 +17,6 @@ package shared
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
@ -45,8 +44,6 @@ type Database struct {
NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON
NotaryServerKeysMetadata tables.FederationNotaryServerKeysMetadata
ServerSigningKeys tables.FederationServerSigningKeys
FederationQueueTransactions tables.FederationQueueTransactions
FederationTransactionJSON tables.FederationTransactionJSON
}
// An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs.
@ -61,6 +58,10 @@ func NewReceipt(nid int64) Receipt {
return Receipt{nid: nid}
}
func (r *Receipt) GetNID() int64 {
return r.nid
}
func (r *Receipt) String() string {
return fmt.Sprintf("%d", r.nid)
}
@ -308,108 +309,3 @@ func (d *Database) GetNotaryKeys(
})
return sks, err
}
func (d *Database) StoreAsyncTransaction(
ctx context.Context, txn gomatrixserverlib.Transaction,
) (*Receipt, error) {
var err error
json, err := json.Marshal(txn)
if err != nil {
return nil, fmt.Errorf("d.JSONUnmarshall: %w", err)
}
var nid int64
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nid, err = d.FederationTransactionJSON.InsertTransactionJSON(ctx, txn, string(json))
return err
})
if err != nil {
return nil, fmt.Errorf("d.insertTransactionJSON: %w", err)
}
return &Receipt{
nid: nid,
}, nil
}
func (d *Database) AssociateAsyncTransactionWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.UserID]struct{},
transactionID gomatrixserverlib.TransactionID,
receipt *Receipt,
) error {
for destination := range destinations {
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.FederationQueueTransactions.InsertQueueTransaction(
ctx, txn, transactionID, destination.Domain(), receipt.nid)
return err
})
if err != nil {
return fmt.Errorf("d.insertQueueTransaction: %w", err)
}
}
return nil
}
func (d *Database) CleanAsyncTransactions(
ctx context.Context,
userID gomatrixserverlib.UserID,
receipts []*Receipt,
) error {
println(len(receipts))
nids := make([]int64, len(receipts))
for i, receipt := range receipts {
nids[i] = receipt.nid
}
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.FederationQueueTransactions.DeleteQueueTransactions(ctx, txn, userID.Domain(), nids)
return err
})
if err != nil {
return fmt.Errorf("d.insertQueueTransaction: %w", err)
}
return nil
}
func (d *Database) GetAsyncTransaction(
ctx context.Context,
userID gomatrixserverlib.UserID,
) (*gomatrixserverlib.Transaction, *Receipt, error) {
nids, err := d.FederationQueueTransactions.SelectQueueTransactions(ctx, nil, userID.Domain(), 1)
if err != nil {
return nil, nil, fmt.Errorf("d.SelectQueueTransaction: %w", err)
}
if len(nids) == 0 {
return nil, nil, nil
}
txns := map[int64][]byte{}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
txns, err = d.FederationTransactionJSON.SelectTransactionJSON(ctx, txn, nids)
return err
})
if err != nil {
return nil, nil, fmt.Errorf("d.SelectTransactionJSON: %w", err)
}
transaction := &gomatrixserverlib.Transaction{}
err = json.Unmarshal(txns[nids[0]], transaction)
if err != nil {
return nil, nil, fmt.Errorf("Unmarshall transaction: %w", err)
}
receipt := NewReceipt(nids[0])
return transaction, &receipt, nil
}
func (d *Database) GetAsyncTransactionCount(
ctx context.Context,
userID gomatrixserverlib.UserID,
) (int64, error) {
count, err := d.FederationQueueTransactions.SelectQueueTransactionCount(ctx, nil, userID.Domain())
if err != nil {
return 0, fmt.Errorf("d.SelectQueueTransactionCount: %w", err)
}
return count, nil
}

View file

@ -55,14 +55,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, err
}
queueTransactions, err := NewSQLiteQueueTransactionsTable(d.db)
if err != nil {
return nil, err
}
transactionJSON, err := NewSQLiteTransactionJSONTable(d.db)
if err != nil {
return nil, err
}
assumedOffline, err := NewSQLiteAssumedOfflineTable(d.db)
if err != nil {
return nil, err
@ -112,8 +104,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
FederationQueuePDUs: queuePDUs,
FederationQueueEDUs: queueEDUs,
FederationQueueJSON: queueJSON,
FederationQueueTransactions: queueTransactions,
FederationTransactionJSON: transactionJSON,
FederationBlacklist: blacklist,
FederationAssumedOffline: assumedOffline,
FederationRelayServers: relayServers,

81
relayapi/api/api.go Normal file
View file

@ -0,0 +1,81 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"context"
"github.com/matrix-org/gomatrixserverlib"
)
// RelayInternalAPI is used to query information from the relay server.
type RelayInternalAPI interface {
RelayServerAPI
PerformRelayServerSync(
ctx context.Context,
request *PerformRelayServerSyncRequest,
response *PerformRelayServerSyncResponse,
) error
}
type RelayServerAPI interface {
// Store async transactions for forwarding to the destination at a later time.
PerformStoreAsync(
ctx context.Context,
request *PerformStoreAsyncRequest,
response *PerformStoreAsyncResponse,
) error
// Obtain the oldest stored transaction for the specified userID.
QueryAsyncTransactions(
ctx context.Context,
request *QueryAsyncTransactionsRequest,
response *QueryAsyncTransactionsResponse,
) error
}
type PerformRelayServerSyncRequest struct {
UserID gomatrixserverlib.UserID `json:"user_id"`
RelayServer gomatrixserverlib.ServerName `json:"relay_name"`
}
type PerformRelayServerSyncResponse struct {
}
type QueryRelayServersRequest struct {
Server gomatrixserverlib.ServerName
}
type QueryRelayServersResponse struct {
RelayServers []gomatrixserverlib.ServerName
}
type PerformStoreAsyncRequest struct {
Txn gomatrixserverlib.Transaction `json:"transaction"`
UserID gomatrixserverlib.UserID `json:"user_id"`
}
type PerformStoreAsyncResponse struct {
}
type QueryAsyncTransactionsRequest struct {
UserID gomatrixserverlib.UserID `json:"user_id"`
}
type QueryAsyncTransactionsResponse struct {
Txn gomatrixserverlib.Transaction `json:"transaction"`
RemainingCount uint32 `json:"remaining"`
}

52
relayapi/internal/api.go Normal file
View file

@ -0,0 +1,52 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/relayapi/storage"
rsAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
)
type RelayInternalAPI struct {
db storage.Database
fedClient *gomatrixserverlib.FederationClient
rsAPI rsAPI.RoomserverInternalAPI
keyRing *gomatrixserverlib.KeyRing
producer *producers.SyncAPIProducer
presenceEnabledInbound bool
serverName gomatrixserverlib.ServerName
}
func NewRelayInternalAPI(
db storage.Database,
fedClient *gomatrixserverlib.FederationClient,
rsAPI rsAPI.RoomserverInternalAPI,
keyRing *gomatrixserverlib.KeyRing,
producer *producers.SyncAPIProducer,
presenceEnabledInbound bool,
serverName gomatrixserverlib.ServerName,
) RelayInternalAPI {
return RelayInternalAPI{
db: db,
fedClient: fedClient,
rsAPI: rsAPI,
keyRing: keyRing,
producer: producer,
presenceEnabledInbound: presenceEnabledInbound,
serverName: serverName,
}
}

View file

@ -0,0 +1,131 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
// PerformRelayServerSync implements api.FederationInternalAPI
func (r *RelayInternalAPI) PerformRelayServerSync(
ctx context.Context,
request *api.PerformRelayServerSyncRequest,
response *api.PerformRelayServerSyncResponse,
) error {
asyncResponse, err := r.fedClient.GetAsyncEvents(ctx, request.UserID, request.RelayServer)
if err != nil {
logrus.Errorf("GetAsyncEvents: %s", err.Error())
return err
}
r.processTransaction(&asyncResponse.Transaction)
for asyncResponse.Remaining > 0 {
asyncResponse, err := r.fedClient.GetAsyncEvents(ctx, request.UserID, request.RelayServer)
if err != nil {
logrus.Errorf("GetAsyncEvents: %s", err.Error())
return err
}
r.processTransaction(&asyncResponse.Transaction)
}
return nil
}
// PerformStoreAsync implements api.RelayInternalAPI
func (r *RelayInternalAPI) PerformStoreAsync(
ctx context.Context,
request *api.PerformStoreAsyncRequest,
response *api.PerformStoreAsyncResponse,
) error {
logrus.Warnf("Storing transaction for %v", request.UserID)
receipt, err := r.db.StoreAsyncTransaction(ctx, request.Txn)
if err != nil {
return err
}
err = r.db.AssociateAsyncTransactionWithDestinations(
ctx,
map[gomatrixserverlib.UserID]struct{}{
request.UserID: {},
},
request.Txn.TransactionID,
receipt)
return err
}
// QueryAsyncTransactions implements api.RelayInternalAPI
func (r *RelayInternalAPI) QueryAsyncTransactions(
ctx context.Context,
request *api.QueryAsyncTransactionsRequest,
response *api.QueryAsyncTransactionsResponse,
) error {
logrus.Warnf("Obtaining transaction for %v", request.UserID)
transaction, receipt, err := r.db.GetAsyncTransaction(ctx, request.UserID)
if err != nil {
return err
}
// TODO : Shouldn't be deleting unless the transaction was successfully returned...
// TODO : Should delete transaction json from table if no more associations
// Maybe track last received transaction, and send that as part of the request,
// then delete before getting the new events from the db.
if transaction != nil && receipt != nil {
err = r.db.CleanAsyncTransactions(ctx, request.UserID, []*shared.Receipt{receipt})
if err != nil {
return err
}
// TODO : Clean async transactions json
}
// TODO : These db calls should happen at the same time right?
count, err := r.db.GetAsyncTransactionCount(ctx, request.UserID)
if err != nil {
return err
}
response.RemainingCount = uint32(count)
if transaction != nil {
response.Txn = *transaction
logrus.Warnf("Obtained transaction: %v", transaction.TransactionID)
}
return nil
}
func (r *RelayInternalAPI) processTransaction(txn *gomatrixserverlib.Transaction) {
logrus.Warn("Processing transaction from relay server")
mu := internal.NewMutexByRoom()
t := internal.NewTxnReq(
r.rsAPI,
nil,
r.serverName,
r.keyRing,
mu,
r.producer,
r.presenceEnabledInbound,
txn.PDUs,
txn.EDUs,
txn.Origin,
txn.TransactionID,
txn.Destination)
t.ProcessTransaction(context.TODO())
}

View file

@ -0,0 +1,70 @@
package inthttp
import (
"context"
"errors"
"net/http"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/relayapi/api"
)
// HTTP paths for the internal HTTP API
const (
RelayAPIPerformRelayServerSyncPath = "/relayapi/performRelayServerSync"
RelayAPIPerformStoreAsyncPath = "/relayapi/performStoreAsync"
RelayAPIQueryAsyncTransactionsPath = "/relayapi/queryAsyncTransactions"
)
// NewRelayAPIClient creates a RelayInternalAPI implemented by talking to a HTTP POST API.
// If httpClient is nil an error is returned
func NewRelayAPIClient(relayapiURL string, httpClient *http.Client, cache caching.ServerKeyCache) (api.RelayInternalAPI, error) {
if httpClient == nil {
return nil, errors.New("NewRelayInternalAPIHTTP: httpClient is <nil>")
}
return &httpRelayInternalAPI{
relayAPIURL: relayapiURL,
httpClient: httpClient,
cache: cache,
}, nil
}
type httpRelayInternalAPI struct {
relayAPIURL string
httpClient *http.Client
cache caching.ServerKeyCache
}
func (h *httpRelayInternalAPI) PerformRelayServerSync(
ctx context.Context,
request *api.PerformRelayServerSyncRequest,
response *api.PerformRelayServerSyncResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformRelayServerSync", h.relayAPIURL+RelayAPIPerformRelayServerSyncPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRelayInternalAPI) PerformStoreAsync(
ctx context.Context,
request *api.PerformStoreAsyncRequest,
response *api.PerformStoreAsyncResponse,
) error {
return httputil.CallInternalRPCAPI(
"PerformStoreAsync", h.relayAPIURL+RelayAPIPerformStoreAsyncPath,
h.httpClient, ctx, request, response,
)
}
func (h *httpRelayInternalAPI) QueryAsyncTransactions(
ctx context.Context,
request *api.QueryAsyncTransactionsRequest,
response *api.QueryAsyncTransactionsResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryAsyncTransactions", h.relayAPIURL+RelayAPIQueryAsyncTransactionsPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -0,0 +1,27 @@
package inthttp
import (
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/relayapi/api"
)
// AddRoutes adds the RelayInternalAPI handlers to the http.ServeMux.
// nolint:gocyclo
func AddRoutes(intAPI api.RelayInternalAPI, internalAPIMux *mux.Router) {
internalAPIMux.Handle(
RelayAPIPerformRelayServerSyncPath,
httputil.MakeInternalRPCAPI("RelayAPIPerformRelayServerSync", intAPI.PerformRelayServerSync),
)
internalAPIMux.Handle(
RelayAPIPerformStoreAsyncPath,
httputil.MakeInternalRPCAPI("RelayAPIPerformStoreAsync", intAPI.PerformStoreAsync),
)
internalAPIMux.Handle(
RelayAPIQueryAsyncTransactionsPath,
httputil.MakeInternalRPCAPI("RelayAPIQueryAsyncTransactions", intAPI.QueryAsyncTransactions),
)
}

View file

@ -15,79 +15,77 @@
package relayapi
import (
"context"
"github.com/gorilla/mux"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/internal"
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/relayapi/api"
relayAPI "github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/dendrite/relayapi/internal"
"github.com/matrix-org/dendrite/relayapi/inthttp"
"github.com/matrix-org/dendrite/relayapi/routing"
"github.com/matrix-org/dendrite/relayapi/storage"
rsAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
type RelayAPI struct {
fedClient *gomatrixserverlib.FederationClient
rsAPI rsAPI.RoomserverInternalAPI
keyRing *gomatrixserverlib.KeyRing
producer *producers.SyncAPIProducer
presenceEnabledInbound bool
serverName gomatrixserverlib.ServerName
// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions
// on the given input API.
func AddInternalRoutes(router *mux.Router, intAPI api.RelayInternalAPI) {
inthttp.AddRoutes(intAPI, router)
}
func NewRelayAPI(
// AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component.
func AddPublicRoutes(
base *base.BaseDendrite,
userAPI userapi.UserInternalAPI,
fedClient *gomatrixserverlib.FederationClient,
keyRing gomatrixserverlib.JSONVerifier,
rsAPI rsAPI.FederationRoomserverAPI,
relayAPI relayAPI.RelayInternalAPI,
fedAPI federationAPI.FederationInternalAPI,
keyAPI keyserverAPI.FederationKeyAPI,
) {
fedCfg := &base.Cfg.FederationAPI
relay, ok := relayAPI.(*internal.RelayInternalAPI)
if !ok {
panic("relayapi.AddPublicRoutes called with a RelayInternalAPI impl which was not " +
"RelayInternalAPI. This is a programming error.")
}
routing.Setup(
base.PublicFederationAPIMux,
fedCfg,
relay,
keyRing,
)
}
func NewRelayInternalAPI(
base *base.BaseDendrite,
fedClient *gomatrixserverlib.FederationClient,
rsAPI rsAPI.RoomserverInternalAPI,
keyRing *gomatrixserverlib.KeyRing,
producer *producers.SyncAPIProducer,
presenceEnabledInbound bool,
serverName gomatrixserverlib.ServerName,
) RelayAPI {
return RelayAPI{
fedClient: fedClient,
rsAPI: rsAPI,
keyRing: keyRing,
producer: producer,
presenceEnabledInbound: presenceEnabledInbound,
serverName: serverName,
}
}
) internal.RelayInternalAPI {
cfg := &base.Cfg.RelayAPI
// PerformRelayServerSync implements api.FederationInternalAPI
func (r *RelayAPI) PerformRelayServerSync(userID gomatrixserverlib.UserID, relayServer gomatrixserverlib.ServerName) error {
asyncResponse, err := r.fedClient.GetAsyncEvents(context.Background(), userID, relayServer)
relayDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName)
if err != nil {
logrus.Errorf("GetAsyncEvents: %s", err.Error())
return err
}
r.processTransaction(&asyncResponse.Transaction)
for asyncResponse.Remaining > 0 {
asyncResponse, err := r.fedClient.GetAsyncEvents(context.Background(), userID, relayServer)
if err != nil {
logrus.Errorf("GetAsyncEvents: %s", err.Error())
return err
}
r.processTransaction(&asyncResponse.Transaction)
logrus.WithError(err).Panic("failed to connect to relay db")
}
return nil
}
func (r *RelayAPI) processTransaction(txn *gomatrixserverlib.Transaction) {
logrus.Warn("Processing transaction from relay server")
mu := internal.NewMutexByRoom()
t := internal.NewTxnReq(
r.rsAPI,
nil,
r.serverName,
r.keyRing,
mu,
r.producer,
r.presenceEnabledInbound,
txn.PDUs,
txn.EDUs,
txn.Origin,
txn.TransactionID,
txn.Destination)
t.ProcessTransaction(context.TODO())
return internal.NewRelayInternalAPI(
relayDB,
fedClient,
rsAPI,
keyRing,
producer,
base.Cfg.Global.Presence.EnableInbound,
base.Cfg.Global.ServerName,
)
}

View file

@ -3,7 +3,7 @@ package routing
import (
"net/http"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
@ -19,12 +19,12 @@ type AsyncEventsResponse struct {
func GetAsyncEvents(
httpReq *http.Request,
fedReq *gomatrixserverlib.FederationRequest,
fedAPI api.FederationInternalAPI,
relayAPI api.RelayInternalAPI,
userID gomatrixserverlib.UserID,
) util.JSONResponse {
logrus.Infof("Handling async_events for %v", userID)
var response api.QueryAsyncTransactionsResponse
err := fedAPI.QueryAsyncTransactions(httpReq.Context(), &api.QueryAsyncTransactionsRequest{UserID: userID}, &response)
err := relayAPI.QueryAsyncTransactions(httpReq.Context(), &api.QueryAsyncTransactionsRequest{UserID: userID}, &response)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,

View file

@ -5,21 +5,21 @@ import (
"net/http"
"testing"
"github.com/matrix-org/dendrite/federationapi/internal"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/relayapi/internal"
"github.com/matrix-org/dendrite/relayapi/routing"
"github.com/matrix-org/dendrite/relayapi/storage"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) {
testDB := createDatabase()
testDB := storage.NewFakeRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
FederationQueueTransactions: testDB,
FederationTransactionJSON: testDB,
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
@ -33,11 +33,11 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) {
t.Fatalf("Failed to store transaction: %s", err.Error())
}
fedAPI := internal.NewFederationInternalAPI(
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse := response.JSON.(routing.AsyncEventsResponse)
@ -46,11 +46,11 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) {
}
func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
testDB := createDatabase()
testDB := storage.NewFakeRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
FederationQueueTransactions: testDB,
FederationTransactionJSON: testDB,
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
@ -74,11 +74,11 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
t.Fatalf("Failed to associate transaction with user: %s", err.Error())
}
fedAPI := internal.NewFederationInternalAPI(
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse := response.JSON.(routing.AsyncEventsResponse)
@ -87,11 +87,11 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
}
func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) {
testDB := createDatabase()
testDB := storage.NewFakeRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
FederationQueueTransactions: testDB,
FederationTransactionJSON: testDB,
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
@ -131,18 +131,18 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) {
t.Fatalf("Failed to associate transaction with user: %s", err.Error())
}
fedAPI := internal.NewFederationInternalAPI(
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse := response.JSON.(routing.AsyncEventsResponse)
assert.Equal(t, uint32(1), jsonResponse.Remaining)
assert.Equal(t, transaction, jsonResponse.Transaction)
response = routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
response = routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse = response.JSON.(routing.AsyncEventsResponse)

View file

@ -5,7 +5,7 @@ import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@ -15,7 +15,7 @@ import (
func ForwardAsync(
httpReq *http.Request,
fedReq *gomatrixserverlib.FederationRequest,
fedAPI api.FederationInternalAPI,
relayAPI api.RelayInternalAPI,
txnID gomatrixserverlib.TransactionID,
userID gomatrixserverlib.UserID,
) util.JSONResponse {
@ -54,7 +54,7 @@ func ForwardAsync(
UserID: userID,
}
res := api.PerformStoreAsyncResponse{}
err := fedAPI.PerformStoreAsync(httpReq.Context(), &req, &res)
err := relayAPI.PerformStoreAsync(httpReq.Context(), &req, &res)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,

View file

@ -0,0 +1,111 @@
package routing_test
import (
"context"
"encoding/json"
"fmt"
"net/http"
"testing"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/internal"
"github.com/matrix-org/dendrite/relayapi/routing"
"github.com/matrix-org/dendrite/relayapi/storage"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/gomatrixserverlib"
)
const (
testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
testDestination = gomatrixserverlib.ServerName("white.orchard")
)
func createTransaction() gomatrixserverlib.Transaction {
txn := gomatrixserverlib.Transaction{}
txn.PDUs = []json.RawMessage{
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`),
}
txn.Origin = testOrigin
txn.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
txn.Destination = testDestination
return txn
}
func createFederationRequest(userID gomatrixserverlib.UserID) (gomatrixserverlib.Transaction, gomatrixserverlib.FederationRequest) {
txn := createTransaction()
var federationPathPrefixV1 = "/_matrix/federation/v1"
path := federationPathPrefixV1 + "/forward_async/" + string(txn.TransactionID) + "/" + userID.Raw()
request := gomatrixserverlib.NewFederationRequest("PUT", txn.Origin, txn.Destination, path)
request.SetContent(txn)
return txn, request
}
func TestEmptyForwardReturnsOk(t *testing.T) {
testDB := storage.NewFakeRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
_, request := createFederationRequest(*userID)
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.ForwardAsync(httpReq, &request, &relayAPI, "1", *userID)
expected := 200
if response.Code != expected {
t.Fatalf("Expected: %v, Actual: %v", expected, response.Code)
}
}
func TestUniqueTransactionStoredInDatabase(t *testing.T) {
testDB := storage.NewFakeRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
inputTransaction, request := createFederationRequest(*userID)
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.ForwardAsync(
httpReq, &request, &relayAPI, inputTransaction.TransactionID, *userID)
transaction, _, err := db.GetAsyncTransaction(context.TODO(), *userID)
if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
expected := 200
if response.Code != expected {
t.Fatalf("Expected Return Code: %v, Actual: %v", expected, response.Code)
}
if transactionCount != 1 {
t.Fatalf("Expected count of 1, Actual: %d", transactionCount)
}
if transaction.TransactionID != inputTransaction.TransactionID {
t.Fatalf("Expected Transaction ID: %s, Actual: %s",
inputTransaction.TransactionID, transaction.TransactionID)
}
}

123
relayapi/routing/routing.go Normal file
View file

@ -0,0 +1,123 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"fmt"
"net/http"
"time"
"github.com/getsentry/sentry-go"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/httputil"
relayInternal "github.com/matrix-org/dendrite/relayapi/internal"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// Setup registers HTTP handlers with the given ServeMux.
// The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly
// path unescape twice (once from the router, once from MakeRelayAPI). We need to have this enabled
// so we can decode paths like foo/bar%2Fbaz as [foo, bar/baz] - by default it will decode to [foo, bar, baz]
//
// Due to Setup being used to call many other functions, a gocyclo nolint is
// applied:
// nolint: gocyclo
func Setup(
fedMux *mux.Router,
cfg *config.FederationAPI,
relayAPI *relayInternal.RelayInternalAPI,
keys gomatrixserverlib.JSONVerifier,
) {
v1fedmux := fedMux.PathPrefix("/v1").Subrouter()
v1fedmux.Handle("/forward_async/{txnID}/{userID}", MakeRelayAPI(
"relay_forward_async", "", cfg.Matrix.IsLocalServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username was invalid"),
}
}
return ForwardAsync(
httpReq, request, relayAPI, gomatrixserverlib.TransactionID(vars["txnID"]),
*userID,
)
},
)).Methods(http.MethodPut, http.MethodOptions)
v1fedmux.Handle("/async_events/{userID}", MakeRelayAPI(
"relay_async_events", "", cfg.Matrix.IsLocalServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username was invalid"),
}
}
return GetAsyncEvents(httpReq, request, relayAPI, *userID)
},
)).Methods(http.MethodGet, http.MethodOptions)
}
// MakeRelayAPI makes an http.Handler that checks matrix relay authentication.
func MakeRelayAPI(
metricsName string, serverName gomatrixserverlib.ServerName,
isLocalServerName func(gomatrixserverlib.ServerName) bool,
keyRing gomatrixserverlib.JSONVerifier,
f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse,
) http.Handler {
h := func(req *http.Request) util.JSONResponse {
fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
req, time.Now(), serverName, isLocalServerName, keyRing,
)
if fedReq == nil {
return errResp
}
// add the user to Sentry, if enabled
hub := sentry.GetHubFromContext(req.Context())
if hub != nil {
hub.Scope().SetTag("origin", string(fedReq.Origin()))
hub.Scope().SetTag("uri", fedReq.RequestURI())
}
defer func() {
if r := recover(); r != nil {
if hub != nil {
hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path))
}
// re-panic to return the 500
panic(r)
}
}()
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.MatrixErrorResponse(400, "M_UNRECOGNISED", "badly encoded query params")
}
jsonRes := f(req, fedReq, vars)
// do not log 4xx as errors as they are client fails, not server fails
if hub != nil && jsonRes.Code >= 500 {
hub.Scope().SetExtra("response", jsonRes)
hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code))
}
return jsonRes
}
return httputil.MakeExternalAPI(metricsName, h)
}

View file

@ -0,0 +1,95 @@
package storage
import (
"context"
"database/sql"
"encoding/json"
"sync"
"github.com/matrix-org/gomatrixserverlib"
)
type testDatabase struct {
nid int64
nidMutex sync.Mutex
transactions map[int64]json.RawMessage
associations map[gomatrixserverlib.ServerName][]int64
}
func NewFakeRelayDatabase() *testDatabase {
return &testDatabase{
nid: 1,
nidMutex: sync.Mutex{},
transactions: make(map[int64]json.RawMessage),
associations: make(map[gomatrixserverlib.ServerName][]int64),
}
}
func (d *testDatabase) InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error {
if _, ok := d.associations[serverName]; !ok {
d.associations[serverName] = []int64{}
}
d.associations[serverName] = append(d.associations[serverName], nid)
return nil
}
func (d *testDatabase) DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error {
for _, nid := range jsonNIDs {
for index, associatedNID := range d.associations[serverName] {
if associatedNID == nid {
d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...)
}
}
}
return nil
}
func (d *testDatabase) SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) {
results := []int64{}
resultCount := limit
if limit > len(d.associations[serverName]) {
resultCount = len(d.associations[serverName])
}
if resultCount > 0 {
for i := 0; i < resultCount; i++ {
results = append(results, d.associations[serverName][i])
}
}
return results, nil
}
func (d *testDatabase) SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) {
return int64(len(d.associations[serverName])), nil
}
func (d *testDatabase) InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) {
d.nidMutex.Lock()
defer d.nidMutex.Unlock()
nid := d.nid
d.transactions[nid] = []byte(json)
d.nid++
return nid, nil
}
func (d *testDatabase) DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error {
for _, nid := range nids {
delete(d.transactions, nid)
}
return nil
}
func (d *testDatabase) SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) {
result := make(map[int64][]byte)
for _, nid := range jsonNIDs {
if transaction, ok := d.transactions[nid]; ok {
result[nid] = transaction
}
}
return result, nil
}

View file

@ -0,0 +1,30 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
StoreAsyncTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*shared.Receipt, error)
AssociateAsyncTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *shared.Receipt) error
CleanAsyncTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*shared.Receipt) error
GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *shared.Receipt, error)
GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error)
}

View file

@ -23,60 +23,60 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const transactionJSONSchema = `
-- The federationsender_transaction_json table contains event contents that
const relayQueueJSONSchema = `
-- The relayapi_queue_json table contains event contents that
-- we are storing for future forwarding.
CREATE TABLE IF NOT EXISTS federationsender_transaction_json (
CREATE TABLE IF NOT EXISTS relayapi_queue_json (
-- The JSON NID. This allows cross-referencing to find the JSON blob.
json_nid BIGSERIAL,
-- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_transaction_json_json_nid_idx
ON federationsender_transaction_json (json_nid);
CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx
ON relayapi_queue_json (json_nid);
`
const insertTransactionJSONSQL = "" +
"INSERT INTO federationsender_transaction_json (json_body)" +
const insertQueueJSONSQL = "" +
"INSERT INTO relayapi_queue_json (json_body)" +
" VALUES ($1)" +
" RETURNING json_nid"
const deleteTransactionJSONSQL = "" +
"DELETE FROM federationsender_transaction_json WHERE json_nid = ANY($1)"
const deleteQueueJSONSQL = "" +
"DELETE FROM relayapi_queue_json WHERE json_nid = ANY($1)"
const selectTransactionJSONSQL = "" +
"SELECT json_nid, json_body FROM federationsender_transaction_json" +
const selectQueueJSONSQL = "" +
"SELECT json_nid, json_body FROM relayapi_queue_json" +
" WHERE json_nid = ANY($1)"
type transactionJSONStatements struct {
type relayQueueJSONStatements struct {
db *sql.DB
insertJSONStmt *sql.Stmt
deleteJSONStmt *sql.Stmt
selectJSONStmt *sql.Stmt
}
func NewPostgresTransactionJSONTable(db *sql.DB) (s *transactionJSONStatements, err error) {
s = &transactionJSONStatements{
func NewPostgresRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) {
s = &relayQueueJSONStatements{
db: db,
}
_, err = s.db.Exec(transactionJSONSchema)
_, err = s.db.Exec(relayQueueJSONSchema)
if err != nil {
return
}
if s.insertJSONStmt, err = s.db.Prepare(insertTransactionJSONSQL); err != nil {
if s.insertJSONStmt, err = s.db.Prepare(insertQueueJSONSQL); err != nil {
return
}
if s.deleteJSONStmt, err = s.db.Prepare(deleteTransactionJSONSQL); err != nil {
if s.deleteJSONStmt, err = s.db.Prepare(deleteQueueJSONSQL); err != nil {
return
}
if s.selectJSONStmt, err = s.db.Prepare(selectTransactionJSONSQL); err != nil {
if s.selectJSONStmt, err = s.db.Prepare(selectQueueJSONSQL); err != nil {
return
}
return
}
func (s *transactionJSONStatements) InsertTransactionJSON(
func (s *relayQueueJSONStatements) InsertQueueJSON(
ctx context.Context, txn *sql.Tx, json string,
) (int64, error) {
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
@ -87,7 +87,7 @@ func (s *transactionJSONStatements) InsertTransactionJSON(
return lastid, nil
}
func (s *transactionJSONStatements) DeleteTransactionJSON(
func (s *relayQueueJSONStatements) DeleteQueueJSON(
ctx context.Context, txn *sql.Tx, nids []int64,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt)
@ -95,7 +95,7 @@ func (s *transactionJSONStatements) DeleteTransactionJSON(
return err
}
func (s *transactionJSONStatements) SelectTransactionJSON(
func (s *relayQueueJSONStatements) SelectQueueJSON(
ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
) (map[int64][]byte, error) {
blobs := map[int64][]byte{}

View file

@ -24,79 +24,79 @@ import (
"github.com/matrix-org/gomatrixserverlib"
)
const queueTransactionsSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_transactions (
const relayQueueSchema = `
CREATE TABLE IF NOT EXISTS relayapi_queue (
-- The transaction ID that was generated before persisting the event.
transaction_id TEXT NOT NULL,
-- The destination server that we will send the event to.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_transaction_json table.
-- The JSON NID from the relayapi_queue_json table.
json_nid BIGINT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_transactions_transaction_json_nid_idx
ON federationsender_queue_transactions (json_nid, server_name);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_json_nid_idx
ON federationsender_queue_transactions (json_nid);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_server_name_idx
ON federationsender_queue_transactions (server_name);
CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx
ON relayapi_queue (json_nid, server_name);
CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx
ON relayapi_queue (json_nid);
CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx
ON relayapi_queue (server_name);
`
const insertQueueTransactionSQL = "" +
"INSERT INTO federationsender_queue_transactions (transaction_id, server_name, json_nid)" +
const insertQueueEntrySQL = "" +
"INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
const deleteQueueTransactionsSQL = "" +
"DELETE FROM federationsender_queue_transactions WHERE server_name = $1 AND json_nid = ANY($2)"
const deleteQueueEntriesSQL = "" +
"DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid = ANY($2)"
const selectQueueTransactionsSQL = "" +
"SELECT json_nid FROM federationsender_queue_transactions" +
const selectQueueEntriesSQL = "" +
"SELECT json_nid FROM relayapi_queue" +
" WHERE server_name = $1" +
" LIMIT $2"
const selectQueueTransactionsCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_transactions" +
const selectQueueEntryCountSQL = "" +
"SELECT COUNT(*) FROM relayapi_queue" +
" WHERE server_name = $1"
type queueTransactionsStatements struct {
type relayQueueStatements struct {
db *sql.DB
insertQueueTransactionStmt *sql.Stmt
deleteQueueTransactionsStmt *sql.Stmt
selectQueueTransactionsStmt *sql.Stmt
selectQueueTransactionsCountStmt *sql.Stmt
insertQueueEntryStmt *sql.Stmt
deleteQueueEntriesStmt *sql.Stmt
selectQueueEntriesStmt *sql.Stmt
selectQueueEntryCountStmt *sql.Stmt
}
func NewPostgresQueueTransactionsTable(db *sql.DB) (s *queueTransactionsStatements, err error) {
s = &queueTransactionsStatements{
func NewPostgresRelayQueueTable(db *sql.DB) (s *relayQueueStatements, err error) {
s = &relayQueueStatements{
db: db,
}
_, err = s.db.Exec(queueTransactionsSchema)
_, err = s.db.Exec(relayQueueSchema)
if err != nil {
return
}
if s.insertQueueTransactionStmt, err = s.db.Prepare(insertQueueTransactionSQL); err != nil {
if s.insertQueueEntryStmt, err = s.db.Prepare(insertQueueEntrySQL); err != nil {
return
}
if s.deleteQueueTransactionsStmt, err = s.db.Prepare(deleteQueueTransactionsSQL); err != nil {
if s.deleteQueueEntriesStmt, err = s.db.Prepare(deleteQueueEntriesSQL); err != nil {
return
}
if s.selectQueueTransactionsStmt, err = s.db.Prepare(selectQueueTransactionsSQL); err != nil {
if s.selectQueueEntriesStmt, err = s.db.Prepare(selectQueueEntriesSQL); err != nil {
return
}
if s.selectQueueTransactionsCountStmt, err = s.db.Prepare(selectQueueTransactionsCountSQL); err != nil {
if s.selectQueueEntryCountStmt, err = s.db.Prepare(selectQueueEntryCountSQL); err != nil {
return
}
return
}
func (s *queueTransactionsStatements) InsertQueueTransaction(
func (s *relayQueueStatements) InsertQueueEntry(
ctx context.Context,
txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueTransactionStmt)
stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt)
_, err := stmt.ExecContext(
ctx,
transactionID, // the transaction ID that we initially attempted
@ -106,22 +106,22 @@ func (s *queueTransactionsStatements) InsertQueueTransaction(
return err
}
func (s *queueTransactionsStatements) DeleteQueueTransactions(
func (s *relayQueueStatements) DeleteQueueEntries(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
jsonNIDs []int64,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionsStmt)
stmt := sqlutil.TxStmt(txn, s.deleteQueueEntriesStmt)
_, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs))
return err
}
func (s *queueTransactionsStatements) SelectQueueTransactions(
func (s *relayQueueStatements) SelectQueueEntries(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsStmt)
stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil {
return nil, err
@ -139,11 +139,11 @@ func (s *queueTransactionsStatements) SelectQueueTransactions(
return result, rows.Err()
}
func (s *queueTransactionsStatements) SelectQueueTransactionCount(
func (s *relayQueueStatements) SelectQueueEntryCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsCountStmt)
stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given

View file

@ -0,0 +1,59 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"database/sql"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database stores information needed by the relayapi
type Database struct {
shared.Database
db *sql.DB
writer sqlutil.Writer
}
// NewDatabase opens a new database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) {
var d Database
var err error
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
return nil, err
}
queue, err := NewPostgresRelayQueueTable(d.db)
if err != nil {
return nil, err
}
queueJSON, err := NewPostgresRelayQueueJSONTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
IsLocalServerName: isLocalServerName,
Cache: cache,
Writer: d.writer,
RelayQueue: queue,
RelayQueueJSON: queueJSON,
}
return &d, nil
}

View file

@ -0,0 +1,142 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shared
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
type Database struct {
DB *sql.DB
IsLocalServerName func(gomatrixserverlib.ServerName) bool
Cache caching.FederationCache
Writer sqlutil.Writer
RelayQueue tables.RelayQueue
RelayQueueJSON tables.RelayQueueJSON
}
func (d *Database) StoreAsyncTransaction(
ctx context.Context, txn gomatrixserverlib.Transaction,
) (*shared.Receipt, error) {
var err error
json, err := json.Marshal(txn)
if err != nil {
return nil, fmt.Errorf("d.JSONUnmarshall: %w", err)
}
var nid int64
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nid, err = d.RelayQueueJSON.InsertQueueJSON(ctx, txn, string(json))
return err
})
if err != nil {
return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
}
receipt := shared.NewReceipt(nid)
return &receipt, nil
}
func (d *Database) AssociateAsyncTransactionWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.UserID]struct{},
transactionID gomatrixserverlib.TransactionID,
receipt *shared.Receipt,
) error {
for destination := range destinations {
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.RelayQueue.InsertQueueEntry(
ctx, txn, transactionID, destination.Domain(), receipt.GetNID())
return err
})
if err != nil {
return fmt.Errorf("d.insertQueueEntry: %w", err)
}
}
return nil
}
func (d *Database) CleanAsyncTransactions(
ctx context.Context,
userID gomatrixserverlib.UserID,
receipts []*shared.Receipt,
) error {
println(len(receipts))
nids := make([]int64, len(receipts))
for i, receipt := range receipts {
nids[i] = receipt.GetNID()
}
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.RelayQueue.DeleteQueueEntries(ctx, txn, userID.Domain(), nids)
return err
})
if err != nil {
return fmt.Errorf("d.deleteQueueEntries: %w", err)
}
return nil
}
func (d *Database) GetAsyncTransaction(
ctx context.Context,
userID gomatrixserverlib.UserID,
) (*gomatrixserverlib.Transaction, *shared.Receipt, error) {
nids, err := d.RelayQueue.SelectQueueEntries(ctx, nil, userID.Domain(), 1)
if err != nil {
return nil, nil, fmt.Errorf("d.SelectQueueEntries: %w", err)
}
if len(nids) == 0 {
return nil, nil, nil
}
txns := map[int64][]byte{}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
txns, err = d.RelayQueueJSON.SelectQueueJSON(ctx, txn, nids)
return err
})
if err != nil {
return nil, nil, fmt.Errorf("d.SelectQueueJSON: %w", err)
}
transaction := &gomatrixserverlib.Transaction{}
err = json.Unmarshal(txns[nids[0]], transaction)
if err != nil {
return nil, nil, fmt.Errorf("Unmarshall transaction: %w", err)
}
receipt := shared.NewReceipt(nids[0])
return transaction, &receipt, nil
}
func (d *Database) GetAsyncTransactionCount(
ctx context.Context,
userID gomatrixserverlib.UserID,
) (int64, error) {
count, err := d.RelayQueue.SelectQueueEntryCount(ctx, nil, userID.Domain())
if err != nil {
return 0, fmt.Errorf("d.SelectQueueEntryCount: %w", err)
}
return count, nil
}

View file

@ -24,53 +24,53 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const transactionJSONSchema = `
-- The federationsender_transaction_json table contains event contents that
const relayQueueJSONSchema = `
-- The relayapi_queue_json table contains event contents that
-- we are storing for future forwarding.
CREATE TABLE IF NOT EXISTS federationsender_transaction_json (
CREATE TABLE IF NOT EXISTS relayapi_queue_json (
-- The JSON NID. This allows cross-referencing to find the JSON blob.
json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
-- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_transaction_json_json_nid_idx
ON federationsender_transaction_json (json_nid);
CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx
ON relayapi_queue_json (json_nid);
`
const insertTransactionJSONSQL = "" +
"INSERT INTO federationsender_transaction_json (json_body)" +
const insertQueueJSONSQL = "" +
"INSERT INTO relayapi_queue_json (json_body)" +
" VALUES ($1)"
const deleteTransactionJSONSQL = "" +
"DELETE FROM federationsender_transaction_json WHERE json_nid IN ($1)"
const deleteQueueJSONSQL = "" +
"DELETE FROM relayapi_queue_json WHERE json_nid IN ($1)"
const selectTransactionJSONSQL = "" +
"SELECT json_nid, json_body FROM federationsender_transaction_json" +
const selectQueueJSONSQL = "" +
"SELECT json_nid, json_body FROM relayapi_queue_json" +
" WHERE json_nid IN ($1)"
type transactionJSONStatements struct {
type relayQueueJSONStatements struct {
db *sql.DB
insertJSONStmt *sql.Stmt
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
}
func NewSQLiteTransactionJSONTable(db *sql.DB) (s *transactionJSONStatements, err error) {
s = &transactionJSONStatements{
func NewSQLiteRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) {
s = &relayQueueJSONStatements{
db: db,
}
_, err = db.Exec(transactionJSONSchema)
_, err = db.Exec(relayQueueJSONSchema)
if err != nil {
return
}
if s.insertJSONStmt, err = db.Prepare(insertTransactionJSONSQL); err != nil {
if s.insertJSONStmt, err = db.Prepare(insertQueueJSONSQL); err != nil {
return
}
return
}
func (s *transactionJSONStatements) InsertTransactionJSON(
func (s *relayQueueJSONStatements) InsertQueueJSON(
ctx context.Context, txn *sql.Tx, json string,
) (lastid int64, err error) {
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
@ -85,13 +85,13 @@ func (s *transactionJSONStatements) InsertTransactionJSON(
return
}
func (s *transactionJSONStatements) DeleteTransactionJSON(
func (s *relayQueueJSONStatements) DeleteQueueJSON(
ctx context.Context, txn *sql.Tx, nids []int64,
) error {
deleteSQL := strings.Replace(deleteTransactionJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
deleteSQL := strings.Replace(deleteQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
deleteStmt, err := txn.Prepare(deleteSQL)
if err != nil {
return fmt.Errorf("s.deleteTransactionJSON s.db.Prepare: %w", err)
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
}
iNIDs := make([]interface{}, len(nids))
@ -104,13 +104,13 @@ func (s *transactionJSONStatements) DeleteTransactionJSON(
return err
}
func (s *transactionJSONStatements) SelectTransactionJSON(
func (s *relayQueueJSONStatements) SelectQueueJSON(
ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
) (map[int64][]byte, error) {
selectSQL := strings.Replace(selectTransactionJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
selectSQL := strings.Replace(selectQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
selectStmt, err := txn.Prepare(selectSQL)
if err != nil {
return nil, fmt.Errorf("s.selectTransactionJSON s.db.Prepare: %w", err)
return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err)
}
iNIDs := make([]interface{}, len(jsonNIDs))
@ -122,14 +122,14 @@ func (s *transactionJSONStatements) SelectTransactionJSON(
stmt := sqlutil.TxStmt(txn, selectStmt)
rows, err := stmt.QueryContext(ctx, iNIDs...)
if err != nil {
return nil, fmt.Errorf("s.selectTransactionJSON stmt.QueryContext: %w", err)
return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err)
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed")
defer internal.CloseAndLogIfError(ctx, rows, "selectQueueJSON: rows.close() failed")
for rows.Next() {
var nid int64
var blob []byte
if err = rows.Scan(&nid, &blob); err != nil {
return nil, fmt.Errorf("s.selectTransactionJSON rows.Scan: %w", err)
return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err)
}
blobs[nid] = blob
}

View file

@ -25,79 +25,79 @@ import (
"github.com/matrix-org/gomatrixserverlib"
)
const queueTransactionsSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_transactions (
const relayQueueSchema = `
CREATE TABLE IF NOT EXISTS relayapi_queue (
-- The transaction ID that was generated before persisting the event.
transaction_id TEXT NOT NULL,
-- The domain part of the user ID the m.room.member event is for.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_queue_transactions_json table.
-- The JSON NID from the relayapi_queue_json table.
json_nid BIGINT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_transactions_transaction_json_nid_idx
ON federationsender_queue_transactions (json_nid, server_name);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_json_nid_idx
ON federationsender_queue_transactions (json_nid);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_server_name_idx
ON federationsender_queue_transactions (server_name);
CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx
ON relayapi_queue (json_nid, server_name);
CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx
ON relayapi_queue (json_nid);
CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx
ON relayapi_queue (server_name);
`
const insertQueueTransactionSQL = "" +
"INSERT INTO federationsender_queue_transactions (transaction_id, server_name, json_nid)" +
const insertQueueEntrySQL = "" +
"INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
const deleteQueueTransactionsSQL = "" +
"DELETE FROM federationsender_queue_transactions WHERE server_name = $1 AND json_nid IN ($2)"
const deleteQueueEntriesSQL = "" +
"DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid IN ($2)"
const selectQueueTransactionsSQL = "" +
"SELECT json_nid FROM federationsender_queue_transactions" +
const selectQueueEntriesSQL = "" +
"SELECT json_nid FROM relayapi_queue" +
" WHERE server_name = $1" +
" LIMIT $2"
const selectQueueTransactionsCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_transactions" +
const selectQueueEntryCountSQL = "" +
"SELECT COUNT(*) FROM relayapi_queue" +
" WHERE server_name = $1"
type queueTransactionsStatements struct {
type relayQueueStatements struct {
db *sql.DB
insertQueueTransactionStmt *sql.Stmt
selectQueueTransactionsStmt *sql.Stmt
selectQueueTransactionsCountStmt *sql.Stmt
// deleteQueueTransactionsStmt *sql.Stmt - prepared at runtime due to variadic
insertQueueEntryStmt *sql.Stmt
selectQueueEntriesStmt *sql.Stmt
selectQueueEntryCountStmt *sql.Stmt
// deleteQueueEntriesStmt *sql.Stmt - prepared at runtime due to variadic
}
func NewSQLiteQueueTransactionsTable(db *sql.DB) (s *queueTransactionsStatements, err error) {
s = &queueTransactionsStatements{
func NewSQLiteRelayQueueTable(db *sql.DB) (s *relayQueueStatements, err error) {
s = &relayQueueStatements{
db: db,
}
_, err = db.Exec(queueTransactionsSchema)
_, err = db.Exec(relayQueueSchema)
if err != nil {
return
}
if s.insertQueueTransactionStmt, err = db.Prepare(insertQueueTransactionSQL); err != nil {
if s.insertQueueEntryStmt, err = db.Prepare(insertQueueEntrySQL); err != nil {
return
}
//if s.deleteQueueTransactionsStmt, err = db.Prepare(deleteQueueTransactionsSQL); err != nil {
//if s.deleteQueueEntriesStmt, err = db.Prepare(deleteQueueEntriesSQL); err != nil {
// return
//}
if s.selectQueueTransactionsStmt, err = db.Prepare(selectQueueTransactionsSQL); err != nil {
if s.selectQueueEntriesStmt, err = db.Prepare(selectQueueEntriesSQL); err != nil {
return
}
if s.selectQueueTransactionsCountStmt, err = db.Prepare(selectQueueTransactionsCountSQL); err != nil {
if s.selectQueueEntryCountStmt, err = db.Prepare(selectQueueEntryCountSQL); err != nil {
return
}
return
}
func (s *queueTransactionsStatements) InsertQueueTransaction(
func (s *relayQueueStatements) InsertQueueEntry(
ctx context.Context,
txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueTransactionStmt)
stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt)
_, err := stmt.ExecContext(
ctx,
transactionID, // the transaction ID that we initially attempted
@ -107,15 +107,15 @@ func (s *queueTransactionsStatements) InsertQueueTransaction(
return err
}
func (s *queueTransactionsStatements) DeleteQueueTransactions(
func (s *relayQueueStatements) DeleteQueueEntries(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
jsonNIDs []int64,
) error {
deleteSQL := strings.Replace(deleteQueueTransactionsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
deleteSQL := strings.Replace(deleteQueueEntriesSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
deleteStmt, err := txn.Prepare(deleteSQL)
if err != nil {
return fmt.Errorf("s.deleteQueueTransactionJSON s.db.Prepare: %w", err)
return fmt.Errorf("s.deleteQueueEntries s.db.Prepare: %w", err)
}
params := make([]interface{}, len(jsonNIDs)+1)
@ -129,12 +129,12 @@ func (s *queueTransactionsStatements) DeleteQueueTransactions(
return err
}
func (s *queueTransactionsStatements) SelectQueueTransactions(
func (s *relayQueueStatements) SelectQueueEntries(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsStmt)
stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil {
return nil, err
@ -152,11 +152,11 @@ func (s *queueTransactionsStatements) SelectQueueTransactions(
return result, rows.Err()
}
func (s *queueTransactionsStatements) SelectQueueTransactionCount(
func (s *relayQueueStatements) SelectQueueEntryCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsCountStmt)
stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given

View file

@ -0,0 +1,53 @@
// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"database/sql"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database stores information needed by the federation sender
type Database struct {
shared.Database
db *sql.DB
writer sqlutil.Writer
}
// NewDatabase opens a new database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) {
var d Database
var err error
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
return nil, err
}
queue, err := NewSQLiteRelayQueueTable(d.db)
if err != nil {
return nil, err
}
queueJSON, err := NewSQLiteRelayQueueJSONTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
IsLocalServerName: isLocalServerName,
Cache: cache,
Writer: d.writer,
RelayQueue: queue,
RelayQueueJSON: queueJSON,
}
return &d, nil
}

View file

@ -0,0 +1,41 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !wasm
// +build !wasm
package storage
import (
"fmt"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/relayapi/storage/postgres"
"github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// NewDatabase opens a new database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName)
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -0,0 +1,35 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tables
import (
"context"
"database/sql"
"github.com/matrix-org/gomatrixserverlib"
)
type RelayQueue interface {
InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
}
type RelayQueueJSON interface {
InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error)
DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error
SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error)
}

View file

@ -8,10 +8,10 @@ import (
"testing"
"time"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/postgres"
"github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
"github.com/matrix-org/dendrite/relayapi/storage/tables"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
@ -35,31 +35,31 @@ func mustCreateTransaction() gomatrixserverlib.Transaction {
return txn
}
type TransactionJSONDatabase struct {
type RelayQueueJSONDatabase struct {
DB *sql.DB
Writer sqlutil.Writer
Table tables.FederationTransactionJSON
Table tables.RelayQueueJSON
}
func mustCreateTransactionJSONTable(t *testing.T, dbType test.DBType) (database TransactionJSONDatabase, close func()) {
func mustCreateQueueJSONTable(t *testing.T, dbType test.DBType) (database RelayQueueJSONDatabase, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
var tab tables.FederationTransactionJSON
var tab tables.RelayQueueJSON
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresTransactionJSONTable(db)
tab, err = postgres.NewPostgresRelayQueueJSONTable(db)
assert.NoError(t, err)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteTransactionJSONTable(db)
tab, err = sqlite3.NewSQLiteRelayQueueJSONTable(db)
assert.NoError(t, err)
}
assert.NoError(t, err)
database = TransactionJSONDatabase{
database = RelayQueueJSONDatabase{
DB: db,
Writer: sqlutil.NewDummyWriter(),
Table: tab,
@ -70,7 +70,7 @@ func mustCreateTransactionJSONTable(t *testing.T, dbType test.DBType) (database
func TestShoudInsertTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateTransactionJSONTable(t, dbType)
db, close := mustCreateQueueJSONTable(t, dbType)
defer close()
transaction := mustCreateTransaction()
@ -79,7 +79,7 @@ func TestShoudInsertTransaction(t *testing.T) {
t.Fatalf("Invalid transaction: %s", err.Error())
}
_, err = db.Table.InsertTransactionJSON(ctx, nil, string(tx))
_, err = db.Table.InsertQueueJSON(ctx, nil, string(tx))
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
@ -89,7 +89,7 @@ func TestShoudInsertTransaction(t *testing.T) {
func TestShouldRetrieveInsertedTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateTransactionJSONTable(t, dbType)
db, close := mustCreateQueueJSONTable(t, dbType)
defer close()
transaction := mustCreateTransaction()
@ -98,14 +98,14 @@ func TestShouldRetrieveInsertedTransaction(t *testing.T) {
t.Fatalf("Invalid transaction: %s", err.Error())
}
nid, err := db.Table.InsertTransactionJSON(ctx, nil, string(tx))
nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx))
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
var storedJSON map[int64][]byte
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
storedJSON, err = db.Table.SelectTransactionJSON(ctx, txn, []int64{nid})
storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid})
return err
})
if err != nil {
@ -124,7 +124,7 @@ func TestShouldRetrieveInsertedTransaction(t *testing.T) {
func TestShouldDeleteTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateTransactionJSONTable(t, dbType)
db, close := mustCreateQueueJSONTable(t, dbType)
defer close()
transaction := mustCreateTransaction()
@ -133,14 +133,14 @@ func TestShouldDeleteTransaction(t *testing.T) {
t.Fatalf("Invalid transaction: %s", err.Error())
}
nid, err := db.Table.InsertTransactionJSON(ctx, nil, string(tx))
nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx))
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
storedJSON := map[int64][]byte{}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
err = db.Table.DeleteTransactionJSON(ctx, txn, []int64{nid})
err = db.Table.DeleteQueueJSON(ctx, txn, []int64{nid})
return err
})
if err != nil {
@ -149,7 +149,7 @@ func TestShouldDeleteTransaction(t *testing.T) {
storedJSON = map[int64][]byte{}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
storedJSON, err = db.Table.SelectTransactionJSON(ctx, txn, []int64{nid})
storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid})
return err
})
if err != nil {

View file

@ -7,41 +7,41 @@ import (
"testing"
"time"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/postgres"
"github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
"github.com/matrix-org/dendrite/relayapi/storage/tables"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
type QueueTransactionsDatabase struct {
type RelayQueueDatabase struct {
DB *sql.DB
Writer sqlutil.Writer
Table tables.FederationQueueTransactions
Table tables.RelayQueue
}
func mustCreateQueueTransactionsTable(t *testing.T, dbType test.DBType) (database QueueTransactionsDatabase, close func()) {
func mustCreateQueueTable(t *testing.T, dbType test.DBType) (database RelayQueueDatabase, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
var tab tables.FederationQueueTransactions
var tab tables.RelayQueue
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresQueueTransactionsTable(db)
tab, err = postgres.NewPostgresRelayQueueTable(db)
assert.NoError(t, err)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteQueueTransactionsTable(db)
tab, err = sqlite3.NewSQLiteRelayQueueTable(db)
assert.NoError(t, err)
}
assert.NoError(t, err)
database = QueueTransactionsDatabase{
database = RelayQueueDatabase{
DB: db,
Writer: sqlutil.NewDummyWriter(),
Table: tab,
@ -52,13 +52,13 @@ func mustCreateQueueTransactionsTable(t *testing.T, dbType test.DBType) (databas
func TestShoudInsertQueueTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTransactionsTable(t, dbType)
db, close := mustCreateQueueTable(t, dbType)
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
nid := int64(1)
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid)
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
@ -68,19 +68,19 @@ func TestShoudInsertQueueTransaction(t *testing.T) {
func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTransactionsTable(t, dbType)
db, close := mustCreateQueueTable(t, dbType)
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
nid := int64(1)
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid)
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
retrievedNids, err := db.Table.SelectQueueTransactions(ctx, nil, serverName, 10)
retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 10)
if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
@ -93,27 +93,27 @@ func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) {
func TestShouldDeleteQueueTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTransactionsTable(t, dbType)
db, close := mustCreateQueueTable(t, dbType)
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
nid := int64(1)
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid)
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
err = db.Table.DeleteQueueTransactions(ctx, txn, serverName, []int64{nid})
err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid})
return err
})
if err != nil {
t.Fatalf("Failed deleting transaction: %s", err.Error())
}
count, err := db.Table.SelectQueueTransactionCount(ctx, nil, serverName)
count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
@ -124,7 +124,7 @@ func TestShouldDeleteQueueTransaction(t *testing.T) {
func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTransactionsTable(t, dbType)
db, close := mustCreateQueueTable(t, dbType)
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
@ -135,34 +135,34 @@ func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) {
nid2 := int64(2)
transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano()))
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid)
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertQueueTransaction(ctx, nil, transactionID2, serverName2, nid)
err = db.Table.InsertQueueEntry(ctx, nil, transactionID2, serverName2, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertQueueTransaction(ctx, nil, transactionID3, serverName, nid2)
err = db.Table.InsertQueueEntry(ctx, nil, transactionID3, serverName, nid2)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
err = db.Table.DeleteQueueTransactions(ctx, txn, serverName, []int64{nid})
err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid})
return err
})
if err != nil {
t.Fatalf("Failed deleting transaction: %s", err.Error())
}
count, err := db.Table.SelectQueueTransactionCount(ctx, nil, serverName)
count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
assert.Equal(t, int64(1), count)
count, err = db.Table.SelectQueueTransactionCount(ctx, nil, serverName2)
count, err = db.Table.SelectQueueEntryCount(ctx, nil, serverName2)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}

View file

@ -62,6 +62,7 @@ type Dendrite struct {
RoomServer RoomServer `yaml:"room_server"`
SyncAPI SyncAPI `yaml:"sync_api"`
UserAPI UserAPI `yaml:"user_api"`
RelayAPI RelayAPI `yaml:"relay_api"`
MSCs MSCs `yaml:"mscs"`

View file

@ -0,0 +1,38 @@
package config
type RelayAPI struct {
Matrix *Global `yaml:"-"`
InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"`
ExternalAPI ExternalAPIOptions `yaml:"external_api,omitempty"`
// The database stores information used by the relay queue to
// forward transactions to remote servers.
Database DatabaseOptions `yaml:"database,omitempty"`
}
func (c *RelayAPI) Defaults(opts DefaultOpts) {
if !opts.Monolithic {
c.InternalAPI.Listen = "http://localhost:7775"
c.InternalAPI.Connect = "http://localhost:7775"
c.ExternalAPI.Listen = "http://[::]:8075"
c.Database.Defaults(10)
}
if opts.Generate {
if !opts.Monolithic {
c.Database.ConnectionString = "file:relayapi.db"
}
}
}
func (c *RelayAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
if isMonolith { // polylith required configs below
return
}
if c.Matrix.DatabaseOptions.ConnectionString == "" {
checkNotEmpty(configErrs, "relay_api.database.connection_string", string(c.Database.ConnectionString))
}
checkURL(configErrs, "relay_api.external_api.listen", string(c.ExternalAPI.Listen))
checkURL(configErrs, "relay_api.internal_api.listen", string(c.InternalAPI.Listen))
checkURL(configErrs, "relay_api.internal_api.connect", string(c.InternalAPI.Connect))
}

View file

@ -23,6 +23,8 @@ import (
"github.com/matrix-org/dendrite/internal/transactions"
keyAPI "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/mediaapi"
"github.com/matrix-org/dendrite/relayapi"
relayAPI "github.com/matrix-org/dendrite/relayapi/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
@ -44,6 +46,7 @@ type Monolith struct {
RoomserverAPI roomserverAPI.RoomserverInternalAPI
UserAPI userapi.UserInternalAPI
KeyAPI keyAPI.KeyInternalAPI
RelayAPI relayAPI.RelayInternalAPI
// Optional
ExtPublicRoomsProvider api.ExtraPublicRoomsProvider
@ -71,4 +74,9 @@ func (m *Monolith) AddAllPublicRoutes(base *base.BaseDendrite) {
syncapi.AddPublicRoutes(
base, m.UserAPI, m.RoomserverAPI, m.KeyAPI,
)
// TODO : relayapi.AddPublicRoutes
if m.RelayAPI != nil {
relayapi.AddPublicRoutes(base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.RelayAPI, m.FederationAPI, m.KeyAPI)
}
}