From e83940b5cf5eb68cb5c686af5513e474f837439e Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Thu, 12 Jan 2023 16:11:15 -0700 Subject: [PATCH] Move in memory test databases into test package --- .../internal/federationclient_test.go | 12 +- federationapi/internal/perform_test.go | 10 +- federationapi/queue/destinationqueue.go | 16 +- federationapi/queue/queue.go | 6 +- federationapi/queue/queue_test.go | 2 +- federationapi/statistics/statistics_test.go | 4 +- federationapi/storage/fake_federation_db.go | 337 ------------------ federationapi/storage/interface.go | 16 +- .../storage/shared/receipt/receipt.go | 37 ++ federationapi/storage/shared/storage.go | 28 +- federationapi/storage/shared/storage_edus.go | 29 +- federationapi/storage/shared/storage_pdus.go | 27 +- .../tables/relay_servers_table_test.go | 5 +- relayapi/internal/perform.go | 6 +- relayapi/internal/perform_test.go | 15 +- relayapi/inthttp/client.go | 6 +- relayapi/inthttp/server.go | 18 +- relayapi/routing/relaytxn_test.go | 8 +- relayapi/routing/sendrelay_test.go | 12 +- relayapi/storage/fake_relay_db.go | 109 ------ relayapi/storage/interface.go | 10 +- .../storage/postgres/relay_queue_table.go | 14 +- relayapi/storage/postgres/storage.go | 7 +- relayapi/storage/shared/storage.go | 17 +- relayapi/storage/sqlite3/relay_queue_table.go | 14 +- relayapi/storage/sqlite3/storage.go | 7 +- relayapi/storage/storage.go | 7 +- .../tables/relay_queue_json_table_test.go | 5 +- .../storage/tables/relay_queue_table_test.go | 5 +- 29 files changed, 212 insertions(+), 577 deletions(-) delete mode 100644 federationapi/storage/fake_federation_db.go create mode 100644 federationapi/storage/shared/receipt/receipt.go delete mode 100644 relayapi/storage/fake_relay_db.go diff --git a/federationapi/internal/federationclient_test.go b/federationapi/internal/federationclient_test.go index e2570d4bd..49137e2d8 100644 --- a/federationapi/internal/federationclient_test.go +++ b/federationapi/internal/federationclient_test.go @@ -21,9 +21,9 @@ import ( "github.com/matrix-org/dendrite/federationapi/queue" "github.com/matrix-org/dendrite/federationapi/statistics" - "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/stretchr/testify/assert" ) @@ -50,7 +50,7 @@ func (t *testFedClient) ClaimKeys(ctx context.Context, origin, s gomatrixserverl } func TestFederationClientQueryKeys(t *testing.T) { - testDB := storage.NewFakeFederationDatabase() + testDB := test.NewInMemoryFederationDatabase() cfg := config.FederationAPI{ Matrix: &config.Global{ @@ -80,7 +80,7 @@ func TestFederationClientQueryKeys(t *testing.T) { } func TestFederationClientQueryKeysBlacklisted(t *testing.T) { - testDB := storage.NewFakeFederationDatabase() + testDB := test.NewInMemoryFederationDatabase() testDB.AddServerToBlacklist("server") cfg := config.FederationAPI{ @@ -111,7 +111,7 @@ func TestFederationClientQueryKeysBlacklisted(t *testing.T) { } func TestFederationClientQueryKeysFailure(t *testing.T) { - testDB := storage.NewFakeFederationDatabase() + testDB := test.NewInMemoryFederationDatabase() cfg := config.FederationAPI{ Matrix: &config.Global{ @@ -141,7 +141,7 @@ func TestFederationClientQueryKeysFailure(t *testing.T) { } func TestFederationClientClaimKeys(t *testing.T) { - testDB := storage.NewFakeFederationDatabase() + testDB := test.NewInMemoryFederationDatabase() cfg := config.FederationAPI{ Matrix: &config.Global{ @@ -171,7 +171,7 @@ func TestFederationClientClaimKeys(t *testing.T) { } func TestFederationClientClaimKeysBlacklisted(t *testing.T) { - testDB := storage.NewFakeFederationDatabase() + testDB := test.NewInMemoryFederationDatabase() testDB.AddServerToBlacklist("server") cfg := config.FederationAPI{ diff --git a/federationapi/internal/perform_test.go b/federationapi/internal/perform_test.go index 49e38b3a7..dd68e266e 100644 --- a/federationapi/internal/perform_test.go +++ b/federationapi/internal/perform_test.go @@ -21,9 +21,9 @@ import ( "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/queue" "github.com/matrix-org/dendrite/federationapi/statistics" - "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/stretchr/testify/assert" ) @@ -40,7 +40,7 @@ func (t *testFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixs } func TestPerformWakeupServers(t *testing.T) { - testDB := storage.NewFakeFederationDatabase() + testDB := test.NewInMemoryFederationDatabase() server := gomatrixserverlib.ServerName("wakeup") testDB.AddServerToBlacklist(server) @@ -87,7 +87,7 @@ func TestPerformWakeupServers(t *testing.T) { } func TestQueryRelayServers(t *testing.T) { - testDB := storage.NewFakeFederationDatabase() + testDB := test.NewInMemoryFederationDatabase() server := gomatrixserverlib.ServerName("wakeup") relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} @@ -124,7 +124,7 @@ func TestQueryRelayServers(t *testing.T) { } func TestPerformDirectoryLookup(t *testing.T) { - testDB := storage.NewFakeFederationDatabase() + testDB := test.NewInMemoryFederationDatabase() cfg := config.FederationAPI{ Matrix: &config.Global{ @@ -155,7 +155,7 @@ func TestPerformDirectoryLookup(t *testing.T) { } func TestPerformDirectoryLookupRelaying(t *testing.T) { - testDB := storage.NewFakeFederationDatabase() + testDB := test.NewInMemoryFederationDatabase() server := gomatrixserverlib.ServerName("wakeup") testDB.SetServerAssumedOffline(server) diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 19d41fe3d..3d1c406c4 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -29,7 +29,7 @@ import ( fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/process" ) @@ -70,7 +70,7 @@ type destinationQueue struct { // Send event adds the event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, receipt *shared.Receipt) { +func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, nid *receipt.Receipt) { if event == nil { logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination) return @@ -85,7 +85,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re if len(oq.pendingPDUs) < maxPDUsInMemory { oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{ pdu: event, - receipt: receipt, + receipt: nid, }) } else { oq.overflowed.Store(true) @@ -101,7 +101,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re // sendEDU adds the EDU event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *shared.Receipt) { +func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, nid *receipt.Receipt) { if event == nil { logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination) return @@ -116,7 +116,7 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share if len(oq.pendingEDUs) < maxEDUsInMemory { oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{ edu: event, - receipt: receipt, + receipt: nid, }) } else { oq.overflowed.Store(true) @@ -479,7 +479,7 @@ func (oq *destinationQueue) nextTransaction( func (oq *destinationQueue) createTransaction( pdus []*queuedPDU, edus []*queuedEDU, -) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) { +) (gomatrixserverlib.Transaction, []*receipt.Receipt, []*receipt.Receipt) { // If there's no projected transaction ID then generate one. If // the transaction succeeds then we'll set it back to "" so that // we generate a new one next time. If it fails, we'll preserve @@ -500,8 +500,8 @@ func (oq *destinationQueue) createTransaction( t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) t.TransactionID = oq.transactionID - var pduReceipts []*shared.Receipt - var eduReceipts []*shared.Receipt + var pduReceipts []*receipt.Receipt + var eduReceipts []*receipt.Receipt // Go through PDUs that we retrieved from the database, if any, // and add them into the transaction. diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index d7744790a..dcb303e79 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -30,7 +30,7 @@ import ( fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/process" ) @@ -138,12 +138,12 @@ func NewOutgoingQueues( } type queuedPDU struct { - receipt *shared.Receipt + receipt *receipt.Receipt pdu *gomatrixserverlib.HeaderedEvent } type queuedEDU struct { - receipt *shared.Receipt + receipt *receipt.Receipt edu *gomatrixserverlib.EDU } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index 0e573ea54..069f8caca 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -54,7 +54,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase } } else { // Fake Database - db := storage.NewFakeFederationDatabase() + db := test.NewInMemoryFederationDatabase() b := struct { ProcessContext *process.ProcessContext }{ProcessContext: process.NewProcessContext()} diff --git a/federationapi/statistics/statistics_test.go b/federationapi/statistics/statistics_test.go index 40b80755a..183b9aa0c 100644 --- a/federationapi/statistics/statistics_test.go +++ b/federationapi/statistics/statistics_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/matrix-org/dendrite/federationapi/storage" + "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/stretchr/testify/assert" ) @@ -106,7 +106,7 @@ func TestBackoff(t *testing.T) { } func TestRelayServersListing(t *testing.T) { - stats := NewStatistics(storage.NewFakeFederationDatabase(), FailuresUntilBlacklist, FailuresUntilAssumedOffline) + stats := NewStatistics(test.NewInMemoryFederationDatabase(), FailuresUntilBlacklist, FailuresUntilAssumedOffline) server := ServerStatistics{statistics: &stats} server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"}) relayServers := server.KnownRelayServers() diff --git a/federationapi/storage/fake_federation_db.go b/federationapi/storage/fake_federation_db.go deleted file mode 100644 index 0031ba871..000000000 --- a/federationapi/storage/fake_federation_db.go +++ /dev/null @@ -1,337 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "context" - "encoding/json" - "errors" - "sync" - "time" - - "github.com/matrix-org/dendrite/federationapi/storage/shared" - "github.com/matrix-org/gomatrixserverlib" -) - -var nidMutex sync.Mutex -var nid = int64(0) - -type FakeFederationDatabase struct { - Database - dbMutex sync.Mutex - pendingPDUServers map[gomatrixserverlib.ServerName]struct{} - pendingEDUServers map[gomatrixserverlib.ServerName]struct{} - blacklistedServers map[gomatrixserverlib.ServerName]struct{} - assumedOffline map[gomatrixserverlib.ServerName]struct{} - pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent - pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU - associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} - associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} - relayServers map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName -} - -func NewFakeFederationDatabase() *FakeFederationDatabase { - return &FakeFederationDatabase{ - pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), - assumedOffline: make(map[gomatrixserverlib.ServerName]struct{}), - pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent), - pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU), - associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), - associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), - relayServers: make(map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName), - } -} - -func (d *FakeFederationDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - var event gomatrixserverlib.HeaderedEvent - if err := json.Unmarshal([]byte(js), &event); err == nil { - nidMutex.Lock() - defer nidMutex.Unlock() - nid++ - receipt := shared.NewReceipt(nid) - d.pendingPDUs[&receipt] = &event - return &receipt, nil - } - - var edu gomatrixserverlib.EDU - if err := json.Unmarshal([]byte(js), &edu); err == nil { - nidMutex.Lock() - defer nidMutex.Unlock() - nid++ - receipt := shared.NewReceipt(nid) - d.pendingEDUs[&receipt] = &edu - return &receipt, nil - } - - return nil, errors.New("Failed to determine type of json to store") -} - -func (d *FakeFederationDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - pduCount := 0 - pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent) - if receipts, ok := d.associatedPDUs[serverName]; ok { - for receipt := range receipts { - if event, ok := d.pendingPDUs[receipt]; ok { - pdus[receipt] = event - pduCount++ - if pduCount == limit { - break - } - } - } - } - return pdus, nil -} - -func (d *FakeFederationDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - eduCount := 0 - edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU) - if receipts, ok := d.associatedEDUs[serverName]; ok { - for receipt := range receipts { - if event, ok := d.pendingEDUs[receipt]; ok { - edus[receipt] = event - eduCount++ - if eduCount == limit { - break - } - } - } - } - return edus, nil -} - -func (d *FakeFederationDatabase) AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if _, ok := d.pendingPDUs[receipt]; ok { - for destination := range destinations { - if _, ok := d.associatedPDUs[destination]; !ok { - d.associatedPDUs[destination] = make(map[*shared.Receipt]struct{}) - } - d.associatedPDUs[destination][receipt] = struct{}{} - } - - return nil - } else { - return errors.New("PDU doesn't exist") - } -} - -func (d *FakeFederationDatabase) AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if _, ok := d.pendingEDUs[receipt]; ok { - for destination := range destinations { - if _, ok := d.associatedEDUs[destination]; !ok { - d.associatedEDUs[destination] = make(map[*shared.Receipt]struct{}) - } - d.associatedEDUs[destination][receipt] = struct{}{} - } - - return nil - } else { - return errors.New("EDU doesn't exist") - } -} - -func (d *FakeFederationDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if pdus, ok := d.associatedPDUs[serverName]; ok { - for _, receipt := range receipts { - delete(pdus, receipt) - } - } - - return nil -} - -func (d *FakeFederationDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if edus, ok := d.associatedEDUs[serverName]; ok { - for _, receipt := range receipts { - delete(edus, receipt) - } - } - - return nil -} - -func (d *FakeFederationDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - var count int64 - if pdus, ok := d.associatedPDUs[serverName]; ok { - count = int64(len(pdus)) - } - return count, nil -} - -func (d *FakeFederationDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - var count int64 - if edus, ok := d.associatedEDUs[serverName]; ok { - count = int64(len(edus)) - } - return count, nil -} - -func (d *FakeFederationDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - servers := []gomatrixserverlib.ServerName{} - for server := range d.pendingPDUServers { - servers = append(servers, server) - } - return servers, nil -} - -func (d *FakeFederationDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - servers := []gomatrixserverlib.ServerName{} - for server := range d.pendingEDUServers { - servers = append(servers, server) - } - return servers, nil -} - -func (d *FakeFederationDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - d.blacklistedServers[serverName] = struct{}{} - return nil -} - -func (d *FakeFederationDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - delete(d.blacklistedServers, serverName) - return nil -} - -func (d *FakeFederationDatabase) RemoveAllServersFromBlacklist() error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) - return nil -} - -func (d *FakeFederationDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - isBlacklisted := false - if _, ok := d.blacklistedServers[serverName]; ok { - isBlacklisted = true - } - - return isBlacklisted, nil -} - -func (d *FakeFederationDatabase) SetServerAssumedOffline(serverName gomatrixserverlib.ServerName) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - d.assumedOffline[serverName] = struct{}{} - return nil -} - -func (d *FakeFederationDatabase) RemoveServerAssumedOffline(serverName gomatrixserverlib.ServerName) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - delete(d.assumedOffline, serverName) - return nil -} - -func (d *FakeFederationDatabase) RemoveAllServersAssumedOffine() error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - d.assumedOffline = make(map[gomatrixserverlib.ServerName]struct{}) - return nil -} - -func (d *FakeFederationDatabase) IsServerAssumedOffline(serverName gomatrixserverlib.ServerName) (bool, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - assumedOffline := false - if _, ok := d.assumedOffline[serverName]; ok { - assumedOffline = true - } - - return assumedOffline, nil -} - -func (d *FakeFederationDatabase) GetRelayServersForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - knownRelayServers := []gomatrixserverlib.ServerName{} - if relayServers, ok := d.relayServers[serverName]; ok { - knownRelayServers = relayServers - } - - return knownRelayServers, nil -} - -func (d *FakeFederationDatabase) AddRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if knownRelayServers, ok := d.relayServers[serverName]; ok { - for _, relayServer := range relayServers { - alreadyKnown := false - for _, knownRelayServer := range knownRelayServers { - if relayServer == knownRelayServer { - alreadyKnown = true - } - } - if !alreadyKnown { - d.relayServers[serverName] = append(d.relayServers[serverName], relayServer) - } - } - } else { - d.relayServers[serverName] = relayServers - } - - return nil -} diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 2340b4fc7..7cf23273d 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -20,7 +20,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/types" ) @@ -34,16 +34,16 @@ type Database interface { // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) - StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) + StoreJSON(ctx context.Context, js string) (*receipt.Receipt, error) - GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) - GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) + GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) + GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) - AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error - AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error + AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *receipt.Receipt) error + AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error - CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error - CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error + CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error + CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) diff --git a/federationapi/storage/shared/receipt/receipt.go b/federationapi/storage/shared/receipt/receipt.go new file mode 100644 index 000000000..33a71a1a8 --- /dev/null +++ b/federationapi/storage/shared/receipt/receipt.go @@ -0,0 +1,37 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// A Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. +// We don't actually export the NIDs but we need the caller to be able +// to pass them back so that we can clean up if the transaction sends +// successfully. + +package receipt + +import "fmt" + +type Receipt struct { + nid int64 +} + +func NewReceipt(nid int64) Receipt { + return Receipt{nid: nid} +} + +func (r *Receipt) GetNID() int64 { + return r.nid +} + +func (r *Receipt) String() string { + return fmt.Sprintf("%d", r.nid) +} diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 3668d6e2c..130869340 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -20,6 +20,7 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/storage/tables" "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal/caching" @@ -46,26 +47,6 @@ type Database struct { ServerSigningKeys tables.FederationServerSigningKeys } -// An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. -// We don't actually export the NIDs but we need the caller to be able -// to pass them back so that we can clean up if the transaction sends -// successfully. -type Receipt struct { - nid int64 -} - -func NewReceipt(nid int64) Receipt { - return Receipt{nid: nid} -} - -func (r *Receipt) GetNID() int64 { - return r.nid -} - -func (r *Receipt) String() string { - return fmt.Sprintf("%d", r.nid) -} - // UpdateRoom updates the joined hosts for a room and returns what the joined // hosts were before the update, or nil if this was a duplicate message. // This is called when we receive a message from kafka, so we pass in @@ -145,7 +126,7 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, // metadata entries. func (d *Database) StoreJSON( ctx context.Context, js string, -) (*Receipt, error) { +) (*receipt.Receipt, error) { var nid int64 var err error _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -155,9 +136,8 @@ func (d *Database) StoreJSON( if err != nil { return nil, fmt.Errorf("d.insertQueueJSON: %w", err) } - return &Receipt{ - nid: nid, - }, nil + newReceipt := receipt.NewReceipt(nid) + return &newReceipt, nil } func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index be8355f31..ae7e884b8 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -22,6 +22,7 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/gomatrixserverlib" ) @@ -41,7 +42,7 @@ var defaultExpireEDUTypes = map[string]time.Duration{ func (d *Database) AssociateEDUWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, - receipt *Receipt, + receipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration, ) error { @@ -62,12 +63,12 @@ func (d *Database) AssociateEDUWithDestinations( var err error for destination := range destinations { err = d.FederationQueueEDUs.InsertQueueEDU( - ctx, // context - txn, // SQL transaction - eduType, // EDU type for coalescing - destination, // destination server name - receipt.nid, // NID from the federationapi_queue_json table - expiresAt, // The timestamp this EDU will expire + ctx, // context + txn, // SQL transaction + eduType, // EDU type for coalescing + destination, // destination server name + receipt.GetNID(), // NID from the federationapi_queue_json table + expiresAt, // The timestamp this EDU will expire ) } return err @@ -81,10 +82,10 @@ func (d *Database) GetPendingEDUs( serverName gomatrixserverlib.ServerName, limit int, ) ( - edus map[*Receipt]*gomatrixserverlib.EDU, + edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error, ) { - edus = make(map[*Receipt]*gomatrixserverlib.EDU) + edus = make(map[*receipt.Receipt]*gomatrixserverlib.EDU) err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { nids, err := d.FederationQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit) if err != nil { @@ -94,7 +95,8 @@ func (d *Database) GetPendingEDUs( retrieve := make([]int64, 0, len(nids)) for _, nid := range nids { if edu, ok := d.Cache.GetFederationQueuedEDU(nid); ok { - edus[&Receipt{nid}] = edu + newReceipt := receipt.NewReceipt(nid) + edus[&newReceipt] = edu } else { retrieve = append(retrieve, nid) } @@ -110,7 +112,8 @@ func (d *Database) GetPendingEDUs( if err := json.Unmarshal(blob, &event); err != nil { return fmt.Errorf("json.Unmarshal: %w", err) } - edus[&Receipt{nid}] = &event + newReceipt := receipt.NewReceipt(nid) + edus[&newReceipt] = &event d.Cache.StoreFederationQueuedEDU(nid, &event) } @@ -124,7 +127,7 @@ func (d *Database) GetPendingEDUs( func (d *Database) CleanEDUs( ctx context.Context, serverName gomatrixserverlib.ServerName, - receipts []*Receipt, + receipts []*receipt.Receipt, ) error { if len(receipts) == 0 { return errors.New("expected receipt") @@ -132,7 +135,7 @@ func (d *Database) CleanEDUs( nids := make([]int64, len(receipts)) for i := range receipts { - nids[i] = receipts[i].nid + nids[i] = receipts[i].GetNID() } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { diff --git a/federationapi/storage/shared/storage_pdus.go b/federationapi/storage/shared/storage_pdus.go index da4cb979d..0f5844520 100644 --- a/federationapi/storage/shared/storage_pdus.go +++ b/federationapi/storage/shared/storage_pdus.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/gomatrixserverlib" ) @@ -30,17 +31,17 @@ import ( func (d *Database) AssociatePDUWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, - receipt *Receipt, + receipt *receipt.Receipt, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error for destination := range destinations { err = d.FederationQueuePDUs.InsertQueuePDU( - ctx, // context - txn, // SQL transaction - "", // transaction ID - destination, // destination server name - receipt.nid, // NID from the federationapi_queue_json table + ctx, // context + txn, // SQL transaction + "", // transaction ID + destination, // destination server name + receipt.GetNID(), // NID from the federationapi_queue_json table ) } return err @@ -54,7 +55,7 @@ func (d *Database) GetPendingPDUs( serverName gomatrixserverlib.ServerName, limit int, ) ( - events map[*Receipt]*gomatrixserverlib.HeaderedEvent, + events map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error, ) { // Strictly speaking this doesn't need to be using the writer @@ -62,7 +63,7 @@ func (d *Database) GetPendingPDUs( // a guarantee of transactional isolation, it's actually useful // to know in SQLite mode that nothing else is trying to modify // the database. - events = make(map[*Receipt]*gomatrixserverlib.HeaderedEvent) + events = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent) err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { nids, err := d.FederationQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, limit) if err != nil { @@ -72,7 +73,8 @@ func (d *Database) GetPendingPDUs( retrieve := make([]int64, 0, len(nids)) for _, nid := range nids { if event, ok := d.Cache.GetFederationQueuedPDU(nid); ok { - events[&Receipt{nid}] = event + newReceipt := receipt.NewReceipt(nid) + events[&newReceipt] = event } else { retrieve = append(retrieve, nid) } @@ -88,7 +90,8 @@ func (d *Database) GetPendingPDUs( if err := json.Unmarshal(blob, &event); err != nil { return fmt.Errorf("json.Unmarshal: %w", err) } - events[&Receipt{nid}] = &event + newReceipt := receipt.NewReceipt(nid) + events[&newReceipt] = &event d.Cache.StoreFederationQueuedPDU(nid, &event) } @@ -103,7 +106,7 @@ func (d *Database) GetPendingPDUs( func (d *Database) CleanPDUs( ctx context.Context, serverName gomatrixserverlib.ServerName, - receipts []*Receipt, + receipts []*receipt.Receipt, ) error { if len(receipts) == 0 { return errors.New("expected receipt") @@ -111,7 +114,7 @@ func (d *Database) CleanPDUs( nids := make([]int64, len(receipts)) for i := range receipts { - nids[i] = receipts[i].nid + nids[i] = receipts[i].GetNID() } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { diff --git a/federationapi/storage/tables/relay_servers_table_test.go b/federationapi/storage/tables/relay_servers_table_test.go index 86d1c1e8d..4d29514e8 100644 --- a/federationapi/storage/tables/relay_servers_table_test.go +++ b/federationapi/storage/tables/relay_servers_table_test.go @@ -27,7 +27,10 @@ type RelayServersDatabase struct { Table tables.FederationRelayServers } -func mustCreateRelayServersTable(t *testing.T, dbType test.DBType) (database RelayServersDatabase, close func()) { +func mustCreateRelayServersTable( + t *testing.T, + dbType test.DBType, +) (database RelayServersDatabase, close func()) { t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) db, err := sqlutil.Open(&config.DatabaseOptions{ diff --git a/relayapi/internal/perform.go b/relayapi/internal/perform.go index 631a24454..d5999abea 100644 --- a/relayapi/internal/perform.go +++ b/relayapi/internal/perform.go @@ -17,7 +17,7 @@ package internal import ( "context" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -87,8 +87,8 @@ func (r *RelayInternalAPI) QueryTransactions( request.PreviousEntry.EntryID, request.UserID.Raw(), ) - prevReceipt := shared.NewReceipt(request.PreviousEntry.EntryID) - err := r.db.CleanTransactions(ctx, request.UserID, []*shared.Receipt{&prevReceipt}) + prevReceipt := receipt.NewReceipt(request.PreviousEntry.EntryID) + err := r.db.CleanTransactions(ctx, request.UserID, []*receipt.Receipt{&prevReceipt}) if err != nil { logrus.Errorf("db.CleanTransactions: %s", err.Error()) return err diff --git a/relayapi/internal/perform_test.go b/relayapi/internal/perform_test.go index be11fbcf3..c2f9dd1c7 100644 --- a/relayapi/internal/perform_test.go +++ b/relayapi/internal/perform_test.go @@ -22,8 +22,8 @@ import ( fedAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/relayapi/api" - "github.com/matrix-org/dendrite/relayapi/storage" "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/stretchr/testify/assert" ) @@ -35,7 +35,12 @@ type testFedClient struct { queueDepth uint } -func (f *testFedClient) P2PGetTransactionFromRelay(ctx context.Context, u gomatrixserverlib.UserID, prev gomatrixserverlib.RelayEntry, relayServer gomatrixserverlib.ServerName) (res gomatrixserverlib.RespGetRelayTransaction, err error) { +func (f *testFedClient) P2PGetTransactionFromRelay( + ctx context.Context, + u gomatrixserverlib.UserID, + prev gomatrixserverlib.RelayEntry, + relayServer gomatrixserverlib.ServerName, +) (res gomatrixserverlib.RespGetRelayTransaction, err error) { f.queryCount++ if !f.shouldFail { res = gomatrixserverlib.RespGetRelayTransaction{ @@ -56,7 +61,7 @@ func (f *testFedClient) P2PGetTransactionFromRelay(ctx context.Context, u gomatr } func TestPerformRelayServerSync(t *testing.T) { - testDB := storage.NewFakeRelayDatabase() + testDB := test.NewInMemoryRelayDatabase() db := shared.Database{ Writer: sqlutil.NewDummyWriter(), RelayQueue: testDB, @@ -81,7 +86,7 @@ func TestPerformRelayServerSync(t *testing.T) { } func TestPerformRelayServerSyncFedError(t *testing.T) { - testDB := storage.NewFakeRelayDatabase() + testDB := test.NewInMemoryRelayDatabase() db := shared.Database{ Writer: sqlutil.NewDummyWriter(), RelayQueue: testDB, @@ -106,7 +111,7 @@ func TestPerformRelayServerSyncFedError(t *testing.T) { } func TestPerformRelayServerSyncRunsUntilQueueEmpty(t *testing.T) { - testDB := storage.NewFakeRelayDatabase() + testDB := test.NewInMemoryRelayDatabase() db := shared.Database{ Writer: sqlutil.NewDummyWriter(), RelayQueue: testDB, diff --git a/relayapi/inthttp/client.go b/relayapi/inthttp/client.go index c86edda8e..6866686cf 100644 --- a/relayapi/inthttp/client.go +++ b/relayapi/inthttp/client.go @@ -33,7 +33,11 @@ const ( // NewRelayAPIClient creates a RelayInternalAPI implemented by talking to a HTTP POST API. // If httpClient is nil an error is returned -func NewRelayAPIClient(relayapiURL string, httpClient *http.Client, cache caching.ServerKeyCache) (api.RelayInternalAPI, error) { +func NewRelayAPIClient( + relayapiURL string, + httpClient *http.Client, + cache caching.ServerKeyCache, +) (api.RelayInternalAPI, error) { if httpClient == nil { return nil, errors.New("NewRelayInternalAPIHTTP: httpClient is ") } diff --git a/relayapi/inthttp/server.go b/relayapi/inthttp/server.go index b0fc40eab..0385fa51b 100644 --- a/relayapi/inthttp/server.go +++ b/relayapi/inthttp/server.go @@ -26,16 +26,28 @@ import ( func AddRoutes(intAPI api.RelayInternalAPI, internalAPIMux *mux.Router, enableMetrics bool) { internalAPIMux.Handle( RelayAPIPerformRelayServerSyncPath, - httputil.MakeInternalRPCAPI("RelayAPIPerformRelayServerSync", enableMetrics, intAPI.PerformRelayServerSync), + httputil.MakeInternalRPCAPI( + "RelayAPIPerformRelayServerSync", + enableMetrics, + intAPI.PerformRelayServerSync, + ), ) internalAPIMux.Handle( RelayAPIPerformStoreTransactionPath, - httputil.MakeInternalRPCAPI("RelayAPIPerformStoreTransaction", enableMetrics, intAPI.PerformStoreTransaction), + httputil.MakeInternalRPCAPI( + "RelayAPIPerformStoreTransaction", + enableMetrics, + intAPI.PerformStoreTransaction, + ), ) internalAPIMux.Handle( RelayAPIQueryTransactionsPath, - httputil.MakeInternalRPCAPI("RelayAPIQueryTransactions", enableMetrics, intAPI.QueryTransactions), + httputil.MakeInternalRPCAPI( + "RelayAPIQueryTransactions", + enableMetrics, + intAPI.QueryTransactions, + ), ) } diff --git a/relayapi/routing/relaytxn_test.go b/relayapi/routing/relaytxn_test.go index c2a95e517..a0b03a398 100644 --- a/relayapi/routing/relaytxn_test.go +++ b/relayapi/routing/relaytxn_test.go @@ -22,8 +22,8 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/relayapi/internal" "github.com/matrix-org/dendrite/relayapi/routing" - "github.com/matrix-org/dendrite/relayapi/storage" "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/stretchr/testify/assert" ) @@ -41,7 +41,7 @@ func createQuery( } func TestGetEmptyDatabaseReturnsNothing(t *testing.T) { - testDB := storage.NewFakeRelayDatabase() + testDB := test.NewInMemoryRelayDatabase() db := shared.Database{ Writer: sqlutil.NewDummyWriter(), RelayQueue: testDB, @@ -74,7 +74,7 @@ func TestGetEmptyDatabaseReturnsNothing(t *testing.T) { } func TestGetReturnsSavedTransaction(t *testing.T) { - testDB := storage.NewFakeRelayDatabase() + testDB := test.NewInMemoryRelayDatabase() db := shared.Database{ Writer: sqlutil.NewDummyWriter(), RelayQueue: testDB, @@ -124,7 +124,7 @@ func TestGetReturnsSavedTransaction(t *testing.T) { } func TestGetReturnsMultipleSavedTransactions(t *testing.T) { - testDB := storage.NewFakeRelayDatabase() + testDB := test.NewInMemoryRelayDatabase() db := shared.Database{ Writer: sqlutil.NewDummyWriter(), RelayQueue: testDB, diff --git a/relayapi/routing/sendrelay_test.go b/relayapi/routing/sendrelay_test.go index b5f2ac6dd..1be1dafa4 100644 --- a/relayapi/routing/sendrelay_test.go +++ b/relayapi/routing/sendrelay_test.go @@ -23,8 +23,8 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/relayapi/internal" "github.com/matrix-org/dendrite/relayapi/routing" - "github.com/matrix-org/dendrite/relayapi/storage" "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/stretchr/testify/assert" ) @@ -58,7 +58,7 @@ func createFederationRequest( } func TestForwardEmptyReturnsOk(t *testing.T) { - testDB := storage.NewFakeRelayDatabase() + testDB := test.NewInMemoryRelayDatabase() db := shared.Database{ Writer: sqlutil.NewDummyWriter(), RelayQueue: testDB, @@ -81,7 +81,7 @@ func TestForwardEmptyReturnsOk(t *testing.T) { } func TestForwardBadJSONReturnsError(t *testing.T) { - testDB := storage.NewFakeRelayDatabase() + testDB := test.NewInMemoryRelayDatabase() db := shared.Database{ Writer: sqlutil.NewDummyWriter(), RelayQueue: testDB, @@ -110,7 +110,7 @@ func TestForwardBadJSONReturnsError(t *testing.T) { } func TestForwardTooManyPDUsReturnsError(t *testing.T) { - testDB := storage.NewFakeRelayDatabase() + testDB := test.NewInMemoryRelayDatabase() db := shared.Database{ Writer: sqlutil.NewDummyWriter(), RelayQueue: testDB, @@ -144,7 +144,7 @@ func TestForwardTooManyPDUsReturnsError(t *testing.T) { } func TestForwardTooManyEDUsReturnsError(t *testing.T) { - testDB := storage.NewFakeRelayDatabase() + testDB := test.NewInMemoryRelayDatabase() db := shared.Database{ Writer: sqlutil.NewDummyWriter(), RelayQueue: testDB, @@ -178,7 +178,7 @@ func TestForwardTooManyEDUsReturnsError(t *testing.T) { } func TestUniqueTransactionStoredInDatabase(t *testing.T) { - testDB := storage.NewFakeRelayDatabase() + testDB := test.NewInMemoryRelayDatabase() db := shared.Database{ Writer: sqlutil.NewDummyWriter(), RelayQueue: testDB, diff --git a/relayapi/storage/fake_relay_db.go b/relayapi/storage/fake_relay_db.go deleted file mode 100644 index 4c011e0bc..000000000 --- a/relayapi/storage/fake_relay_db.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "context" - "database/sql" - "encoding/json" - "sync" - - "github.com/matrix-org/gomatrixserverlib" -) - -type testDatabase struct { - nid int64 - nidMutex sync.Mutex - transactions map[int64]json.RawMessage - associations map[gomatrixserverlib.ServerName][]int64 -} - -func NewFakeRelayDatabase() *testDatabase { - return &testDatabase{ - nid: 1, - nidMutex: sync.Mutex{}, - transactions: make(map[int64]json.RawMessage), - associations: make(map[gomatrixserverlib.ServerName][]int64), - } -} - -func (d *testDatabase) InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error { - if _, ok := d.associations[serverName]; !ok { - d.associations[serverName] = []int64{} - } - d.associations[serverName] = append(d.associations[serverName], nid) - return nil -} - -func (d *testDatabase) DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error { - for _, nid := range jsonNIDs { - for index, associatedNID := range d.associations[serverName] { - if associatedNID == nid { - d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...) - } - } - } - - return nil -} - -func (d *testDatabase) SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) { - results := []int64{} - resultCount := limit - if limit > len(d.associations[serverName]) { - resultCount = len(d.associations[serverName]) - } - if resultCount > 0 { - for i := 0; i < resultCount; i++ { - results = append(results, d.associations[serverName][i]) - } - } - - return results, nil -} - -func (d *testDatabase) SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) { - return int64(len(d.associations[serverName])), nil -} - -func (d *testDatabase) InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) { - d.nidMutex.Lock() - defer d.nidMutex.Unlock() - - nid := d.nid - d.transactions[nid] = []byte(json) - d.nid++ - - return nid, nil -} - -func (d *testDatabase) DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error { - for _, nid := range nids { - delete(d.transactions, nid) - } - - return nil -} - -func (d *testDatabase) SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) { - result := make(map[int64][]byte) - for _, nid := range jsonNIDs { - if transaction, ok := d.transactions[nid]; ok { - result[nid] = transaction - } - } - - return result, nil -} diff --git a/relayapi/storage/interface.go b/relayapi/storage/interface.go index 3fe57b9f6..d39b89aae 100644 --- a/relayapi/storage/interface.go +++ b/relayapi/storage/interface.go @@ -17,14 +17,14 @@ package storage import ( "context" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/gomatrixserverlib" ) type Database interface { - StoreTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*shared.Receipt, error) - AssociateTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *shared.Receipt) error - CleanTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*shared.Receipt) error - GetTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *shared.Receipt, error) + StoreTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*receipt.Receipt, error) + AssociateTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *receipt.Receipt) error + CleanTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*receipt.Receipt) error + GetTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) GetTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error) } diff --git a/relayapi/storage/postgres/relay_queue_table.go b/relayapi/storage/postgres/relay_queue_table.go index 929fad22b..a8ba72902 100644 --- a/relayapi/storage/postgres/relay_queue_table.go +++ b/relayapi/storage/postgres/relay_queue_table.go @@ -66,7 +66,9 @@ type relayQueueStatements struct { selectQueueEntryCountStmt *sql.Stmt } -func NewPostgresRelayQueueTable(db *sql.DB) (s *relayQueueStatements, err error) { +func NewPostgresRelayQueueTable( + db *sql.DB, +) (s *relayQueueStatements, err error) { s = &relayQueueStatements{ db: db, } @@ -101,7 +103,8 @@ func (s *relayQueueStatements) InsertQueueEntry( } func (s *relayQueueStatements) DeleteQueueEntries( - ctx context.Context, txn *sql.Tx, + ctx context.Context, + txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64, ) error { @@ -111,7 +114,8 @@ func (s *relayQueueStatements) DeleteQueueEntries( } func (s *relayQueueStatements) SelectQueueEntries( - ctx context.Context, txn *sql.Tx, + ctx context.Context, + txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int, ) ([]int64, error) { @@ -134,7 +138,9 @@ func (s *relayQueueStatements) SelectQueueEntries( } func (s *relayQueueStatements) SelectQueueEntryCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, ) (int64, error) { var count int64 stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt) diff --git a/relayapi/storage/postgres/storage.go b/relayapi/storage/postgres/storage.go index 3902cc8ab..1042beba7 100644 --- a/relayapi/storage/postgres/storage.go +++ b/relayapi/storage/postgres/storage.go @@ -33,7 +33,12 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (*Database, error) { var d Database var err error if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { diff --git a/relayapi/storage/shared/storage.go b/relayapi/storage/shared/storage.go index 3d69a91ba..aef24f47e 100644 --- a/relayapi/storage/shared/storage.go +++ b/relayapi/storage/shared/storage.go @@ -20,7 +20,7 @@ import ( "encoding/json" "fmt" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/relayapi/storage/tables" @@ -37,8 +37,9 @@ type Database struct { } func (d *Database) StoreTransaction( - ctx context.Context, transaction gomatrixserverlib.Transaction, -) (*shared.Receipt, error) { + ctx context.Context, + transaction gomatrixserverlib.Transaction, +) (*receipt.Receipt, error) { var err error json, err := json.Marshal(transaction) if err != nil { @@ -54,7 +55,7 @@ func (d *Database) StoreTransaction( return nil, fmt.Errorf("d.insertQueueJSON: %w", err) } - receipt := shared.NewReceipt(nid) + receipt := receipt.NewReceipt(nid) return &receipt, nil } @@ -62,7 +63,7 @@ func (d *Database) AssociateTransactionWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, - receipt *shared.Receipt, + receipt *receipt.Receipt, ) error { for destination := range destinations { err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -81,7 +82,7 @@ func (d *Database) AssociateTransactionWithDestinations( func (d *Database) CleanTransactions( ctx context.Context, userID gomatrixserverlib.UserID, - receipts []*shared.Receipt, + receipts []*receipt.Receipt, ) error { nids := make([]int64, len(receipts)) for i, receipt := range receipts { @@ -110,7 +111,7 @@ func (d *Database) CleanTransactions( func (d *Database) GetTransaction( ctx context.Context, userID gomatrixserverlib.UserID, -) (*gomatrixserverlib.Transaction, *shared.Receipt, error) { +) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) { nids, err := d.RelayQueue.SelectQueueEntries(ctx, nil, userID.Domain(), 1) if err != nil { return nil, nil, fmt.Errorf("d.SelectQueueEntries: %w", err) @@ -134,7 +135,7 @@ func (d *Database) GetTransaction( return nil, nil, fmt.Errorf("Unmarshal transaction: %w", err) } - receipt := shared.NewReceipt(nids[0]) + receipt := receipt.NewReceipt(nids[0]) return transaction, &receipt, nil } diff --git a/relayapi/storage/sqlite3/relay_queue_table.go b/relayapi/storage/sqlite3/relay_queue_table.go index 778e15c13..72a2d64c3 100644 --- a/relayapi/storage/sqlite3/relay_queue_table.go +++ b/relayapi/storage/sqlite3/relay_queue_table.go @@ -67,7 +67,9 @@ type relayQueueStatements struct { // deleteQueueEntriesStmt *sql.Stmt - prepared at runtime due to variadic } -func NewSQLiteRelayQueueTable(db *sql.DB) (s *relayQueueStatements, err error) { +func NewSQLiteRelayQueueTable( + db *sql.DB, +) (s *relayQueueStatements, err error) { s = &relayQueueStatements{ db: db, } @@ -101,7 +103,8 @@ func (s *relayQueueStatements) InsertQueueEntry( } func (s *relayQueueStatements) DeleteQueueEntries( - ctx context.Context, txn *sql.Tx, + ctx context.Context, + txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64, ) error { @@ -123,7 +126,8 @@ func (s *relayQueueStatements) DeleteQueueEntries( } func (s *relayQueueStatements) SelectQueueEntries( - ctx context.Context, txn *sql.Tx, + ctx context.Context, + txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int, ) ([]int64, error) { @@ -146,7 +150,9 @@ func (s *relayQueueStatements) SelectQueueEntries( } func (s *relayQueueStatements) SelectQueueEntryCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, ) (int64, error) { var count int64 stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt) diff --git a/relayapi/storage/sqlite3/storage.go b/relayapi/storage/sqlite3/storage.go index dbc698c39..3ed4ab046 100644 --- a/relayapi/storage/sqlite3/storage.go +++ b/relayapi/storage/sqlite3/storage.go @@ -33,7 +33,12 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (*Database, error) { var d Database var err error if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { diff --git a/relayapi/storage/storage.go b/relayapi/storage/storage.go index e4cefc1fd..16ecbcfb7 100644 --- a/relayapi/storage/storage.go +++ b/relayapi/storage/storage.go @@ -29,7 +29,12 @@ import ( ) // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (Database, error) { +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName) diff --git a/relayapi/storage/tables/relay_queue_json_table_test.go b/relayapi/storage/tables/relay_queue_json_table_test.go index 3bab1ac30..efa3363e5 100644 --- a/relayapi/storage/tables/relay_queue_json_table_test.go +++ b/relayapi/storage/tables/relay_queue_json_table_test.go @@ -50,7 +50,10 @@ type RelayQueueJSONDatabase struct { Table tables.RelayQueueJSON } -func mustCreateQueueJSONTable(t *testing.T, dbType test.DBType) (database RelayQueueJSONDatabase, close func()) { +func mustCreateQueueJSONTable( + t *testing.T, + dbType test.DBType, +) (database RelayQueueJSONDatabase, close func()) { t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) db, err := sqlutil.Open(&config.DatabaseOptions{ diff --git a/relayapi/storage/tables/relay_queue_table_test.go b/relayapi/storage/tables/relay_queue_table_test.go index 52ea7a000..dc45c02af 100644 --- a/relayapi/storage/tables/relay_queue_table_test.go +++ b/relayapi/storage/tables/relay_queue_table_test.go @@ -37,7 +37,10 @@ type RelayQueueDatabase struct { Table tables.RelayQueue } -func mustCreateQueueTable(t *testing.T, dbType test.DBType) (database RelayQueueDatabase, close func()) { +func mustCreateQueueTable( + t *testing.T, + dbType test.DBType, +) (database RelayQueueDatabase, close func()) { t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) db, err := sqlutil.Open(&config.DatabaseOptions{