mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-16 18:43:10 -06:00
Refactor all relay specific stuff into it's own component
This commit is contained in:
parent
f300a4d0e9
commit
ad53326ce8
|
|
@ -458,7 +458,7 @@ func (m *DendriteMonolith) Start() {
|
|||
switch e := event.(type) {
|
||||
case pineconeEvents.PeerAdded:
|
||||
if !relayServerSyncRunning.Load() {
|
||||
go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning)
|
||||
// go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning)
|
||||
}
|
||||
case pineconeEvents.PeerRemoved:
|
||||
if relayServerSyncRunning.Load() && m.PineconeRouter.PeerCount(-1) == 0 {
|
||||
|
|
@ -486,46 +486,6 @@ func (m *DendriteMonolith) Start() {
|
|||
}(pineconeEventChannel)
|
||||
}
|
||||
|
||||
func (m *DendriteMonolith) syncRelayServers(stop <-chan bool, running atomic.Bool) {
|
||||
defer running.Store(false)
|
||||
|
||||
t := time.NewTimer(relayServerRetryInterval)
|
||||
for {
|
||||
relayServersToQuery := []gomatrixserverlib.ServerName{}
|
||||
for server, complete := range m.relayServersQueried {
|
||||
if !complete {
|
||||
relayServersToQuery = append(relayServersToQuery, server)
|
||||
}
|
||||
}
|
||||
if len(relayServersToQuery) == 0 {
|
||||
// All relay servers have been synced.
|
||||
return
|
||||
}
|
||||
m.queryRelayServers(relayServersToQuery)
|
||||
t.Reset(relayServerRetryInterval)
|
||||
|
||||
select {
|
||||
case <-stop:
|
||||
if !t.Stop() {
|
||||
<-t.C
|
||||
}
|
||||
return
|
||||
case <-t.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DendriteMonolith) queryRelayServers(relayServers []gomatrixserverlib.ServerName) {
|
||||
for _, server := range relayServers {
|
||||
request := api.PerformRelayServerSyncRequest{RelayServer: server}
|
||||
response := api.PerformRelayServerSyncResponse{}
|
||||
err := m.federationAPI.PerformRelayServerSync(m.processContext.Context(), &request, &response)
|
||||
if err == nil {
|
||||
m.relayServersQueried[server] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DendriteMonolith) Stop() {
|
||||
m.processContext.ShutdownDendrite()
|
||||
_ = m.listener.Close()
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
"github.com/matrix-org/dendrite/keyserver"
|
||||
"github.com/matrix-org/dendrite/relayapi"
|
||||
relayServerAPI "github.com/matrix-org/dendrite/relayapi/api"
|
||||
"github.com/matrix-org/dendrite/roomserver"
|
||||
"github.com/matrix-org/dendrite/setup"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
|
|
@ -145,6 +146,7 @@ func main() {
|
|||
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(*instanceDir, *instanceName)))
|
||||
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(*instanceDir, *instanceName)))
|
||||
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", filepath.Join(*instanceDir, *instanceName)))
|
||||
cfg.RelayAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-relayapi.db", filepath.Join(*instanceDir, *instanceName)))
|
||||
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
|
||||
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName)))
|
||||
cfg.ClientAPI.RegistrationDisabled = false
|
||||
|
|
@ -235,6 +237,20 @@ func main() {
|
|||
userProvider := users.NewPineconeUserProvider(pRouter, pQUIC, userAPI, federation)
|
||||
roomProvider := rooms.NewPineconeRoomProvider(pRouter, pQUIC, fsAPI, federation)
|
||||
|
||||
js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
|
||||
producer := &producers.SyncAPIProducer{
|
||||
JetStream: js,
|
||||
TopicReceiptEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent),
|
||||
TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
|
||||
TopicTypingEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent),
|
||||
TopicPresenceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent),
|
||||
TopicDeviceListUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
|
||||
TopicSigningKeyUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
|
||||
Config: &base.Cfg.FederationAPI,
|
||||
UserAPI: userAPI,
|
||||
}
|
||||
relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer)
|
||||
|
||||
monolith := setup.Monolith{
|
||||
Config: base.Cfg,
|
||||
Client: conn.CreateClient(base, pQUIC),
|
||||
|
|
@ -246,6 +262,7 @@ func main() {
|
|||
RoomserverAPI: rsAPI,
|
||||
UserAPI: userAPI,
|
||||
KeyAPI: keyAPI,
|
||||
RelayAPI: &relayAPI,
|
||||
ExtPublicRoomsProvider: roomProvider,
|
||||
ExtUserDirectoryProvider: userProvider,
|
||||
}
|
||||
|
|
@ -319,32 +336,12 @@ func main() {
|
|||
relayServerSyncRunning := atomic.NewBool(false)
|
||||
stopRelayServerSync := make(chan bool)
|
||||
|
||||
js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.FederationAPI.Matrix.JetStream)
|
||||
producer := &producers.SyncAPIProducer{
|
||||
JetStream: js,
|
||||
TopicReceiptEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent),
|
||||
TopicSendToDeviceEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
|
||||
TopicTypingEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputTypingEvent),
|
||||
TopicPresenceEvent: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent),
|
||||
TopicDeviceListUpdate: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
|
||||
TopicSigningKeyUpdate: base.Cfg.FederationAPI.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
|
||||
Config: &base.Cfg.FederationAPI,
|
||||
UserAPI: userAPI,
|
||||
}
|
||||
|
||||
m := RelayServerRetriever{
|
||||
Context: context.Background(),
|
||||
ServerName: gomatrixserverlib.ServerName(pRouter.PublicKey().String()),
|
||||
FederationAPI: fsAPI,
|
||||
RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool),
|
||||
RelayAPI: relayapi.NewRelayAPI(
|
||||
federation,
|
||||
rsAPI,
|
||||
keyRing,
|
||||
producer,
|
||||
cfg.Global.Presence.EnableInbound,
|
||||
cfg.Global.ServerName,
|
||||
),
|
||||
RelayAPI: monolith.RelayAPI,
|
||||
}
|
||||
m.InitializeRelayServers(eLog)
|
||||
|
||||
|
|
@ -387,7 +384,7 @@ type RelayServerRetriever struct {
|
|||
ServerName gomatrixserverlib.ServerName
|
||||
FederationAPI api.FederationInternalAPI
|
||||
RelayServersQueried map[gomatrixserverlib.ServerName]bool
|
||||
RelayAPI relayapi.RelayAPI
|
||||
RelayAPI relayServerAPI.RelayInternalAPI
|
||||
}
|
||||
|
||||
func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) {
|
||||
|
|
@ -440,7 +437,12 @@ func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverli
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = m.RelayAPI.PerformRelayServerSync(*userID, server)
|
||||
request := relayServerAPI.PerformRelayServerSyncRequest{
|
||||
UserID: *userID,
|
||||
RelayServer: server,
|
||||
}
|
||||
response := relayServerAPI.PerformRelayServerSyncResponse{}
|
||||
err = m.RelayAPI.PerformRelayServerSync(context.Background(), &request, &response)
|
||||
if err == nil {
|
||||
m.RelayServersQueried[server] = true
|
||||
// TODO : What happens if your relay receives new messages after this point?
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ type FederationInternalAPI interface {
|
|||
gomatrixserverlib.KeyDatabase
|
||||
ClientFederationAPI
|
||||
RoomserverFederationAPI
|
||||
RelayServerAPI
|
||||
|
||||
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
|
||||
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
|
||||
|
|
@ -45,22 +44,6 @@ type FederationInternalAPI interface {
|
|||
) error
|
||||
}
|
||||
|
||||
type RelayServerAPI interface {
|
||||
// Store async transactions for forwarding to the destination at a later time.
|
||||
PerformStoreAsync(
|
||||
ctx context.Context,
|
||||
request *PerformStoreAsyncRequest,
|
||||
response *PerformStoreAsyncResponse,
|
||||
) error
|
||||
|
||||
// Obtain the oldest stored transaction for the specified userID.
|
||||
QueryAsyncTransactions(
|
||||
ctx context.Context,
|
||||
request *QueryAsyncTransactionsRequest,
|
||||
response *QueryAsyncTransactionsResponse,
|
||||
) error
|
||||
}
|
||||
|
||||
type ClientFederationAPI interface {
|
||||
// Query the server names of the joined hosts in a room.
|
||||
// Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/federationapi/consumers"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/version"
|
||||
)
|
||||
|
|
@ -847,64 +846,3 @@ func (r *FederationInternalAPI) QueryRelayServers(
|
|||
response.RelayServers = relayServers
|
||||
return nil
|
||||
}
|
||||
|
||||
// PerformStoreAsync implements api.FederationInternalAPI
|
||||
func (r *FederationInternalAPI) PerformStoreAsync(
|
||||
ctx context.Context,
|
||||
request *api.PerformStoreAsyncRequest,
|
||||
response *api.PerformStoreAsyncResponse,
|
||||
) error {
|
||||
logrus.Warnf("Storing transaction for %v", request.UserID)
|
||||
receipt, err := r.db.StoreAsyncTransaction(ctx, request.Txn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.db.AssociateAsyncTransactionWithDestinations(
|
||||
ctx,
|
||||
map[gomatrixserverlib.UserID]struct{}{
|
||||
request.UserID: {},
|
||||
},
|
||||
request.Txn.TransactionID,
|
||||
receipt)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// QueryAsyncTransactions implements api.FederationInternalAPI
|
||||
func (r *FederationInternalAPI) QueryAsyncTransactions(
|
||||
ctx context.Context,
|
||||
request *api.QueryAsyncTransactionsRequest,
|
||||
response *api.QueryAsyncTransactionsResponse,
|
||||
) error {
|
||||
logrus.Warnf("Obtaining transaction for %v", request.UserID)
|
||||
transaction, receipt, err := r.db.GetAsyncTransaction(ctx, request.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO : Shouldn't be deleting unless the transaction was successfully returned...
|
||||
// TODO : Should delete transaction json from table if no more associations
|
||||
// Maybe track last received transaction, and send that as part of the request,
|
||||
// then delete before getting the new events from the db.
|
||||
if transaction != nil && receipt != nil {
|
||||
err = r.db.CleanAsyncTransactions(ctx, request.UserID, []*shared.Receipt{receipt})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO : Clean async transactions json
|
||||
}
|
||||
|
||||
// TODO : These db calls should happen at the same time right?
|
||||
count, err := r.db.GetAsyncTransactionCount(ctx, request.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response.RemainingCount = uint32(count)
|
||||
if transaction != nil {
|
||||
response.Txn = *transaction
|
||||
logrus.Warnf("Obtained transaction: %v", transaction.TransactionID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,9 +26,6 @@ const (
|
|||
FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers"
|
||||
FederationAPIQueryRelayServers = "/federationapi/queryRelayServers"
|
||||
|
||||
FederationAPIPerformStoreAsyncPath = "/federationapi/performStoreAsync"
|
||||
FederationAPIQueryAsyncTransactionsPath = "/federationapi/queryAsyncTransactions"
|
||||
|
||||
FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices"
|
||||
FederationAPIClaimKeysPath = "/federationapi/client/claimKeys"
|
||||
FederationAPIQueryKeysPath = "/federationapi/client/queryKeys"
|
||||
|
|
@ -525,25 +522,3 @@ func (h *httpFederationInternalAPI) QueryRelayServers(
|
|||
h.httpClient, ctx, request, response,
|
||||
)
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) PerformStoreAsync(
|
||||
ctx context.Context,
|
||||
request *api.PerformStoreAsyncRequest,
|
||||
response *api.PerformStoreAsyncResponse,
|
||||
) error {
|
||||
return httputil.CallInternalRPCAPI(
|
||||
"PerformStoreAsync", h.federationAPIURL+FederationAPIPerformStoreAsyncPath,
|
||||
h.httpClient, ctx, request, response,
|
||||
)
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) QueryAsyncTransactions(
|
||||
ctx context.Context,
|
||||
request *api.QueryAsyncTransactionsRequest,
|
||||
response *api.QueryAsyncTransactionsResponse,
|
||||
) error {
|
||||
return httputil.CallInternalRPCAPI(
|
||||
"QueryAsyncTransactions", h.federationAPIURL+FederationAPIQueryAsyncTransactionsPath,
|
||||
h.httpClient, ctx, request, response,
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,16 +43,6 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
|
|||
httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU),
|
||||
)
|
||||
|
||||
internalAPIMux.Handle(
|
||||
FederationAPIPerformStoreAsyncPath,
|
||||
httputil.MakeInternalRPCAPI("FederationAPIPerformStoreAsync", intAPI.PerformStoreAsync),
|
||||
)
|
||||
|
||||
internalAPIMux.Handle(
|
||||
FederationAPIQueryAsyncTransactionsPath,
|
||||
httputil.MakeInternalRPCAPI("FederationAPIQueryAsyncTransactions", intAPI.QueryAsyncTransactions),
|
||||
)
|
||||
|
||||
internalAPIMux.Handle(
|
||||
FederationAPIPerformWakeupServers,
|
||||
httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", intAPI.PerformWakeupServers),
|
||||
|
|
|
|||
|
|
@ -1,199 +0,0 @@
|
|||
package routing_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/internal"
|
||||
"github.com/matrix-org/dendrite/federationapi/routing"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const (
|
||||
testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
|
||||
testDestination = gomatrixserverlib.ServerName("white.orchard")
|
||||
)
|
||||
|
||||
type testDatabase struct {
|
||||
nid int64
|
||||
nidMutex sync.Mutex
|
||||
transactions map[int64]json.RawMessage
|
||||
associations map[gomatrixserverlib.ServerName][]int64
|
||||
}
|
||||
|
||||
func createDatabase() *testDatabase {
|
||||
return &testDatabase{
|
||||
nid: 1,
|
||||
nidMutex: sync.Mutex{},
|
||||
transactions: make(map[int64]json.RawMessage),
|
||||
associations: make(map[gomatrixserverlib.ServerName][]int64),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *testDatabase) InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error {
|
||||
if _, ok := d.associations[serverName]; !ok {
|
||||
d.associations[serverName] = []int64{}
|
||||
}
|
||||
d.associations[serverName] = append(d.associations[serverName], nid)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *testDatabase) DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error {
|
||||
for _, nid := range jsonNIDs {
|
||||
for index, associatedNID := range d.associations[serverName] {
|
||||
if associatedNID == nid {
|
||||
d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *testDatabase) SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) {
|
||||
results := []int64{}
|
||||
resultCount := limit
|
||||
if limit > len(d.associations[serverName]) {
|
||||
resultCount = len(d.associations[serverName])
|
||||
}
|
||||
if resultCount > 0 {
|
||||
for i := 0; i < resultCount; i++ {
|
||||
results = append(results, d.associations[serverName][i])
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (d *testDatabase) SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) {
|
||||
return int64(len(d.associations[serverName])), nil
|
||||
}
|
||||
|
||||
func (d *testDatabase) InsertTransactionJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) {
|
||||
d.nidMutex.Lock()
|
||||
defer d.nidMutex.Unlock()
|
||||
|
||||
nid := d.nid
|
||||
d.transactions[nid] = []byte(json)
|
||||
d.nid++
|
||||
|
||||
return nid, nil
|
||||
}
|
||||
|
||||
func (d *testDatabase) DeleteTransactionJSON(ctx context.Context, txn *sql.Tx, nids []int64) error {
|
||||
for _, nid := range nids {
|
||||
delete(d.transactions, nid)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *testDatabase) SelectTransactionJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) {
|
||||
result := make(map[int64][]byte)
|
||||
for _, nid := range jsonNIDs {
|
||||
if transaction, ok := d.transactions[nid]; ok {
|
||||
result[nid] = transaction
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func createTransaction() gomatrixserverlib.Transaction {
|
||||
txn := gomatrixserverlib.Transaction{}
|
||||
txn.PDUs = []json.RawMessage{
|
||||
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`),
|
||||
}
|
||||
txn.Origin = testOrigin
|
||||
txn.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
|
||||
txn.Destination = testDestination
|
||||
return txn
|
||||
}
|
||||
|
||||
func createFederationRequest(userID gomatrixserverlib.UserID) (gomatrixserverlib.Transaction, gomatrixserverlib.FederationRequest) {
|
||||
txn := createTransaction()
|
||||
var federationPathPrefixV1 = "/_matrix/federation/v1"
|
||||
path := federationPathPrefixV1 + "/forward_async/" + string(txn.TransactionID) + "/" + userID.Raw()
|
||||
request := gomatrixserverlib.NewFederationRequest("PUT", txn.Origin, txn.Destination, path)
|
||||
request.SetContent(txn)
|
||||
|
||||
return txn, request
|
||||
}
|
||||
|
||||
func TestEmptyForwardReturnsOk(t *testing.T) {
|
||||
testDB := createDatabase()
|
||||
db := shared.Database{
|
||||
Writer: sqlutil.NewDummyWriter(),
|
||||
FederationQueueTransactions: testDB,
|
||||
FederationTransactionJSON: testDB,
|
||||
}
|
||||
httpReq := &http.Request{}
|
||||
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
|
||||
if err != nil {
|
||||
t.Fatalf("Invalid userID: %s", err.Error())
|
||||
}
|
||||
_, request := createFederationRequest(*userID)
|
||||
|
||||
fedAPI := internal.NewFederationInternalAPI(
|
||||
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
|
||||
response := routing.ForwardAsync(httpReq, &request, fedAPI, "1", *userID)
|
||||
|
||||
expected := 200
|
||||
if response.Code != expected {
|
||||
t.Fatalf("Expected: %v, Actual: %v", expected, response.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUniqueTransactionStoredInDatabase(t *testing.T) {
|
||||
testDB := createDatabase()
|
||||
db := shared.Database{
|
||||
Writer: sqlutil.NewDummyWriter(),
|
||||
FederationQueueTransactions: testDB,
|
||||
FederationTransactionJSON: testDB,
|
||||
}
|
||||
|
||||
httpReq := &http.Request{}
|
||||
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
|
||||
if err != nil {
|
||||
t.Fatalf("Invalid userID: %s", err.Error())
|
||||
}
|
||||
inputTransaction, request := createFederationRequest(*userID)
|
||||
|
||||
fedAPI := internal.NewFederationInternalAPI(
|
||||
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
|
||||
response := routing.ForwardAsync(
|
||||
httpReq, &request, fedAPI, inputTransaction.TransactionID, *userID)
|
||||
transaction, _, err := db.GetAsyncTransaction(context.TODO(), *userID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving transaction: %s", err.Error())
|
||||
}
|
||||
transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
|
||||
}
|
||||
|
||||
expected := 200
|
||||
if response.Code != expected {
|
||||
t.Fatalf("Expected Return Code: %v, Actual: %v", expected, response.Code)
|
||||
}
|
||||
if transactionCount != 1 {
|
||||
t.Fatalf("Expected count of 1, Actual: %d", transactionCount)
|
||||
}
|
||||
if transaction.TransactionID != inputTransaction.TransactionID {
|
||||
t.Fatalf("Expected Transaction ID: %s, Actual: %s",
|
||||
inputTransaction.TransactionID, transaction.TransactionID)
|
||||
}
|
||||
}
|
||||
|
|
@ -133,37 +133,6 @@ func Setup(
|
|||
},
|
||||
)).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
v1fedmux.Handle("/forward_async/{txnID}/{userID}", MakeFedAPI(
|
||||
"federation_forward_async", "", cfg.Matrix.IsLocalServerName, keys, wakeup,
|
||||
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
|
||||
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername("Username was invalid"),
|
||||
}
|
||||
}
|
||||
return ForwardAsync(
|
||||
httpReq, request, fsAPI, gomatrixserverlib.TransactionID(vars["txnID"]),
|
||||
*userID,
|
||||
)
|
||||
},
|
||||
)).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
v1fedmux.Handle("/async_events/{userID}", MakeFedAPI(
|
||||
"federation_async_events", "", cfg.Matrix.IsLocalServerName, keys, wakeup,
|
||||
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
|
||||
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername("Username was invalid"),
|
||||
}
|
||||
}
|
||||
return GetAsyncEvents(httpReq, request, fsAPI, *userID)
|
||||
},
|
||||
)).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI(
|
||||
"federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
|
||||
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
|
||||
|
|
|
|||
|
|
@ -51,12 +51,6 @@ type Database interface {
|
|||
GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||
|
||||
StoreAsyncTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*shared.Receipt, error)
|
||||
AssociateAsyncTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *shared.Receipt) error
|
||||
CleanAsyncTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*shared.Receipt) error
|
||||
GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *shared.Receipt, error)
|
||||
GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error)
|
||||
|
||||
// these don't have contexts passed in as we want things to happen regardless of the request context
|
||||
AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error
|
||||
RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error
|
||||
|
|
|
|||
|
|
@ -58,18 +58,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queueTransactions, err := NewPostgresQueueTransactionsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queueJSON, err := NewPostgresQueueJSONTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
transactionJSON, err := NewPostgresTransactionJSONTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
assumedOffline, err := NewPostgresAssumedOfflineTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -111,24 +103,22 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
return nil, err
|
||||
}
|
||||
d.Database = shared.Database{
|
||||
DB: d.db,
|
||||
IsLocalServerName: isLocalServerName,
|
||||
Cache: cache,
|
||||
Writer: d.writer,
|
||||
FederationJoinedHosts: joinedHosts,
|
||||
FederationQueuePDUs: queuePDUs,
|
||||
FederationQueueEDUs: queueEDUs,
|
||||
FederationQueueJSON: queueJSON,
|
||||
FederationQueueTransactions: queueTransactions,
|
||||
FederationTransactionJSON: transactionJSON,
|
||||
FederationBlacklist: blacklist,
|
||||
FederationAssumedOffline: assumedOffline,
|
||||
FederationRelayServers: relayServers,
|
||||
FederationInboundPeeks: inboundPeeks,
|
||||
FederationOutboundPeeks: outboundPeeks,
|
||||
NotaryServerKeysJSON: notaryJSON,
|
||||
NotaryServerKeysMetadata: notaryMetadata,
|
||||
ServerSigningKeys: serverSigningKeys,
|
||||
DB: d.db,
|
||||
IsLocalServerName: isLocalServerName,
|
||||
Cache: cache,
|
||||
Writer: d.writer,
|
||||
FederationJoinedHosts: joinedHosts,
|
||||
FederationQueuePDUs: queuePDUs,
|
||||
FederationQueueEDUs: queueEDUs,
|
||||
FederationQueueJSON: queueJSON,
|
||||
FederationBlacklist: blacklist,
|
||||
FederationAssumedOffline: assumedOffline,
|
||||
FederationRelayServers: relayServers,
|
||||
FederationInboundPeeks: inboundPeeks,
|
||||
FederationOutboundPeeks: outboundPeeks,
|
||||
NotaryServerKeysJSON: notaryJSON,
|
||||
NotaryServerKeysMetadata: notaryMetadata,
|
||||
ServerSigningKeys: serverSigningKeys,
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ package shared
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
|
@ -29,24 +28,22 @@ import (
|
|||
)
|
||||
|
||||
type Database struct {
|
||||
DB *sql.DB
|
||||
IsLocalServerName func(gomatrixserverlib.ServerName) bool
|
||||
Cache caching.FederationCache
|
||||
Writer sqlutil.Writer
|
||||
FederationQueuePDUs tables.FederationQueuePDUs
|
||||
FederationQueueEDUs tables.FederationQueueEDUs
|
||||
FederationQueueJSON tables.FederationQueueJSON
|
||||
FederationJoinedHosts tables.FederationJoinedHosts
|
||||
FederationBlacklist tables.FederationBlacklist
|
||||
FederationAssumedOffline tables.FederationAssumedOffline
|
||||
FederationRelayServers tables.FederationRelayServers
|
||||
FederationOutboundPeeks tables.FederationOutboundPeeks
|
||||
FederationInboundPeeks tables.FederationInboundPeeks
|
||||
NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON
|
||||
NotaryServerKeysMetadata tables.FederationNotaryServerKeysMetadata
|
||||
ServerSigningKeys tables.FederationServerSigningKeys
|
||||
FederationQueueTransactions tables.FederationQueueTransactions
|
||||
FederationTransactionJSON tables.FederationTransactionJSON
|
||||
DB *sql.DB
|
||||
IsLocalServerName func(gomatrixserverlib.ServerName) bool
|
||||
Cache caching.FederationCache
|
||||
Writer sqlutil.Writer
|
||||
FederationQueuePDUs tables.FederationQueuePDUs
|
||||
FederationQueueEDUs tables.FederationQueueEDUs
|
||||
FederationQueueJSON tables.FederationQueueJSON
|
||||
FederationJoinedHosts tables.FederationJoinedHosts
|
||||
FederationBlacklist tables.FederationBlacklist
|
||||
FederationAssumedOffline tables.FederationAssumedOffline
|
||||
FederationRelayServers tables.FederationRelayServers
|
||||
FederationOutboundPeeks tables.FederationOutboundPeeks
|
||||
FederationInboundPeeks tables.FederationInboundPeeks
|
||||
NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON
|
||||
NotaryServerKeysMetadata tables.FederationNotaryServerKeysMetadata
|
||||
ServerSigningKeys tables.FederationServerSigningKeys
|
||||
}
|
||||
|
||||
// An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs.
|
||||
|
|
@ -61,6 +58,10 @@ func NewReceipt(nid int64) Receipt {
|
|||
return Receipt{nid: nid}
|
||||
}
|
||||
|
||||
func (r *Receipt) GetNID() int64 {
|
||||
return r.nid
|
||||
}
|
||||
|
||||
func (r *Receipt) String() string {
|
||||
return fmt.Sprintf("%d", r.nid)
|
||||
}
|
||||
|
|
@ -308,108 +309,3 @@ func (d *Database) GetNotaryKeys(
|
|||
})
|
||||
return sks, err
|
||||
}
|
||||
|
||||
func (d *Database) StoreAsyncTransaction(
|
||||
ctx context.Context, txn gomatrixserverlib.Transaction,
|
||||
) (*Receipt, error) {
|
||||
var err error
|
||||
json, err := json.Marshal(txn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.JSONUnmarshall: %w", err)
|
||||
}
|
||||
|
||||
var nid int64
|
||||
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
nid, err = d.FederationTransactionJSON.InsertTransactionJSON(ctx, txn, string(json))
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.insertTransactionJSON: %w", err)
|
||||
}
|
||||
return &Receipt{
|
||||
nid: nid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Database) AssociateAsyncTransactionWithDestinations(
|
||||
ctx context.Context,
|
||||
destinations map[gomatrixserverlib.UserID]struct{},
|
||||
transactionID gomatrixserverlib.TransactionID,
|
||||
receipt *Receipt,
|
||||
) error {
|
||||
for destination := range destinations {
|
||||
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
err := d.FederationQueueTransactions.InsertQueueTransaction(
|
||||
ctx, txn, transactionID, destination.Domain(), receipt.nid)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("d.insertQueueTransaction: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) CleanAsyncTransactions(
|
||||
ctx context.Context,
|
||||
userID gomatrixserverlib.UserID,
|
||||
receipts []*Receipt,
|
||||
) error {
|
||||
println(len(receipts))
|
||||
nids := make([]int64, len(receipts))
|
||||
for i, receipt := range receipts {
|
||||
nids[i] = receipt.nid
|
||||
}
|
||||
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
err := d.FederationQueueTransactions.DeleteQueueTransactions(ctx, txn, userID.Domain(), nids)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("d.insertQueueTransaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) GetAsyncTransaction(
|
||||
ctx context.Context,
|
||||
userID gomatrixserverlib.UserID,
|
||||
) (*gomatrixserverlib.Transaction, *Receipt, error) {
|
||||
nids, err := d.FederationQueueTransactions.SelectQueueTransactions(ctx, nil, userID.Domain(), 1)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("d.SelectQueueTransaction: %w", err)
|
||||
}
|
||||
if len(nids) == 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
txns := map[int64][]byte{}
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
txns, err = d.FederationTransactionJSON.SelectTransactionJSON(ctx, txn, nids)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("d.SelectTransactionJSON: %w", err)
|
||||
}
|
||||
|
||||
transaction := &gomatrixserverlib.Transaction{}
|
||||
err = json.Unmarshal(txns[nids[0]], transaction)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("Unmarshall transaction: %w", err)
|
||||
}
|
||||
|
||||
receipt := NewReceipt(nids[0])
|
||||
return transaction, &receipt, nil
|
||||
}
|
||||
|
||||
func (d *Database) GetAsyncTransactionCount(
|
||||
ctx context.Context,
|
||||
userID gomatrixserverlib.UserID,
|
||||
) (int64, error) {
|
||||
count, err := d.FederationQueueTransactions.SelectQueueTransactionCount(ctx, nil, userID.Domain())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("d.SelectQueueTransactionCount: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,14 +55,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queueTransactions, err := NewSQLiteQueueTransactionsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
transactionJSON, err := NewSQLiteTransactionJSONTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
assumedOffline, err := NewSQLiteAssumedOfflineTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -104,24 +96,22 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
return nil, err
|
||||
}
|
||||
d.Database = shared.Database{
|
||||
DB: d.db,
|
||||
IsLocalServerName: isLocalServerName,
|
||||
Cache: cache,
|
||||
Writer: d.writer,
|
||||
FederationJoinedHosts: joinedHosts,
|
||||
FederationQueuePDUs: queuePDUs,
|
||||
FederationQueueEDUs: queueEDUs,
|
||||
FederationQueueJSON: queueJSON,
|
||||
FederationQueueTransactions: queueTransactions,
|
||||
FederationTransactionJSON: transactionJSON,
|
||||
FederationBlacklist: blacklist,
|
||||
FederationAssumedOffline: assumedOffline,
|
||||
FederationRelayServers: relayServers,
|
||||
FederationOutboundPeeks: outboundPeeks,
|
||||
FederationInboundPeeks: inboundPeeks,
|
||||
NotaryServerKeysJSON: notaryKeys,
|
||||
NotaryServerKeysMetadata: notaryKeysMetadata,
|
||||
ServerSigningKeys: serverSigningKeys,
|
||||
DB: d.db,
|
||||
IsLocalServerName: isLocalServerName,
|
||||
Cache: cache,
|
||||
Writer: d.writer,
|
||||
FederationJoinedHosts: joinedHosts,
|
||||
FederationQueuePDUs: queuePDUs,
|
||||
FederationQueueEDUs: queueEDUs,
|
||||
FederationQueueJSON: queueJSON,
|
||||
FederationBlacklist: blacklist,
|
||||
FederationAssumedOffline: assumedOffline,
|
||||
FederationRelayServers: relayServers,
|
||||
FederationOutboundPeeks: outboundPeeks,
|
||||
FederationInboundPeeks: inboundPeeks,
|
||||
NotaryServerKeysJSON: notaryKeys,
|
||||
NotaryServerKeysMetadata: notaryKeysMetadata,
|
||||
ServerSigningKeys: serverSigningKeys,
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
|
|
|
|||
81
relayapi/api/api.go
Normal file
81
relayapi/api/api.go
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// RelayInternalAPI is used to query information from the relay server.
|
||||
type RelayInternalAPI interface {
|
||||
RelayServerAPI
|
||||
|
||||
PerformRelayServerSync(
|
||||
ctx context.Context,
|
||||
request *PerformRelayServerSyncRequest,
|
||||
response *PerformRelayServerSyncResponse,
|
||||
) error
|
||||
}
|
||||
|
||||
type RelayServerAPI interface {
|
||||
// Store async transactions for forwarding to the destination at a later time.
|
||||
PerformStoreAsync(
|
||||
ctx context.Context,
|
||||
request *PerformStoreAsyncRequest,
|
||||
response *PerformStoreAsyncResponse,
|
||||
) error
|
||||
|
||||
// Obtain the oldest stored transaction for the specified userID.
|
||||
QueryAsyncTransactions(
|
||||
ctx context.Context,
|
||||
request *QueryAsyncTransactionsRequest,
|
||||
response *QueryAsyncTransactionsResponse,
|
||||
) error
|
||||
}
|
||||
|
||||
type PerformRelayServerSyncRequest struct {
|
||||
UserID gomatrixserverlib.UserID `json:"user_id"`
|
||||
RelayServer gomatrixserverlib.ServerName `json:"relay_name"`
|
||||
}
|
||||
|
||||
type PerformRelayServerSyncResponse struct {
|
||||
}
|
||||
|
||||
type QueryRelayServersRequest struct {
|
||||
Server gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type QueryRelayServersResponse struct {
|
||||
RelayServers []gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type PerformStoreAsyncRequest struct {
|
||||
Txn gomatrixserverlib.Transaction `json:"transaction"`
|
||||
UserID gomatrixserverlib.UserID `json:"user_id"`
|
||||
}
|
||||
|
||||
type PerformStoreAsyncResponse struct {
|
||||
}
|
||||
|
||||
type QueryAsyncTransactionsRequest struct {
|
||||
UserID gomatrixserverlib.UserID `json:"user_id"`
|
||||
}
|
||||
|
||||
type QueryAsyncTransactionsResponse struct {
|
||||
Txn gomatrixserverlib.Transaction `json:"transaction"`
|
||||
RemainingCount uint32 `json:"remaining"`
|
||||
}
|
||||
52
relayapi/internal/api.go
Normal file
52
relayapi/internal/api.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/federationapi/producers"
|
||||
"github.com/matrix-org/dendrite/relayapi/storage"
|
||||
rsAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type RelayInternalAPI struct {
|
||||
db storage.Database
|
||||
fedClient *gomatrixserverlib.FederationClient
|
||||
rsAPI rsAPI.RoomserverInternalAPI
|
||||
keyRing *gomatrixserverlib.KeyRing
|
||||
producer *producers.SyncAPIProducer
|
||||
presenceEnabledInbound bool
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func NewRelayInternalAPI(
|
||||
db storage.Database,
|
||||
fedClient *gomatrixserverlib.FederationClient,
|
||||
rsAPI rsAPI.RoomserverInternalAPI,
|
||||
keyRing *gomatrixserverlib.KeyRing,
|
||||
producer *producers.SyncAPIProducer,
|
||||
presenceEnabledInbound bool,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) RelayInternalAPI {
|
||||
return RelayInternalAPI{
|
||||
db: db,
|
||||
fedClient: fedClient,
|
||||
rsAPI: rsAPI,
|
||||
keyRing: keyRing,
|
||||
producer: producer,
|
||||
presenceEnabledInbound: presenceEnabledInbound,
|
||||
serverName: serverName,
|
||||
}
|
||||
}
|
||||
131
relayapi/internal/perform.go
Normal file
131
relayapi/internal/perform.go
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/relayapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// PerformRelayServerSync implements api.FederationInternalAPI
|
||||
func (r *RelayInternalAPI) PerformRelayServerSync(
|
||||
ctx context.Context,
|
||||
request *api.PerformRelayServerSyncRequest,
|
||||
response *api.PerformRelayServerSyncResponse,
|
||||
) error {
|
||||
asyncResponse, err := r.fedClient.GetAsyncEvents(ctx, request.UserID, request.RelayServer)
|
||||
if err != nil {
|
||||
logrus.Errorf("GetAsyncEvents: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
r.processTransaction(&asyncResponse.Transaction)
|
||||
|
||||
for asyncResponse.Remaining > 0 {
|
||||
asyncResponse, err := r.fedClient.GetAsyncEvents(ctx, request.UserID, request.RelayServer)
|
||||
if err != nil {
|
||||
logrus.Errorf("GetAsyncEvents: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
r.processTransaction(&asyncResponse.Transaction)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PerformStoreAsync implements api.RelayInternalAPI
|
||||
func (r *RelayInternalAPI) PerformStoreAsync(
|
||||
ctx context.Context,
|
||||
request *api.PerformStoreAsyncRequest,
|
||||
response *api.PerformStoreAsyncResponse,
|
||||
) error {
|
||||
logrus.Warnf("Storing transaction for %v", request.UserID)
|
||||
receipt, err := r.db.StoreAsyncTransaction(ctx, request.Txn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.db.AssociateAsyncTransactionWithDestinations(
|
||||
ctx,
|
||||
map[gomatrixserverlib.UserID]struct{}{
|
||||
request.UserID: {},
|
||||
},
|
||||
request.Txn.TransactionID,
|
||||
receipt)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// QueryAsyncTransactions implements api.RelayInternalAPI
|
||||
func (r *RelayInternalAPI) QueryAsyncTransactions(
|
||||
ctx context.Context,
|
||||
request *api.QueryAsyncTransactionsRequest,
|
||||
response *api.QueryAsyncTransactionsResponse,
|
||||
) error {
|
||||
logrus.Warnf("Obtaining transaction for %v", request.UserID)
|
||||
transaction, receipt, err := r.db.GetAsyncTransaction(ctx, request.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO : Shouldn't be deleting unless the transaction was successfully returned...
|
||||
// TODO : Should delete transaction json from table if no more associations
|
||||
// Maybe track last received transaction, and send that as part of the request,
|
||||
// then delete before getting the new events from the db.
|
||||
if transaction != nil && receipt != nil {
|
||||
err = r.db.CleanAsyncTransactions(ctx, request.UserID, []*shared.Receipt{receipt})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO : Clean async transactions json
|
||||
}
|
||||
|
||||
// TODO : These db calls should happen at the same time right?
|
||||
count, err := r.db.GetAsyncTransactionCount(ctx, request.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response.RemainingCount = uint32(count)
|
||||
if transaction != nil {
|
||||
response.Txn = *transaction
|
||||
logrus.Warnf("Obtained transaction: %v", transaction.TransactionID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RelayInternalAPI) processTransaction(txn *gomatrixserverlib.Transaction) {
|
||||
logrus.Warn("Processing transaction from relay server")
|
||||
mu := internal.NewMutexByRoom()
|
||||
t := internal.NewTxnReq(
|
||||
r.rsAPI,
|
||||
nil,
|
||||
r.serverName,
|
||||
r.keyRing,
|
||||
mu,
|
||||
r.producer,
|
||||
r.presenceEnabledInbound,
|
||||
txn.PDUs,
|
||||
txn.EDUs,
|
||||
txn.Origin,
|
||||
txn.TransactionID,
|
||||
txn.Destination)
|
||||
|
||||
t.ProcessTransaction(context.TODO())
|
||||
}
|
||||
70
relayapi/inthttp/client.go
Normal file
70
relayapi/inthttp/client.go
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
package inthttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
"github.com/matrix-org/dendrite/relayapi/api"
|
||||
)
|
||||
|
||||
// HTTP paths for the internal HTTP API
|
||||
const (
|
||||
RelayAPIPerformRelayServerSyncPath = "/relayapi/performRelayServerSync"
|
||||
RelayAPIPerformStoreAsyncPath = "/relayapi/performStoreAsync"
|
||||
RelayAPIQueryAsyncTransactionsPath = "/relayapi/queryAsyncTransactions"
|
||||
)
|
||||
|
||||
// NewRelayAPIClient creates a RelayInternalAPI implemented by talking to a HTTP POST API.
|
||||
// If httpClient is nil an error is returned
|
||||
func NewRelayAPIClient(relayapiURL string, httpClient *http.Client, cache caching.ServerKeyCache) (api.RelayInternalAPI, error) {
|
||||
if httpClient == nil {
|
||||
return nil, errors.New("NewRelayInternalAPIHTTP: httpClient is <nil>")
|
||||
}
|
||||
return &httpRelayInternalAPI{
|
||||
relayAPIURL: relayapiURL,
|
||||
httpClient: httpClient,
|
||||
cache: cache,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type httpRelayInternalAPI struct {
|
||||
relayAPIURL string
|
||||
httpClient *http.Client
|
||||
cache caching.ServerKeyCache
|
||||
}
|
||||
|
||||
func (h *httpRelayInternalAPI) PerformRelayServerSync(
|
||||
ctx context.Context,
|
||||
request *api.PerformRelayServerSyncRequest,
|
||||
response *api.PerformRelayServerSyncResponse,
|
||||
) error {
|
||||
return httputil.CallInternalRPCAPI(
|
||||
"PerformRelayServerSync", h.relayAPIURL+RelayAPIPerformRelayServerSyncPath,
|
||||
h.httpClient, ctx, request, response,
|
||||
)
|
||||
}
|
||||
|
||||
func (h *httpRelayInternalAPI) PerformStoreAsync(
|
||||
ctx context.Context,
|
||||
request *api.PerformStoreAsyncRequest,
|
||||
response *api.PerformStoreAsyncResponse,
|
||||
) error {
|
||||
return httputil.CallInternalRPCAPI(
|
||||
"PerformStoreAsync", h.relayAPIURL+RelayAPIPerformStoreAsyncPath,
|
||||
h.httpClient, ctx, request, response,
|
||||
)
|
||||
}
|
||||
|
||||
func (h *httpRelayInternalAPI) QueryAsyncTransactions(
|
||||
ctx context.Context,
|
||||
request *api.QueryAsyncTransactionsRequest,
|
||||
response *api.QueryAsyncTransactionsResponse,
|
||||
) error {
|
||||
return httputil.CallInternalRPCAPI(
|
||||
"QueryAsyncTransactions", h.relayAPIURL+RelayAPIQueryAsyncTransactionsPath,
|
||||
h.httpClient, ctx, request, response,
|
||||
)
|
||||
}
|
||||
27
relayapi/inthttp/server.go
Normal file
27
relayapi/inthttp/server.go
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
package inthttp
|
||||
|
||||
import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
"github.com/matrix-org/dendrite/relayapi/api"
|
||||
)
|
||||
|
||||
// AddRoutes adds the RelayInternalAPI handlers to the http.ServeMux.
|
||||
// nolint:gocyclo
|
||||
func AddRoutes(intAPI api.RelayInternalAPI, internalAPIMux *mux.Router) {
|
||||
internalAPIMux.Handle(
|
||||
RelayAPIPerformRelayServerSyncPath,
|
||||
httputil.MakeInternalRPCAPI("RelayAPIPerformRelayServerSync", intAPI.PerformRelayServerSync),
|
||||
)
|
||||
|
||||
internalAPIMux.Handle(
|
||||
RelayAPIPerformStoreAsyncPath,
|
||||
httputil.MakeInternalRPCAPI("RelayAPIPerformStoreAsync", intAPI.PerformStoreAsync),
|
||||
)
|
||||
|
||||
internalAPIMux.Handle(
|
||||
RelayAPIQueryAsyncTransactionsPath,
|
||||
httputil.MakeInternalRPCAPI("RelayAPIQueryAsyncTransactions", intAPI.QueryAsyncTransactions),
|
||||
)
|
||||
}
|
||||
|
|
@ -15,79 +15,77 @@
|
|||
package relayapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/federationapi/producers"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/relayapi/api"
|
||||
relayAPI "github.com/matrix-org/dendrite/relayapi/api"
|
||||
"github.com/matrix-org/dendrite/relayapi/internal"
|
||||
"github.com/matrix-org/dendrite/relayapi/inthttp"
|
||||
"github.com/matrix-org/dendrite/relayapi/routing"
|
||||
"github.com/matrix-org/dendrite/relayapi/storage"
|
||||
rsAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type RelayAPI struct {
|
||||
fedClient *gomatrixserverlib.FederationClient
|
||||
rsAPI rsAPI.RoomserverInternalAPI
|
||||
keyRing *gomatrixserverlib.KeyRing
|
||||
producer *producers.SyncAPIProducer
|
||||
presenceEnabledInbound bool
|
||||
serverName gomatrixserverlib.ServerName
|
||||
// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions
|
||||
// on the given input API.
|
||||
func AddInternalRoutes(router *mux.Router, intAPI api.RelayInternalAPI) {
|
||||
inthttp.AddRoutes(intAPI, router)
|
||||
}
|
||||
|
||||
func NewRelayAPI(
|
||||
// AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component.
|
||||
func AddPublicRoutes(
|
||||
base *base.BaseDendrite,
|
||||
userAPI userapi.UserInternalAPI,
|
||||
fedClient *gomatrixserverlib.FederationClient,
|
||||
keyRing gomatrixserverlib.JSONVerifier,
|
||||
rsAPI rsAPI.FederationRoomserverAPI,
|
||||
relayAPI relayAPI.RelayInternalAPI,
|
||||
fedAPI federationAPI.FederationInternalAPI,
|
||||
keyAPI keyserverAPI.FederationKeyAPI,
|
||||
) {
|
||||
fedCfg := &base.Cfg.FederationAPI
|
||||
|
||||
relay, ok := relayAPI.(*internal.RelayInternalAPI)
|
||||
if !ok {
|
||||
panic("relayapi.AddPublicRoutes called with a RelayInternalAPI impl which was not " +
|
||||
"RelayInternalAPI. This is a programming error.")
|
||||
}
|
||||
|
||||
routing.Setup(
|
||||
base.PublicFederationAPIMux,
|
||||
fedCfg,
|
||||
relay,
|
||||
keyRing,
|
||||
)
|
||||
}
|
||||
|
||||
func NewRelayInternalAPI(
|
||||
base *base.BaseDendrite,
|
||||
fedClient *gomatrixserverlib.FederationClient,
|
||||
rsAPI rsAPI.RoomserverInternalAPI,
|
||||
keyRing *gomatrixserverlib.KeyRing,
|
||||
producer *producers.SyncAPIProducer,
|
||||
presenceEnabledInbound bool,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) RelayAPI {
|
||||
return RelayAPI{
|
||||
fedClient: fedClient,
|
||||
rsAPI: rsAPI,
|
||||
keyRing: keyRing,
|
||||
producer: producer,
|
||||
presenceEnabledInbound: presenceEnabledInbound,
|
||||
serverName: serverName,
|
||||
}
|
||||
}
|
||||
) internal.RelayInternalAPI {
|
||||
cfg := &base.Cfg.RelayAPI
|
||||
|
||||
// PerformRelayServerSync implements api.FederationInternalAPI
|
||||
func (r *RelayAPI) PerformRelayServerSync(userID gomatrixserverlib.UserID, relayServer gomatrixserverlib.ServerName) error {
|
||||
asyncResponse, err := r.fedClient.GetAsyncEvents(context.Background(), userID, relayServer)
|
||||
relayDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName)
|
||||
if err != nil {
|
||||
logrus.Errorf("GetAsyncEvents: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
r.processTransaction(&asyncResponse.Transaction)
|
||||
|
||||
for asyncResponse.Remaining > 0 {
|
||||
asyncResponse, err := r.fedClient.GetAsyncEvents(context.Background(), userID, relayServer)
|
||||
if err != nil {
|
||||
logrus.Errorf("GetAsyncEvents: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
r.processTransaction(&asyncResponse.Transaction)
|
||||
logrus.WithError(err).Panic("failed to connect to relay db")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RelayAPI) processTransaction(txn *gomatrixserverlib.Transaction) {
|
||||
logrus.Warn("Processing transaction from relay server")
|
||||
mu := internal.NewMutexByRoom()
|
||||
t := internal.NewTxnReq(
|
||||
r.rsAPI,
|
||||
nil,
|
||||
r.serverName,
|
||||
r.keyRing,
|
||||
mu,
|
||||
r.producer,
|
||||
r.presenceEnabledInbound,
|
||||
txn.PDUs,
|
||||
txn.EDUs,
|
||||
txn.Origin,
|
||||
txn.TransactionID,
|
||||
txn.Destination)
|
||||
|
||||
t.ProcessTransaction(context.TODO())
|
||||
return internal.NewRelayInternalAPI(
|
||||
relayDB,
|
||||
fedClient,
|
||||
rsAPI,
|
||||
keyRing,
|
||||
producer,
|
||||
base.Cfg.Global.Presence.EnableInbound,
|
||||
base.Cfg.Global.ServerName,
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ package routing
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/relayapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
|
@ -19,12 +19,12 @@ type AsyncEventsResponse struct {
|
|||
func GetAsyncEvents(
|
||||
httpReq *http.Request,
|
||||
fedReq *gomatrixserverlib.FederationRequest,
|
||||
fedAPI api.FederationInternalAPI,
|
||||
relayAPI api.RelayInternalAPI,
|
||||
userID gomatrixserverlib.UserID,
|
||||
) util.JSONResponse {
|
||||
logrus.Infof("Handling async_events for %v", userID)
|
||||
var response api.QueryAsyncTransactionsResponse
|
||||
err := fedAPI.QueryAsyncTransactions(httpReq.Context(), &api.QueryAsyncTransactionsRequest{UserID: userID}, &response)
|
||||
err := relayAPI.QueryAsyncTransactions(httpReq.Context(), &api.QueryAsyncTransactionsRequest{UserID: userID}, &response)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
|
|
@ -5,21 +5,21 @@ import (
|
|||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/internal"
|
||||
"github.com/matrix-org/dendrite/federationapi/routing"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/relayapi/internal"
|
||||
"github.com/matrix-org/dendrite/relayapi/routing"
|
||||
"github.com/matrix-org/dendrite/relayapi/storage"
|
||||
"github.com/matrix-org/dendrite/relayapi/storage/shared"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) {
|
||||
testDB := createDatabase()
|
||||
testDB := storage.NewFakeRelayDatabase()
|
||||
db := shared.Database{
|
||||
Writer: sqlutil.NewDummyWriter(),
|
||||
FederationQueueTransactions: testDB,
|
||||
FederationTransactionJSON: testDB,
|
||||
Writer: sqlutil.NewDummyWriter(),
|
||||
RelayQueue: testDB,
|
||||
RelayQueueJSON: testDB,
|
||||
}
|
||||
httpReq := &http.Request{}
|
||||
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
|
||||
|
|
@ -33,11 +33,11 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) {
|
|||
t.Fatalf("Failed to store transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
fedAPI := internal.NewFederationInternalAPI(
|
||||
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
|
||||
relayAPI := internal.NewRelayInternalAPI(
|
||||
&db, nil, nil, nil, nil, false, "",
|
||||
)
|
||||
|
||||
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
|
||||
response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID)
|
||||
assert.Equal(t, http.StatusOK, response.Code)
|
||||
|
||||
jsonResponse := response.JSON.(routing.AsyncEventsResponse)
|
||||
|
|
@ -46,11 +46,11 @@ func TestGetAsyncEmptyDatabaseReturnsNothing(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
|
||||
testDB := createDatabase()
|
||||
testDB := storage.NewFakeRelayDatabase()
|
||||
db := shared.Database{
|
||||
Writer: sqlutil.NewDummyWriter(),
|
||||
FederationQueueTransactions: testDB,
|
||||
FederationTransactionJSON: testDB,
|
||||
Writer: sqlutil.NewDummyWriter(),
|
||||
RelayQueue: testDB,
|
||||
RelayQueueJSON: testDB,
|
||||
}
|
||||
httpReq := &http.Request{}
|
||||
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
|
||||
|
|
@ -74,11 +74,11 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
|
|||
t.Fatalf("Failed to associate transaction with user: %s", err.Error())
|
||||
}
|
||||
|
||||
fedAPI := internal.NewFederationInternalAPI(
|
||||
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
|
||||
relayAPI := internal.NewRelayInternalAPI(
|
||||
&db, nil, nil, nil, nil, false, "",
|
||||
)
|
||||
|
||||
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
|
||||
response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID)
|
||||
assert.Equal(t, http.StatusOK, response.Code)
|
||||
|
||||
jsonResponse := response.JSON.(routing.AsyncEventsResponse)
|
||||
|
|
@ -87,11 +87,11 @@ func TestGetAsyncReturnsSavedTransaction(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) {
|
||||
testDB := createDatabase()
|
||||
testDB := storage.NewFakeRelayDatabase()
|
||||
db := shared.Database{
|
||||
Writer: sqlutil.NewDummyWriter(),
|
||||
FederationQueueTransactions: testDB,
|
||||
FederationTransactionJSON: testDB,
|
||||
Writer: sqlutil.NewDummyWriter(),
|
||||
RelayQueue: testDB,
|
||||
RelayQueueJSON: testDB,
|
||||
}
|
||||
httpReq := &http.Request{}
|
||||
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
|
||||
|
|
@ -131,18 +131,18 @@ func TestGetAsyncReturnsMultipleSavedTransactions(t *testing.T) {
|
|||
t.Fatalf("Failed to associate transaction with user: %s", err.Error())
|
||||
}
|
||||
|
||||
fedAPI := internal.NewFederationInternalAPI(
|
||||
&db, &config.FederationAPI{}, nil, nil, nil, nil, nil, nil,
|
||||
relayAPI := internal.NewRelayInternalAPI(
|
||||
&db, nil, nil, nil, nil, false, "",
|
||||
)
|
||||
|
||||
response := routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
|
||||
response := routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID)
|
||||
assert.Equal(t, http.StatusOK, response.Code)
|
||||
|
||||
jsonResponse := response.JSON.(routing.AsyncEventsResponse)
|
||||
assert.Equal(t, uint32(1), jsonResponse.Remaining)
|
||||
assert.Equal(t, transaction, jsonResponse.Transaction)
|
||||
|
||||
response = routing.GetAsyncEvents(httpReq, nil, fedAPI, *userID)
|
||||
response = routing.GetAsyncEvents(httpReq, nil, &relayAPI, *userID)
|
||||
assert.Equal(t, http.StatusOK, response.Code)
|
||||
|
||||
jsonResponse = response.JSON.(routing.AsyncEventsResponse)
|
||||
|
|
@ -5,7 +5,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/relayapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
|
@ -15,7 +15,7 @@ import (
|
|||
func ForwardAsync(
|
||||
httpReq *http.Request,
|
||||
fedReq *gomatrixserverlib.FederationRequest,
|
||||
fedAPI api.FederationInternalAPI,
|
||||
relayAPI api.RelayInternalAPI,
|
||||
txnID gomatrixserverlib.TransactionID,
|
||||
userID gomatrixserverlib.UserID,
|
||||
) util.JSONResponse {
|
||||
|
|
@ -54,7 +54,7 @@ func ForwardAsync(
|
|||
UserID: userID,
|
||||
}
|
||||
res := api.PerformStoreAsyncResponse{}
|
||||
err := fedAPI.PerformStoreAsync(httpReq.Context(), &req, &res)
|
||||
err := relayAPI.PerformStoreAsync(httpReq.Context(), &req, &res)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
111
relayapi/routing/forwardasync_test.go
Normal file
111
relayapi/routing/forwardasync_test.go
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
package routing_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/relayapi/internal"
|
||||
"github.com/matrix-org/dendrite/relayapi/routing"
|
||||
"github.com/matrix-org/dendrite/relayapi/storage"
|
||||
"github.com/matrix-org/dendrite/relayapi/storage/shared"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const (
|
||||
testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
|
||||
testDestination = gomatrixserverlib.ServerName("white.orchard")
|
||||
)
|
||||
|
||||
func createTransaction() gomatrixserverlib.Transaction {
|
||||
txn := gomatrixserverlib.Transaction{}
|
||||
txn.PDUs = []json.RawMessage{
|
||||
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`),
|
||||
}
|
||||
txn.Origin = testOrigin
|
||||
txn.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
|
||||
txn.Destination = testDestination
|
||||
return txn
|
||||
}
|
||||
|
||||
func createFederationRequest(userID gomatrixserverlib.UserID) (gomatrixserverlib.Transaction, gomatrixserverlib.FederationRequest) {
|
||||
txn := createTransaction()
|
||||
var federationPathPrefixV1 = "/_matrix/federation/v1"
|
||||
path := federationPathPrefixV1 + "/forward_async/" + string(txn.TransactionID) + "/" + userID.Raw()
|
||||
request := gomatrixserverlib.NewFederationRequest("PUT", txn.Origin, txn.Destination, path)
|
||||
request.SetContent(txn)
|
||||
|
||||
return txn, request
|
||||
}
|
||||
|
||||
func TestEmptyForwardReturnsOk(t *testing.T) {
|
||||
testDB := storage.NewFakeRelayDatabase()
|
||||
db := shared.Database{
|
||||
Writer: sqlutil.NewDummyWriter(),
|
||||
RelayQueue: testDB,
|
||||
RelayQueueJSON: testDB,
|
||||
}
|
||||
httpReq := &http.Request{}
|
||||
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
|
||||
if err != nil {
|
||||
t.Fatalf("Invalid userID: %s", err.Error())
|
||||
}
|
||||
_, request := createFederationRequest(*userID)
|
||||
|
||||
relayAPI := internal.NewRelayInternalAPI(
|
||||
&db, nil, nil, nil, nil, false, "",
|
||||
)
|
||||
|
||||
response := routing.ForwardAsync(httpReq, &request, &relayAPI, "1", *userID)
|
||||
|
||||
expected := 200
|
||||
if response.Code != expected {
|
||||
t.Fatalf("Expected: %v, Actual: %v", expected, response.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUniqueTransactionStoredInDatabase(t *testing.T) {
|
||||
testDB := storage.NewFakeRelayDatabase()
|
||||
db := shared.Database{
|
||||
Writer: sqlutil.NewDummyWriter(),
|
||||
RelayQueue: testDB,
|
||||
RelayQueueJSON: testDB,
|
||||
}
|
||||
httpReq := &http.Request{}
|
||||
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
|
||||
if err != nil {
|
||||
t.Fatalf("Invalid userID: %s", err.Error())
|
||||
}
|
||||
inputTransaction, request := createFederationRequest(*userID)
|
||||
|
||||
relayAPI := internal.NewRelayInternalAPI(
|
||||
&db, nil, nil, nil, nil, false, "",
|
||||
)
|
||||
|
||||
response := routing.ForwardAsync(
|
||||
httpReq, &request, &relayAPI, inputTransaction.TransactionID, *userID)
|
||||
transaction, _, err := db.GetAsyncTransaction(context.TODO(), *userID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving transaction: %s", err.Error())
|
||||
}
|
||||
transactionCount, err := db.GetAsyncTransactionCount(context.TODO(), *userID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
|
||||
}
|
||||
|
||||
expected := 200
|
||||
if response.Code != expected {
|
||||
t.Fatalf("Expected Return Code: %v, Actual: %v", expected, response.Code)
|
||||
}
|
||||
if transactionCount != 1 {
|
||||
t.Fatalf("Expected count of 1, Actual: %d", transactionCount)
|
||||
}
|
||||
if transaction.TransactionID != inputTransaction.TransactionID {
|
||||
t.Fatalf("Expected Transaction ID: %s, Actual: %s",
|
||||
inputTransaction.TransactionID, transaction.TransactionID)
|
||||
}
|
||||
}
|
||||
123
relayapi/routing/routing.go
Normal file
123
relayapi/routing/routing.go
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
relayInternal "github.com/matrix-org/dendrite/relayapi/internal"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// Setup registers HTTP handlers with the given ServeMux.
|
||||
// The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly
|
||||
// path unescape twice (once from the router, once from MakeRelayAPI). We need to have this enabled
|
||||
// so we can decode paths like foo/bar%2Fbaz as [foo, bar/baz] - by default it will decode to [foo, bar, baz]
|
||||
//
|
||||
// Due to Setup being used to call many other functions, a gocyclo nolint is
|
||||
// applied:
|
||||
// nolint: gocyclo
|
||||
func Setup(
|
||||
fedMux *mux.Router,
|
||||
cfg *config.FederationAPI,
|
||||
relayAPI *relayInternal.RelayInternalAPI,
|
||||
keys gomatrixserverlib.JSONVerifier,
|
||||
) {
|
||||
v1fedmux := fedMux.PathPrefix("/v1").Subrouter()
|
||||
|
||||
v1fedmux.Handle("/forward_async/{txnID}/{userID}", MakeRelayAPI(
|
||||
"relay_forward_async", "", cfg.Matrix.IsLocalServerName, keys,
|
||||
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
|
||||
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername("Username was invalid"),
|
||||
}
|
||||
}
|
||||
return ForwardAsync(
|
||||
httpReq, request, relayAPI, gomatrixserverlib.TransactionID(vars["txnID"]),
|
||||
*userID,
|
||||
)
|
||||
},
|
||||
)).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
v1fedmux.Handle("/async_events/{userID}", MakeRelayAPI(
|
||||
"relay_async_events", "", cfg.Matrix.IsLocalServerName, keys,
|
||||
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
|
||||
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername("Username was invalid"),
|
||||
}
|
||||
}
|
||||
return GetAsyncEvents(httpReq, request, relayAPI, *userID)
|
||||
},
|
||||
)).Methods(http.MethodGet, http.MethodOptions)
|
||||
}
|
||||
|
||||
// MakeRelayAPI makes an http.Handler that checks matrix relay authentication.
|
||||
func MakeRelayAPI(
|
||||
metricsName string, serverName gomatrixserverlib.ServerName,
|
||||
isLocalServerName func(gomatrixserverlib.ServerName) bool,
|
||||
keyRing gomatrixserverlib.JSONVerifier,
|
||||
f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse,
|
||||
) http.Handler {
|
||||
h := func(req *http.Request) util.JSONResponse {
|
||||
fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
|
||||
req, time.Now(), serverName, isLocalServerName, keyRing,
|
||||
)
|
||||
if fedReq == nil {
|
||||
return errResp
|
||||
}
|
||||
// add the user to Sentry, if enabled
|
||||
hub := sentry.GetHubFromContext(req.Context())
|
||||
if hub != nil {
|
||||
hub.Scope().SetTag("origin", string(fedReq.Origin()))
|
||||
hub.Scope().SetTag("uri", fedReq.RequestURI())
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if hub != nil {
|
||||
hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path))
|
||||
}
|
||||
// re-panic to return the 500
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.MatrixErrorResponse(400, "M_UNRECOGNISED", "badly encoded query params")
|
||||
}
|
||||
|
||||
jsonRes := f(req, fedReq, vars)
|
||||
// do not log 4xx as errors as they are client fails, not server fails
|
||||
if hub != nil && jsonRes.Code >= 500 {
|
||||
hub.Scope().SetExtra("response", jsonRes)
|
||||
hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code))
|
||||
}
|
||||
return jsonRes
|
||||
}
|
||||
return httputil.MakeExternalAPI(metricsName, h)
|
||||
}
|
||||
95
relayapi/storage/fake_relay_db.go
Normal file
95
relayapi/storage/fake_relay_db.go
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type testDatabase struct {
|
||||
nid int64
|
||||
nidMutex sync.Mutex
|
||||
transactions map[int64]json.RawMessage
|
||||
associations map[gomatrixserverlib.ServerName][]int64
|
||||
}
|
||||
|
||||
func NewFakeRelayDatabase() *testDatabase {
|
||||
return &testDatabase{
|
||||
nid: 1,
|
||||
nidMutex: sync.Mutex{},
|
||||
transactions: make(map[int64]json.RawMessage),
|
||||
associations: make(map[gomatrixserverlib.ServerName][]int64),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *testDatabase) InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error {
|
||||
if _, ok := d.associations[serverName]; !ok {
|
||||
d.associations[serverName] = []int64{}
|
||||
}
|
||||
d.associations[serverName] = append(d.associations[serverName], nid)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *testDatabase) DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error {
|
||||
for _, nid := range jsonNIDs {
|
||||
for index, associatedNID := range d.associations[serverName] {
|
||||
if associatedNID == nid {
|
||||
d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *testDatabase) SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) {
|
||||
results := []int64{}
|
||||
resultCount := limit
|
||||
if limit > len(d.associations[serverName]) {
|
||||
resultCount = len(d.associations[serverName])
|
||||
}
|
||||
if resultCount > 0 {
|
||||
for i := 0; i < resultCount; i++ {
|
||||
results = append(results, d.associations[serverName][i])
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (d *testDatabase) SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) {
|
||||
return int64(len(d.associations[serverName])), nil
|
||||
}
|
||||
|
||||
func (d *testDatabase) InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) {
|
||||
d.nidMutex.Lock()
|
||||
defer d.nidMutex.Unlock()
|
||||
|
||||
nid := d.nid
|
||||
d.transactions[nid] = []byte(json)
|
||||
d.nid++
|
||||
|
||||
return nid, nil
|
||||
}
|
||||
|
||||
func (d *testDatabase) DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error {
|
||||
for _, nid := range nids {
|
||||
delete(d.transactions, nid)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *testDatabase) SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) {
|
||||
result := make(map[int64][]byte)
|
||||
for _, nid := range jsonNIDs {
|
||||
if transaction, ok := d.transactions[nid]; ok {
|
||||
result[nid] = transaction
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
30
relayapi/storage/interface.go
Normal file
30
relayapi/storage/interface.go
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type Database interface {
|
||||
StoreAsyncTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*shared.Receipt, error)
|
||||
AssociateAsyncTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *shared.Receipt) error
|
||||
CleanAsyncTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*shared.Receipt) error
|
||||
GetAsyncTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *shared.Receipt, error)
|
||||
GetAsyncTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error)
|
||||
}
|
||||
|
|
@ -23,60 +23,60 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
const transactionJSONSchema = `
|
||||
-- The federationsender_transaction_json table contains event contents that
|
||||
const relayQueueJSONSchema = `
|
||||
-- The relayapi_queue_json table contains event contents that
|
||||
-- we are storing for future forwarding.
|
||||
CREATE TABLE IF NOT EXISTS federationsender_transaction_json (
|
||||
CREATE TABLE IF NOT EXISTS relayapi_queue_json (
|
||||
-- The JSON NID. This allows cross-referencing to find the JSON blob.
|
||||
json_nid BIGSERIAL,
|
||||
-- The JSON body. Text so that we preserve UTF-8.
|
||||
json_body TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_transaction_json_json_nid_idx
|
||||
ON federationsender_transaction_json (json_nid);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx
|
||||
ON relayapi_queue_json (json_nid);
|
||||
`
|
||||
|
||||
const insertTransactionJSONSQL = "" +
|
||||
"INSERT INTO federationsender_transaction_json (json_body)" +
|
||||
const insertQueueJSONSQL = "" +
|
||||
"INSERT INTO relayapi_queue_json (json_body)" +
|
||||
" VALUES ($1)" +
|
||||
" RETURNING json_nid"
|
||||
|
||||
const deleteTransactionJSONSQL = "" +
|
||||
"DELETE FROM federationsender_transaction_json WHERE json_nid = ANY($1)"
|
||||
const deleteQueueJSONSQL = "" +
|
||||
"DELETE FROM relayapi_queue_json WHERE json_nid = ANY($1)"
|
||||
|
||||
const selectTransactionJSONSQL = "" +
|
||||
"SELECT json_nid, json_body FROM federationsender_transaction_json" +
|
||||
const selectQueueJSONSQL = "" +
|
||||
"SELECT json_nid, json_body FROM relayapi_queue_json" +
|
||||
" WHERE json_nid = ANY($1)"
|
||||
|
||||
type transactionJSONStatements struct {
|
||||
type relayQueueJSONStatements struct {
|
||||
db *sql.DB
|
||||
insertJSONStmt *sql.Stmt
|
||||
deleteJSONStmt *sql.Stmt
|
||||
selectJSONStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresTransactionJSONTable(db *sql.DB) (s *transactionJSONStatements, err error) {
|
||||
s = &transactionJSONStatements{
|
||||
func NewPostgresRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) {
|
||||
s = &relayQueueJSONStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = s.db.Exec(transactionJSONSchema)
|
||||
_, err = s.db.Exec(relayQueueJSONSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertJSONStmt, err = s.db.Prepare(insertTransactionJSONSQL); err != nil {
|
||||
if s.insertJSONStmt, err = s.db.Prepare(insertQueueJSONSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteJSONStmt, err = s.db.Prepare(deleteTransactionJSONSQL); err != nil {
|
||||
if s.deleteJSONStmt, err = s.db.Prepare(deleteQueueJSONSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectJSONStmt, err = s.db.Prepare(selectTransactionJSONSQL); err != nil {
|
||||
if s.selectJSONStmt, err = s.db.Prepare(selectQueueJSONSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *transactionJSONStatements) InsertTransactionJSON(
|
||||
func (s *relayQueueJSONStatements) InsertQueueJSON(
|
||||
ctx context.Context, txn *sql.Tx, json string,
|
||||
) (int64, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
|
||||
|
|
@ -87,7 +87,7 @@ func (s *transactionJSONStatements) InsertTransactionJSON(
|
|||
return lastid, nil
|
||||
}
|
||||
|
||||
func (s *transactionJSONStatements) DeleteTransactionJSON(
|
||||
func (s *relayQueueJSONStatements) DeleteQueueJSON(
|
||||
ctx context.Context, txn *sql.Tx, nids []int64,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt)
|
||||
|
|
@ -95,7 +95,7 @@ func (s *transactionJSONStatements) DeleteTransactionJSON(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *transactionJSONStatements) SelectTransactionJSON(
|
||||
func (s *relayQueueJSONStatements) SelectQueueJSON(
|
||||
ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
|
||||
) (map[int64][]byte, error) {
|
||||
blobs := map[int64][]byte{}
|
||||
|
|
@ -24,79 +24,79 @@ import (
|
|||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const queueTransactionsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_queue_transactions (
|
||||
const relayQueueSchema = `
|
||||
CREATE TABLE IF NOT EXISTS relayapi_queue (
|
||||
-- The transaction ID that was generated before persisting the event.
|
||||
transaction_id TEXT NOT NULL,
|
||||
-- The destination server that we will send the event to.
|
||||
server_name TEXT NOT NULL,
|
||||
-- The JSON NID from the federationsender_transaction_json table.
|
||||
-- The JSON NID from the relayapi_queue_json table.
|
||||
json_nid BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_transactions_transaction_json_nid_idx
|
||||
ON federationsender_queue_transactions (json_nid, server_name);
|
||||
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_json_nid_idx
|
||||
ON federationsender_queue_transactions (json_nid);
|
||||
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_server_name_idx
|
||||
ON federationsender_queue_transactions (server_name);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx
|
||||
ON relayapi_queue (json_nid, server_name);
|
||||
CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx
|
||||
ON relayapi_queue (json_nid);
|
||||
CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx
|
||||
ON relayapi_queue (server_name);
|
||||
`
|
||||
|
||||
const insertQueueTransactionSQL = "" +
|
||||
"INSERT INTO federationsender_queue_transactions (transaction_id, server_name, json_nid)" +
|
||||
const insertQueueEntrySQL = "" +
|
||||
"INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" +
|
||||
" VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteQueueTransactionsSQL = "" +
|
||||
"DELETE FROM federationsender_queue_transactions WHERE server_name = $1 AND json_nid = ANY($2)"
|
||||
const deleteQueueEntriesSQL = "" +
|
||||
"DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid = ANY($2)"
|
||||
|
||||
const selectQueueTransactionsSQL = "" +
|
||||
"SELECT json_nid FROM federationsender_queue_transactions" +
|
||||
const selectQueueEntriesSQL = "" +
|
||||
"SELECT json_nid FROM relayapi_queue" +
|
||||
" WHERE server_name = $1" +
|
||||
" LIMIT $2"
|
||||
|
||||
const selectQueueTransactionsCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_transactions" +
|
||||
const selectQueueEntryCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM relayapi_queue" +
|
||||
" WHERE server_name = $1"
|
||||
|
||||
type queueTransactionsStatements struct {
|
||||
db *sql.DB
|
||||
insertQueueTransactionStmt *sql.Stmt
|
||||
deleteQueueTransactionsStmt *sql.Stmt
|
||||
selectQueueTransactionsStmt *sql.Stmt
|
||||
selectQueueTransactionsCountStmt *sql.Stmt
|
||||
type relayQueueStatements struct {
|
||||
db *sql.DB
|
||||
insertQueueEntryStmt *sql.Stmt
|
||||
deleteQueueEntriesStmt *sql.Stmt
|
||||
selectQueueEntriesStmt *sql.Stmt
|
||||
selectQueueEntryCountStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresQueueTransactionsTable(db *sql.DB) (s *queueTransactionsStatements, err error) {
|
||||
s = &queueTransactionsStatements{
|
||||
func NewPostgresRelayQueueTable(db *sql.DB) (s *relayQueueStatements, err error) {
|
||||
s = &relayQueueStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = s.db.Exec(queueTransactionsSchema)
|
||||
_, err = s.db.Exec(relayQueueSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertQueueTransactionStmt, err = s.db.Prepare(insertQueueTransactionSQL); err != nil {
|
||||
if s.insertQueueEntryStmt, err = s.db.Prepare(insertQueueEntrySQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteQueueTransactionsStmt, err = s.db.Prepare(deleteQueueTransactionsSQL); err != nil {
|
||||
if s.deleteQueueEntriesStmt, err = s.db.Prepare(deleteQueueEntriesSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueTransactionsStmt, err = s.db.Prepare(selectQueueTransactionsSQL); err != nil {
|
||||
if s.selectQueueEntriesStmt, err = s.db.Prepare(selectQueueEntriesSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueTransactionsCountStmt, err = s.db.Prepare(selectQueueTransactionsCountSQL); err != nil {
|
||||
if s.selectQueueEntryCountStmt, err = s.db.Prepare(selectQueueEntryCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *queueTransactionsStatements) InsertQueueTransaction(
|
||||
func (s *relayQueueStatements) InsertQueueEntry(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
transactionID gomatrixserverlib.TransactionID,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
nid int64,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertQueueTransactionStmt)
|
||||
stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
transactionID, // the transaction ID that we initially attempted
|
||||
|
|
@ -106,22 +106,22 @@ func (s *queueTransactionsStatements) InsertQueueTransaction(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *queueTransactionsStatements) DeleteQueueTransactions(
|
||||
func (s *relayQueueStatements) DeleteQueueEntries(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
jsonNIDs []int64,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionsStmt)
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteQueueEntriesStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *queueTransactionsStatements) SelectQueueTransactions(
|
||||
func (s *relayQueueStatements) SelectQueueEntries(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
limit int,
|
||||
) ([]int64, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsStmt)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt)
|
||||
rows, err := stmt.QueryContext(ctx, serverName, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -139,11 +139,11 @@ func (s *queueTransactionsStatements) SelectQueueTransactions(
|
|||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *queueTransactionsStatements) SelectQueueTransactionCount(
|
||||
func (s *relayQueueStatements) SelectQueueEntryCount(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsCountStmt)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
// It's acceptable for there to be no rows referencing a given
|
||||
59
relayapi/storage/postgres/storage.go
Normal file
59
relayapi/storage/postgres/storage.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/relayapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// Database stores information needed by the relayapi
|
||||
type Database struct {
|
||||
shared.Database
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
}
|
||||
|
||||
// NewDatabase opens a new database
|
||||
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) {
|
||||
var d Database
|
||||
var err error
|
||||
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queue, err := NewPostgresRelayQueueTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queueJSON, err := NewPostgresRelayQueueJSONTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.Database = shared.Database{
|
||||
DB: d.db,
|
||||
IsLocalServerName: isLocalServerName,
|
||||
Cache: cache,
|
||||
Writer: d.writer,
|
||||
RelayQueue: queue,
|
||||
RelayQueueJSON: queueJSON,
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
142
relayapi/storage/shared/storage.go
Normal file
142
relayapi/storage/shared/storage.go
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/relayapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
DB *sql.DB
|
||||
IsLocalServerName func(gomatrixserverlib.ServerName) bool
|
||||
Cache caching.FederationCache
|
||||
Writer sqlutil.Writer
|
||||
RelayQueue tables.RelayQueue
|
||||
RelayQueueJSON tables.RelayQueueJSON
|
||||
}
|
||||
|
||||
func (d *Database) StoreAsyncTransaction(
|
||||
ctx context.Context, txn gomatrixserverlib.Transaction,
|
||||
) (*shared.Receipt, error) {
|
||||
var err error
|
||||
json, err := json.Marshal(txn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.JSONUnmarshall: %w", err)
|
||||
}
|
||||
|
||||
var nid int64
|
||||
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
nid, err = d.RelayQueueJSON.InsertQueueJSON(ctx, txn, string(json))
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
|
||||
}
|
||||
|
||||
receipt := shared.NewReceipt(nid)
|
||||
return &receipt, nil
|
||||
}
|
||||
|
||||
func (d *Database) AssociateAsyncTransactionWithDestinations(
|
||||
ctx context.Context,
|
||||
destinations map[gomatrixserverlib.UserID]struct{},
|
||||
transactionID gomatrixserverlib.TransactionID,
|
||||
receipt *shared.Receipt,
|
||||
) error {
|
||||
for destination := range destinations {
|
||||
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
err := d.RelayQueue.InsertQueueEntry(
|
||||
ctx, txn, transactionID, destination.Domain(), receipt.GetNID())
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("d.insertQueueEntry: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) CleanAsyncTransactions(
|
||||
ctx context.Context,
|
||||
userID gomatrixserverlib.UserID,
|
||||
receipts []*shared.Receipt,
|
||||
) error {
|
||||
println(len(receipts))
|
||||
nids := make([]int64, len(receipts))
|
||||
for i, receipt := range receipts {
|
||||
nids[i] = receipt.GetNID()
|
||||
}
|
||||
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
err := d.RelayQueue.DeleteQueueEntries(ctx, txn, userID.Domain(), nids)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("d.deleteQueueEntries: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) GetAsyncTransaction(
|
||||
ctx context.Context,
|
||||
userID gomatrixserverlib.UserID,
|
||||
) (*gomatrixserverlib.Transaction, *shared.Receipt, error) {
|
||||
nids, err := d.RelayQueue.SelectQueueEntries(ctx, nil, userID.Domain(), 1)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("d.SelectQueueEntries: %w", err)
|
||||
}
|
||||
if len(nids) == 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
txns := map[int64][]byte{}
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
txns, err = d.RelayQueueJSON.SelectQueueJSON(ctx, txn, nids)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("d.SelectQueueJSON: %w", err)
|
||||
}
|
||||
|
||||
transaction := &gomatrixserverlib.Transaction{}
|
||||
err = json.Unmarshal(txns[nids[0]], transaction)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("Unmarshall transaction: %w", err)
|
||||
}
|
||||
|
||||
receipt := shared.NewReceipt(nids[0])
|
||||
return transaction, &receipt, nil
|
||||
}
|
||||
|
||||
func (d *Database) GetAsyncTransactionCount(
|
||||
ctx context.Context,
|
||||
userID gomatrixserverlib.UserID,
|
||||
) (int64, error) {
|
||||
count, err := d.RelayQueue.SelectQueueEntryCount(ctx, nil, userID.Domain())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("d.SelectQueueEntryCount: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
|
@ -24,53 +24,53 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
const transactionJSONSchema = `
|
||||
-- The federationsender_transaction_json table contains event contents that
|
||||
const relayQueueJSONSchema = `
|
||||
-- The relayapi_queue_json table contains event contents that
|
||||
-- we are storing for future forwarding.
|
||||
CREATE TABLE IF NOT EXISTS federationsender_transaction_json (
|
||||
CREATE TABLE IF NOT EXISTS relayapi_queue_json (
|
||||
-- The JSON NID. This allows cross-referencing to find the JSON blob.
|
||||
json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The JSON body. Text so that we preserve UTF-8.
|
||||
json_body TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_transaction_json_json_nid_idx
|
||||
ON federationsender_transaction_json (json_nid);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx
|
||||
ON relayapi_queue_json (json_nid);
|
||||
`
|
||||
|
||||
const insertTransactionJSONSQL = "" +
|
||||
"INSERT INTO federationsender_transaction_json (json_body)" +
|
||||
const insertQueueJSONSQL = "" +
|
||||
"INSERT INTO relayapi_queue_json (json_body)" +
|
||||
" VALUES ($1)"
|
||||
|
||||
const deleteTransactionJSONSQL = "" +
|
||||
"DELETE FROM federationsender_transaction_json WHERE json_nid IN ($1)"
|
||||
const deleteQueueJSONSQL = "" +
|
||||
"DELETE FROM relayapi_queue_json WHERE json_nid IN ($1)"
|
||||
|
||||
const selectTransactionJSONSQL = "" +
|
||||
"SELECT json_nid, json_body FROM federationsender_transaction_json" +
|
||||
const selectQueueJSONSQL = "" +
|
||||
"SELECT json_nid, json_body FROM relayapi_queue_json" +
|
||||
" WHERE json_nid IN ($1)"
|
||||
|
||||
type transactionJSONStatements struct {
|
||||
type relayQueueJSONStatements struct {
|
||||
db *sql.DB
|
||||
insertJSONStmt *sql.Stmt
|
||||
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
}
|
||||
|
||||
func NewSQLiteTransactionJSONTable(db *sql.DB) (s *transactionJSONStatements, err error) {
|
||||
s = &transactionJSONStatements{
|
||||
func NewSQLiteRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) {
|
||||
s = &relayQueueJSONStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(transactionJSONSchema)
|
||||
_, err = db.Exec(relayQueueJSONSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertJSONStmt, err = db.Prepare(insertTransactionJSONSQL); err != nil {
|
||||
if s.insertJSONStmt, err = db.Prepare(insertQueueJSONSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *transactionJSONStatements) InsertTransactionJSON(
|
||||
func (s *relayQueueJSONStatements) InsertQueueJSON(
|
||||
ctx context.Context, txn *sql.Tx, json string,
|
||||
) (lastid int64, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
|
||||
|
|
@ -85,13 +85,13 @@ func (s *transactionJSONStatements) InsertTransactionJSON(
|
|||
return
|
||||
}
|
||||
|
||||
func (s *transactionJSONStatements) DeleteTransactionJSON(
|
||||
func (s *relayQueueJSONStatements) DeleteQueueJSON(
|
||||
ctx context.Context, txn *sql.Tx, nids []int64,
|
||||
) error {
|
||||
deleteSQL := strings.Replace(deleteTransactionJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
|
||||
deleteSQL := strings.Replace(deleteQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
|
||||
deleteStmt, err := txn.Prepare(deleteSQL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.deleteTransactionJSON s.db.Prepare: %w", err)
|
||||
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
|
||||
}
|
||||
|
||||
iNIDs := make([]interface{}, len(nids))
|
||||
|
|
@ -104,13 +104,13 @@ func (s *transactionJSONStatements) DeleteTransactionJSON(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *transactionJSONStatements) SelectTransactionJSON(
|
||||
func (s *relayQueueJSONStatements) SelectQueueJSON(
|
||||
ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
|
||||
) (map[int64][]byte, error) {
|
||||
selectSQL := strings.Replace(selectTransactionJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
|
||||
selectSQL := strings.Replace(selectQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
|
||||
selectStmt, err := txn.Prepare(selectSQL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s.selectTransactionJSON s.db.Prepare: %w", err)
|
||||
return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err)
|
||||
}
|
||||
|
||||
iNIDs := make([]interface{}, len(jsonNIDs))
|
||||
|
|
@ -122,14 +122,14 @@ func (s *transactionJSONStatements) SelectTransactionJSON(
|
|||
stmt := sqlutil.TxStmt(txn, selectStmt)
|
||||
rows, err := stmt.QueryContext(ctx, iNIDs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s.selectTransactionJSON stmt.QueryContext: %w", err)
|
||||
return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err)
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed")
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectQueueJSON: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var nid int64
|
||||
var blob []byte
|
||||
if err = rows.Scan(&nid, &blob); err != nil {
|
||||
return nil, fmt.Errorf("s.selectTransactionJSON rows.Scan: %w", err)
|
||||
return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err)
|
||||
}
|
||||
blobs[nid] = blob
|
||||
}
|
||||
|
|
@ -25,79 +25,79 @@ import (
|
|||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const queueTransactionsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_queue_transactions (
|
||||
const relayQueueSchema = `
|
||||
CREATE TABLE IF NOT EXISTS relayapi_queue (
|
||||
-- The transaction ID that was generated before persisting the event.
|
||||
transaction_id TEXT NOT NULL,
|
||||
-- The domain part of the user ID the m.room.member event is for.
|
||||
server_name TEXT NOT NULL,
|
||||
-- The JSON NID from the federationsender_queue_transactions_json table.
|
||||
-- The JSON NID from the relayapi_queue_json table.
|
||||
json_nid BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_transactions_transaction_json_nid_idx
|
||||
ON federationsender_queue_transactions (json_nid, server_name);
|
||||
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_json_nid_idx
|
||||
ON federationsender_queue_transactions (json_nid);
|
||||
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_server_name_idx
|
||||
ON federationsender_queue_transactions (server_name);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx
|
||||
ON relayapi_queue (json_nid, server_name);
|
||||
CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx
|
||||
ON relayapi_queue (json_nid);
|
||||
CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx
|
||||
ON relayapi_queue (server_name);
|
||||
`
|
||||
|
||||
const insertQueueTransactionSQL = "" +
|
||||
"INSERT INTO federationsender_queue_transactions (transaction_id, server_name, json_nid)" +
|
||||
const insertQueueEntrySQL = "" +
|
||||
"INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" +
|
||||
" VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteQueueTransactionsSQL = "" +
|
||||
"DELETE FROM federationsender_queue_transactions WHERE server_name = $1 AND json_nid IN ($2)"
|
||||
const deleteQueueEntriesSQL = "" +
|
||||
"DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid IN ($2)"
|
||||
|
||||
const selectQueueTransactionsSQL = "" +
|
||||
"SELECT json_nid FROM federationsender_queue_transactions" +
|
||||
const selectQueueEntriesSQL = "" +
|
||||
"SELECT json_nid FROM relayapi_queue" +
|
||||
" WHERE server_name = $1" +
|
||||
" LIMIT $2"
|
||||
|
||||
const selectQueueTransactionsCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_transactions" +
|
||||
const selectQueueEntryCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM relayapi_queue" +
|
||||
" WHERE server_name = $1"
|
||||
|
||||
type queueTransactionsStatements struct {
|
||||
db *sql.DB
|
||||
insertQueueTransactionStmt *sql.Stmt
|
||||
selectQueueTransactionsStmt *sql.Stmt
|
||||
selectQueueTransactionsCountStmt *sql.Stmt
|
||||
// deleteQueueTransactionsStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
type relayQueueStatements struct {
|
||||
db *sql.DB
|
||||
insertQueueEntryStmt *sql.Stmt
|
||||
selectQueueEntriesStmt *sql.Stmt
|
||||
selectQueueEntryCountStmt *sql.Stmt
|
||||
// deleteQueueEntriesStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
}
|
||||
|
||||
func NewSQLiteQueueTransactionsTable(db *sql.DB) (s *queueTransactionsStatements, err error) {
|
||||
s = &queueTransactionsStatements{
|
||||
func NewSQLiteRelayQueueTable(db *sql.DB) (s *relayQueueStatements, err error) {
|
||||
s = &relayQueueStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(queueTransactionsSchema)
|
||||
_, err = db.Exec(relayQueueSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertQueueTransactionStmt, err = db.Prepare(insertQueueTransactionSQL); err != nil {
|
||||
if s.insertQueueEntryStmt, err = db.Prepare(insertQueueEntrySQL); err != nil {
|
||||
return
|
||||
}
|
||||
//if s.deleteQueueTransactionsStmt, err = db.Prepare(deleteQueueTransactionsSQL); err != nil {
|
||||
//if s.deleteQueueEntriesStmt, err = db.Prepare(deleteQueueEntriesSQL); err != nil {
|
||||
// return
|
||||
//}
|
||||
if s.selectQueueTransactionsStmt, err = db.Prepare(selectQueueTransactionsSQL); err != nil {
|
||||
if s.selectQueueEntriesStmt, err = db.Prepare(selectQueueEntriesSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueTransactionsCountStmt, err = db.Prepare(selectQueueTransactionsCountSQL); err != nil {
|
||||
if s.selectQueueEntryCountStmt, err = db.Prepare(selectQueueEntryCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *queueTransactionsStatements) InsertQueueTransaction(
|
||||
func (s *relayQueueStatements) InsertQueueEntry(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
transactionID gomatrixserverlib.TransactionID,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
nid int64,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertQueueTransactionStmt)
|
||||
stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
transactionID, // the transaction ID that we initially attempted
|
||||
|
|
@ -107,15 +107,15 @@ func (s *queueTransactionsStatements) InsertQueueTransaction(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *queueTransactionsStatements) DeleteQueueTransactions(
|
||||
func (s *relayQueueStatements) DeleteQueueEntries(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
jsonNIDs []int64,
|
||||
) error {
|
||||
deleteSQL := strings.Replace(deleteQueueTransactionsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
|
||||
deleteSQL := strings.Replace(deleteQueueEntriesSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
|
||||
deleteStmt, err := txn.Prepare(deleteSQL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.deleteQueueTransactionJSON s.db.Prepare: %w", err)
|
||||
return fmt.Errorf("s.deleteQueueEntries s.db.Prepare: %w", err)
|
||||
}
|
||||
|
||||
params := make([]interface{}, len(jsonNIDs)+1)
|
||||
|
|
@ -129,12 +129,12 @@ func (s *queueTransactionsStatements) DeleteQueueTransactions(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *queueTransactionsStatements) SelectQueueTransactions(
|
||||
func (s *relayQueueStatements) SelectQueueEntries(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
limit int,
|
||||
) ([]int64, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsStmt)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt)
|
||||
rows, err := stmt.QueryContext(ctx, serverName, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -152,11 +152,11 @@ func (s *queueTransactionsStatements) SelectQueueTransactions(
|
|||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *queueTransactionsStatements) SelectQueueTransactionCount(
|
||||
func (s *relayQueueStatements) SelectQueueEntryCount(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsCountStmt)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
// It's acceptable for there to be no rows referencing a given
|
||||
53
relayapi/storage/sqlite3/storage.go
Normal file
53
relayapi/storage/sqlite3/storage.go
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/relayapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// Database stores information needed by the federation sender
|
||||
type Database struct {
|
||||
shared.Database
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
}
|
||||
|
||||
// NewDatabase opens a new database
|
||||
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) {
|
||||
var d Database
|
||||
var err error
|
||||
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queue, err := NewSQLiteRelayQueueTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queueJSON, err := NewSQLiteRelayQueueJSONTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.Database = shared.Database{
|
||||
DB: d.db,
|
||||
IsLocalServerName: isLocalServerName,
|
||||
Cache: cache,
|
||||
Writer: d.writer,
|
||||
RelayQueue: queue,
|
||||
RelayQueueJSON: queueJSON,
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
41
relayapi/storage/storage.go
Normal file
41
relayapi/storage/storage.go
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !wasm
|
||||
// +build !wasm
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/relayapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// NewDatabase opens a new database
|
||||
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
}
|
||||
35
relayapi/storage/tables/interface.go
Normal file
35
relayapi/storage/tables/interface.go
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type RelayQueue interface {
|
||||
InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
|
||||
DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
|
||||
SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
|
||||
SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
|
||||
}
|
||||
|
||||
type RelayQueueJSON interface {
|
||||
InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error)
|
||||
DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error
|
||||
SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error)
|
||||
}
|
||||
|
|
@ -8,10 +8,10 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/relayapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/relayapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -35,31 +35,31 @@ func mustCreateTransaction() gomatrixserverlib.Transaction {
|
|||
return txn
|
||||
}
|
||||
|
||||
type TransactionJSONDatabase struct {
|
||||
type RelayQueueJSONDatabase struct {
|
||||
DB *sql.DB
|
||||
Writer sqlutil.Writer
|
||||
Table tables.FederationTransactionJSON
|
||||
Table tables.RelayQueueJSON
|
||||
}
|
||||
|
||||
func mustCreateTransactionJSONTable(t *testing.T, dbType test.DBType) (database TransactionJSONDatabase, close func()) {
|
||||
func mustCreateQueueJSONTable(t *testing.T, dbType test.DBType) (database RelayQueueJSONDatabase, close func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, sqlutil.NewExclusiveWriter())
|
||||
assert.NoError(t, err)
|
||||
var tab tables.FederationTransactionJSON
|
||||
var tab tables.RelayQueueJSON
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
tab, err = postgres.NewPostgresTransactionJSONTable(db)
|
||||
tab, err = postgres.NewPostgresRelayQueueJSONTable(db)
|
||||
assert.NoError(t, err)
|
||||
case test.DBTypeSQLite:
|
||||
tab, err = sqlite3.NewSQLiteTransactionJSONTable(db)
|
||||
tab, err = sqlite3.NewSQLiteRelayQueueJSONTable(db)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
database = TransactionJSONDatabase{
|
||||
database = RelayQueueJSONDatabase{
|
||||
DB: db,
|
||||
Writer: sqlutil.NewDummyWriter(),
|
||||
Table: tab,
|
||||
|
|
@ -70,7 +70,7 @@ func mustCreateTransactionJSONTable(t *testing.T, dbType test.DBType) (database
|
|||
func TestShoudInsertTransaction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateTransactionJSONTable(t, dbType)
|
||||
db, close := mustCreateQueueJSONTable(t, dbType)
|
||||
defer close()
|
||||
|
||||
transaction := mustCreateTransaction()
|
||||
|
|
@ -79,7 +79,7 @@ func TestShoudInsertTransaction(t *testing.T) {
|
|||
t.Fatalf("Invalid transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
_, err = db.Table.InsertTransactionJSON(ctx, nil, string(tx))
|
||||
_, err = db.Table.InsertQueueJSON(ctx, nil, string(tx))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
|
|
@ -89,7 +89,7 @@ func TestShoudInsertTransaction(t *testing.T) {
|
|||
func TestShouldRetrieveInsertedTransaction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateTransactionJSONTable(t, dbType)
|
||||
db, close := mustCreateQueueJSONTable(t, dbType)
|
||||
defer close()
|
||||
|
||||
transaction := mustCreateTransaction()
|
||||
|
|
@ -98,14 +98,14 @@ func TestShouldRetrieveInsertedTransaction(t *testing.T) {
|
|||
t.Fatalf("Invalid transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
nid, err := db.Table.InsertTransactionJSON(ctx, nil, string(tx))
|
||||
nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
var storedJSON map[int64][]byte
|
||||
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
|
||||
storedJSON, err = db.Table.SelectTransactionJSON(ctx, txn, []int64{nid})
|
||||
storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid})
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
|
|
@ -124,7 +124,7 @@ func TestShouldRetrieveInsertedTransaction(t *testing.T) {
|
|||
func TestShouldDeleteTransaction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateTransactionJSONTable(t, dbType)
|
||||
db, close := mustCreateQueueJSONTable(t, dbType)
|
||||
defer close()
|
||||
|
||||
transaction := mustCreateTransaction()
|
||||
|
|
@ -133,14 +133,14 @@ func TestShouldDeleteTransaction(t *testing.T) {
|
|||
t.Fatalf("Invalid transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
nid, err := db.Table.InsertTransactionJSON(ctx, nil, string(tx))
|
||||
nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
storedJSON := map[int64][]byte{}
|
||||
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
|
||||
err = db.Table.DeleteTransactionJSON(ctx, txn, []int64{nid})
|
||||
err = db.Table.DeleteQueueJSON(ctx, txn, []int64{nid})
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
|
|
@ -149,7 +149,7 @@ func TestShouldDeleteTransaction(t *testing.T) {
|
|||
|
||||
storedJSON = map[int64][]byte{}
|
||||
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
|
||||
storedJSON, err = db.Table.SelectTransactionJSON(ctx, txn, []int64{nid})
|
||||
storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid})
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
|
|
@ -7,41 +7,41 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/relayapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/relayapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type QueueTransactionsDatabase struct {
|
||||
type RelayQueueDatabase struct {
|
||||
DB *sql.DB
|
||||
Writer sqlutil.Writer
|
||||
Table tables.FederationQueueTransactions
|
||||
Table tables.RelayQueue
|
||||
}
|
||||
|
||||
func mustCreateQueueTransactionsTable(t *testing.T, dbType test.DBType) (database QueueTransactionsDatabase, close func()) {
|
||||
func mustCreateQueueTable(t *testing.T, dbType test.DBType) (database RelayQueueDatabase, close func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, sqlutil.NewExclusiveWriter())
|
||||
assert.NoError(t, err)
|
||||
var tab tables.FederationQueueTransactions
|
||||
var tab tables.RelayQueue
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
tab, err = postgres.NewPostgresQueueTransactionsTable(db)
|
||||
tab, err = postgres.NewPostgresRelayQueueTable(db)
|
||||
assert.NoError(t, err)
|
||||
case test.DBTypeSQLite:
|
||||
tab, err = sqlite3.NewSQLiteQueueTransactionsTable(db)
|
||||
tab, err = sqlite3.NewSQLiteRelayQueueTable(db)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
database = QueueTransactionsDatabase{
|
||||
database = RelayQueueDatabase{
|
||||
DB: db,
|
||||
Writer: sqlutil.NewDummyWriter(),
|
||||
Table: tab,
|
||||
|
|
@ -52,13 +52,13 @@ func mustCreateQueueTransactionsTable(t *testing.T, dbType test.DBType) (databas
|
|||
func TestShoudInsertQueueTransaction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateQueueTransactionsTable(t, dbType)
|
||||
db, close := mustCreateQueueTable(t, dbType)
|
||||
defer close()
|
||||
|
||||
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
|
||||
serverName := gomatrixserverlib.ServerName("domain")
|
||||
nid := int64(1)
|
||||
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid)
|
||||
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
|
|
@ -68,19 +68,19 @@ func TestShoudInsertQueueTransaction(t *testing.T) {
|
|||
func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateQueueTransactionsTable(t, dbType)
|
||||
db, close := mustCreateQueueTable(t, dbType)
|
||||
defer close()
|
||||
|
||||
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
|
||||
serverName := gomatrixserverlib.ServerName("domain")
|
||||
nid := int64(1)
|
||||
|
||||
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid)
|
||||
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
retrievedNids, err := db.Table.SelectQueueTransactions(ctx, nil, serverName, 10)
|
||||
retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving transaction: %s", err.Error())
|
||||
}
|
||||
|
|
@ -93,27 +93,27 @@ func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) {
|
|||
func TestShouldDeleteQueueTransaction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateQueueTransactionsTable(t, dbType)
|
||||
db, close := mustCreateQueueTable(t, dbType)
|
||||
defer close()
|
||||
|
||||
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
|
||||
serverName := gomatrixserverlib.ServerName("domain")
|
||||
nid := int64(1)
|
||||
|
||||
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid)
|
||||
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
|
||||
err = db.Table.DeleteQueueTransactions(ctx, txn, serverName, []int64{nid})
|
||||
err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid})
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed deleting transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
count, err := db.Table.SelectQueueTransactionCount(ctx, nil, serverName)
|
||||
count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
|
||||
}
|
||||
|
|
@ -124,7 +124,7 @@ func TestShouldDeleteQueueTransaction(t *testing.T) {
|
|||
func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateQueueTransactionsTable(t, dbType)
|
||||
db, close := mustCreateQueueTable(t, dbType)
|
||||
defer close()
|
||||
|
||||
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
|
||||
|
|
@ -135,34 +135,34 @@ func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) {
|
|||
nid2 := int64(2)
|
||||
transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano()))
|
||||
|
||||
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid)
|
||||
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
err = db.Table.InsertQueueTransaction(ctx, nil, transactionID2, serverName2, nid)
|
||||
err = db.Table.InsertQueueEntry(ctx, nil, transactionID2, serverName2, nid)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
err = db.Table.InsertQueueTransaction(ctx, nil, transactionID3, serverName, nid2)
|
||||
err = db.Table.InsertQueueEntry(ctx, nil, transactionID3, serverName, nid2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
|
||||
err = db.Table.DeleteQueueTransactions(ctx, txn, serverName, []int64{nid})
|
||||
err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid})
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed deleting transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
count, err := db.Table.SelectQueueTransactionCount(ctx, nil, serverName)
|
||||
count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
|
||||
}
|
||||
assert.Equal(t, int64(1), count)
|
||||
|
||||
count, err = db.Table.SelectQueueTransactionCount(ctx, nil, serverName2)
|
||||
count, err = db.Table.SelectQueueEntryCount(ctx, nil, serverName2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
|
||||
}
|
||||
|
|
@ -62,6 +62,7 @@ type Dendrite struct {
|
|||
RoomServer RoomServer `yaml:"room_server"`
|
||||
SyncAPI SyncAPI `yaml:"sync_api"`
|
||||
UserAPI UserAPI `yaml:"user_api"`
|
||||
RelayAPI RelayAPI `yaml:"relay_api"`
|
||||
|
||||
MSCs MSCs `yaml:"mscs"`
|
||||
|
||||
|
|
|
|||
38
setup/config/config_relayapi.go
Normal file
38
setup/config/config_relayapi.go
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
package config
|
||||
|
||||
type RelayAPI struct {
|
||||
Matrix *Global `yaml:"-"`
|
||||
|
||||
InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"`
|
||||
ExternalAPI ExternalAPIOptions `yaml:"external_api,omitempty"`
|
||||
|
||||
// The database stores information used by the relay queue to
|
||||
// forward transactions to remote servers.
|
||||
Database DatabaseOptions `yaml:"database,omitempty"`
|
||||
}
|
||||
|
||||
func (c *RelayAPI) Defaults(opts DefaultOpts) {
|
||||
if !opts.Monolithic {
|
||||
c.InternalAPI.Listen = "http://localhost:7775"
|
||||
c.InternalAPI.Connect = "http://localhost:7775"
|
||||
c.ExternalAPI.Listen = "http://[::]:8075"
|
||||
c.Database.Defaults(10)
|
||||
}
|
||||
if opts.Generate {
|
||||
if !opts.Monolithic {
|
||||
c.Database.ConnectionString = "file:relayapi.db"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RelayAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||
if isMonolith { // polylith required configs below
|
||||
return
|
||||
}
|
||||
if c.Matrix.DatabaseOptions.ConnectionString == "" {
|
||||
checkNotEmpty(configErrs, "relay_api.database.connection_string", string(c.Database.ConnectionString))
|
||||
}
|
||||
checkURL(configErrs, "relay_api.external_api.listen", string(c.ExternalAPI.Listen))
|
||||
checkURL(configErrs, "relay_api.internal_api.listen", string(c.InternalAPI.Listen))
|
||||
checkURL(configErrs, "relay_api.internal_api.connect", string(c.InternalAPI.Connect))
|
||||
}
|
||||
|
|
@ -23,6 +23,8 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/transactions"
|
||||
keyAPI "github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/mediaapi"
|
||||
"github.com/matrix-org/dendrite/relayapi"
|
||||
relayAPI "github.com/matrix-org/dendrite/relayapi/api"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
|
|
@ -44,6 +46,7 @@ type Monolith struct {
|
|||
RoomserverAPI roomserverAPI.RoomserverInternalAPI
|
||||
UserAPI userapi.UserInternalAPI
|
||||
KeyAPI keyAPI.KeyInternalAPI
|
||||
RelayAPI relayAPI.RelayInternalAPI
|
||||
|
||||
// Optional
|
||||
ExtPublicRoomsProvider api.ExtraPublicRoomsProvider
|
||||
|
|
@ -71,4 +74,9 @@ func (m *Monolith) AddAllPublicRoutes(base *base.BaseDendrite) {
|
|||
syncapi.AddPublicRoutes(
|
||||
base, m.UserAPI, m.RoomserverAPI, m.KeyAPI,
|
||||
)
|
||||
|
||||
// TODO : relayapi.AddPublicRoutes
|
||||
if m.RelayAPI != nil {
|
||||
relayapi.AddPublicRoutes(base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.RelayAPI, m.FederationAPI, m.KeyAPI)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue