From f45d612ebd498b7d9906a752eba5fe39e9719e2e Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 13 Jan 2023 11:35:49 -0700 Subject: [PATCH] Refactor federationapi db to better distinguish p2p specifics --- federationapi/internal/perform.go | 10 +- federationapi/internal/perform_test.go | 12 +- federationapi/queue/queue_test.go | 8 +- federationapi/statistics/statistics.go | 11 +- federationapi/storage/interface.go | 17 ++- federationapi/storage/shared/storage.go | 146 +++++++++++++++++++----- federationapi/storage/storage_test.go | 38 +++--- test/memory_federation_db.go | 19 ++- 8 files changed, 178 insertions(+), 83 deletions(-) diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index a22945885..552942f28 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -738,7 +738,9 @@ func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error { return fmt.Errorf("auth chain response is missing m.room.create event") } -func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder) gomatrixserverlib.RoomVersion { +func setDefaultRoomVersionFromJoinEvent( + joinEvent gomatrixserverlib.EventBuilder, +) gomatrixserverlib.RoomVersion { // if auth events are not event references we know it must be v3+ // we have to do these shenanigans to satisfy sytest, specifically for: // "Outbound federation rejects m.room.create events with an unknown room version" @@ -829,7 +831,7 @@ func (r *FederationInternalAPI) P2PQueryRelayServers( response *api.P2PQueryRelayServersResponse, ) error { logrus.Infof("Getting relay servers for: %s", request.Server) - relayServers, err := r.db.GetRelayServersForServer(request.Server) + relayServers, err := r.db.P2PGetRelayServersForServer(ctx, request.Server) if err != nil { return err } @@ -838,7 +840,9 @@ func (r *FederationInternalAPI) P2PQueryRelayServers( return nil } -func (r *FederationInternalAPI) shouldAttemptDirectFederation(destination gomatrixserverlib.ServerName) bool { +func (r *FederationInternalAPI) shouldAttemptDirectFederation( + destination gomatrixserverlib.ServerName, +) bool { var shouldRelay bool stats := r.statistics.ForServer(destination) if stats.AssumedOffline() && len(stats.KnownRelayServers()) > 0 { diff --git a/federationapi/internal/perform_test.go b/federationapi/internal/perform_test.go index dd68e266e..e8e0d00a3 100644 --- a/federationapi/internal/perform_test.go +++ b/federationapi/internal/perform_test.go @@ -44,11 +44,11 @@ func TestPerformWakeupServers(t *testing.T) { server := gomatrixserverlib.ServerName("wakeup") testDB.AddServerToBlacklist(server) - testDB.SetServerAssumedOffline(server) + testDB.SetServerAssumedOffline(context.Background(), server) blacklisted, err := testDB.IsServerBlacklisted(server) assert.NoError(t, err) assert.True(t, blacklisted) - offline, err := testDB.IsServerAssumedOffline(server) + offline, err := testDB.IsServerAssumedOffline(context.Background(), server) assert.NoError(t, err) assert.True(t, offline) @@ -81,7 +81,7 @@ func TestPerformWakeupServers(t *testing.T) { blacklisted, err = testDB.IsServerBlacklisted(server) assert.NoError(t, err) assert.False(t, blacklisted) - offline, err = testDB.IsServerAssumedOffline(server) + offline, err = testDB.IsServerAssumedOffline(context.Background(), server) assert.NoError(t, err) assert.False(t, offline) } @@ -91,7 +91,7 @@ func TestQueryRelayServers(t *testing.T) { server := gomatrixserverlib.ServerName("wakeup") relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} - err := testDB.AddRelayServersForServer(server, relayServers) + err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers) assert.NoError(t, err) cfg := config.FederationAPI{ @@ -158,8 +158,8 @@ func TestPerformDirectoryLookupRelaying(t *testing.T) { testDB := test.NewInMemoryFederationDatabase() server := gomatrixserverlib.ServerName("wakeup") - testDB.SetServerAssumedOffline(server) - testDB.AddRelayServersForServer(server, []gomatrixserverlib.ServerName{"relay"}) + testDB.SetServerAssumedOffline(context.Background(), server) + testDB.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{"relay"}) cfg := config.FederationAPI{ Matrix: &config.Global{ diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index 069f8caca..36e2ccbc2 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -858,7 +858,7 @@ func TestSendPDUMultipleFailuresAssumedOffline(t *testing.T) { data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) assert.NoError(t, dbErr) if len(data) == 1 { - if val, _ := db.IsServerAssumedOffline(destination); val { + if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val { return poll.Success() } return poll.Continue("waiting for server to be assumed offline") @@ -891,7 +891,7 @@ func TestSendEDUMultipleFailuresAssumedOffline(t *testing.T) { data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) assert.NoError(t, dbErr) if len(data) == 1 { - if val, _ := db.IsServerAssumedOffline(destination); val { + if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val { return poll.Success() } return poll.Continue("waiting for server to be assumed offline") @@ -938,7 +938,7 @@ func TestSendPDUOnRelaySuccessRemovedFromDB(t *testing.T) { } poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) - assumedOffline, _ := db.IsServerAssumedOffline(destination) + assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination) assert.Equal(t, true, assumedOffline) } @@ -977,6 +977,6 @@ func TestSendEDUOnRelaySuccessRemovedFromDB(t *testing.T) { } poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) - assumedOffline, _ := db.IsServerAssumedOffline(destination) + assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination) assert.Equal(t, true, assumedOffline) } diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index 430819642..866c09336 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -1,6 +1,7 @@ package statistics import ( + "context" "math" "math/rand" "sync" @@ -72,14 +73,14 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS } else { server.blacklisted.Store(blacklisted) } - assumedOffline, err := s.DB.IsServerAssumedOffline(serverName) + assumedOffline, err := s.DB.IsServerAssumedOffline(context.Background(), serverName) if err != nil { logrus.WithError(err).Errorf("Failed to get assumed offline entry %q", serverName) } else { server.assumedOffline.Store(assumedOffline) } - knownRelayServers, err := s.DB.GetRelayServersForServer(serverName) + knownRelayServers, err := s.DB.P2PGetRelayServersForServer(context.Background(), serverName) if err != nil { logrus.WithError(err).Errorf("Failed to get relay server list for %q", serverName) } else { @@ -186,7 +187,7 @@ func (s *ServerStatistics) Failure() (time.Time, bool) { if backoffCount >= s.statistics.FailuresUntilAssumedOffline { s.assumedOffline.CompareAndSwap(false, true) if s.statistics.DB != nil { - if err := s.statistics.DB.SetServerAssumedOffline(s.serverName); err != nil { + if err := s.statistics.DB.SetServerAssumedOffline(context.Background(), s.serverName); err != nil { logrus.WithError(err).Errorf("Failed to set %q as assumed offline", s.serverName) } } @@ -291,7 +292,7 @@ func (s *ServerStatistics) removeBlacklist() bool { // removeAssumedOffline removes the assumed offline status from the server. func (s *ServerStatistics) removeAssumedOffline() { if s.AssumedOffline() { - _ = s.statistics.DB.RemoveServerAssumedOffline(s.serverName) + _ = s.statistics.DB.RemoveServerAssumedOffline(context.Background(), s.serverName) } s.assumedOffline.Store(false) } @@ -321,7 +322,7 @@ func (s *ServerStatistics) AddRelayServers(relayServers []gomatrixserverlib.Serv uniqueList = append(uniqueList, srv) } - err := s.statistics.DB.AddRelayServersForServer(s.serverName, uniqueList) + err := s.statistics.DB.P2PAddRelayServersForServer(context.Background(), s.serverName, uniqueList) if err != nil { logrus.WithError(err).Errorf("Failed to add relay servers for %q. Servers: %v", s.serverName, uniqueList) return diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 7cf23273d..3248fead6 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -54,16 +54,15 @@ type Database interface { RemoveAllServersFromBlacklist() error IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) - // these don't have contexts passed in as we want things to happen regardless of the request context - SetServerAssumedOffline(serverName gomatrixserverlib.ServerName) error - RemoveServerAssumedOffline(serverName gomatrixserverlib.ServerName) error - RemoveAllServersAssumedOffline() error - IsServerAssumedOffline(serverName gomatrixserverlib.ServerName) (bool, error) + SetServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error + RemoveServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error + RemoveAllServersAssumedOffline(ctx context.Context) error + IsServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) (bool, error) - AddRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error - GetRelayServersForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) - RemoveRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error - RemoveAllRelayServersForServer(serverName gomatrixserverlib.ServerName) error + P2PAddRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + P2PGetRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) + P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 130869340..2958b79ad 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -100,11 +100,18 @@ func (d *Database) GetJoinedHosts( // GetAllJoinedHosts returns the currently joined hosts for // all rooms known to the federation sender. // Returns an error if something goes wrong. -func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { +func (d *Database) GetAllJoinedHosts( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) } -func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { +func (d *Database) GetJoinedHostsForRooms( + ctx context.Context, + roomIDs []string, + excludeSelf, + excludeBlacklisted bool, +) ([]gomatrixserverlib.ServerName, error) { servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted) if err != nil { return nil, err @@ -140,13 +147,17 @@ func (d *Database) StoreJSON( return &newReceipt, nil } -func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { +func (d *Database) AddServerToBlacklist( + serverName gomatrixserverlib.ServerName, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName) }) } -func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { +func (d *Database) RemoveServerFromBlacklist( + serverName gomatrixserverlib.ServerName, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.DeleteBlacklist(context.TODO(), txn, serverName) }) @@ -158,95 +169,166 @@ func (d *Database) RemoveAllServersFromBlacklist() error { }) } -func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { +func (d *Database) IsServerBlacklisted( + serverName gomatrixserverlib.ServerName, +) (bool, error) { return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName) } -func (d *Database) SetServerAssumedOffline(serverName gomatrixserverlib.ServerName) error { +func (d *Database) SetServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.FederationAssumedOffline.InsertAssumedOffline(context.TODO(), txn, serverName) + return d.FederationAssumedOffline.InsertAssumedOffline(ctx, txn, serverName) }) } -func (d *Database) RemoveServerAssumedOffline(serverName gomatrixserverlib.ServerName) error { +func (d *Database) RemoveServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.FederationAssumedOffline.DeleteAssumedOffline(context.TODO(), txn, serverName) + return d.FederationAssumedOffline.DeleteAssumedOffline(ctx, txn, serverName) }) } -func (d *Database) RemoveAllServersAssumedOffline() error { +func (d *Database) RemoveAllServersAssumedOffline( + ctx context.Context, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.FederationAssumedOffline.DeleteAllAssumedOffline(context.TODO(), txn) + return d.FederationAssumedOffline.DeleteAllAssumedOffline(ctx, txn) }) } -func (d *Database) IsServerAssumedOffline(serverName gomatrixserverlib.ServerName) (bool, error) { - return d.FederationAssumedOffline.SelectAssumedOffline(context.TODO(), nil, serverName) +func (d *Database) IsServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (bool, error) { + return d.FederationAssumedOffline.SelectAssumedOffline(ctx, nil, serverName) } -func (d *Database) AddRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error { +func (d *Database) P2PAddRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.FederationRelayServers.InsertRelayServers(context.TODO(), txn, serverName, relayServers) + return d.FederationRelayServers.InsertRelayServers(ctx, txn, serverName, relayServers) }) } -func (d *Database) GetRelayServersForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) { - return d.FederationRelayServers.SelectRelayServers(context.TODO(), nil, serverName) +func (d *Database) P2PGetRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + return d.FederationRelayServers.SelectRelayServers(ctx, nil, serverName) } -func (d *Database) RemoveRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error { +func (d *Database) P2PRemoveRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.FederationRelayServers.DeleteRelayServers(context.TODO(), txn, serverName, relayServers) + return d.FederationRelayServers.DeleteRelayServers(ctx, txn, serverName, relayServers) }) } -func (d *Database) RemoveAllRelayServersForServer(serverName gomatrixserverlib.ServerName) error { +func (d *Database) P2PRemoveAllRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.FederationRelayServers.DeleteAllRelayServers(context.TODO(), txn, serverName) + return d.FederationRelayServers.DeleteAllRelayServers(ctx, txn, serverName) }) } -func (d *Database) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) AddOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) RenewOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationOutboundPeeks.RenewOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) { +func (d *Database) GetOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID, + peekID string, +) (*types.OutboundPeek, error) { return d.FederationOutboundPeeks.SelectOutboundPeek(ctx, nil, serverName, roomID, peekID) } -func (d *Database) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) { +func (d *Database) GetOutboundPeeks( + ctx context.Context, + roomID string, +) ([]types.OutboundPeek, error) { return d.FederationOutboundPeeks.SelectOutboundPeeks(ctx, nil, roomID) } -func (d *Database) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) AddInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationInboundPeeks.InsertInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) RenewInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationInboundPeeks.RenewInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) { +func (d *Database) GetInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, +) (*types.InboundPeek, error) { return d.FederationInboundPeeks.SelectInboundPeek(ctx, nil, serverName, roomID, peekID) } -func (d *Database) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) { +func (d *Database) GetInboundPeeks( + ctx context.Context, + roomID string, +) ([]types.InboundPeek, error) { return d.FederationInboundPeeks.SelectInboundPeeks(ctx, nil, roomID) } -func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error { +func (d *Database) UpdateNotaryKeys( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + serverKeys gomatrixserverlib.ServerKeys, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { validUntil := serverKeys.ValidUntilTS // Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid. @@ -281,7 +363,9 @@ func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserv } func (d *Database) GetNotaryKeys( - ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID, + ctx context.Context, + serverName gomatrixserverlib.ServerName, + optKeyIDs []gomatrixserverlib.KeyID, ) (sks []gomatrixserverlib.ServerKeys, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { sks, err = d.NotaryServerKeysMetadata.SelectKeys(ctx, txn, serverName, optKeyIDs) diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index f05055227..44a5f2028 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -254,39 +254,39 @@ func TestServersAssumedOffline(t *testing.T) { db, closeDB := mustCreateFederationDatabase(t, dbType) defer closeDB() - err := db.SetServerAssumedOffline(server1) + err := db.SetServerAssumedOffline(context.Background(), server1) assert.Nil(t, err) - isOffline, err := db.IsServerAssumedOffline(server1) + isOffline, err := db.IsServerAssumedOffline(context.Background(), server1) assert.Nil(t, err) assert.True(t, isOffline) - err = db.RemoveServerAssumedOffline(server1) + err = db.RemoveServerAssumedOffline(context.Background(), server1) assert.Nil(t, err) - isOffline, err = db.IsServerAssumedOffline(server1) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) assert.Nil(t, err) assert.False(t, isOffline) - err = db.SetServerAssumedOffline(server1) + err = db.SetServerAssumedOffline(context.Background(), server1) assert.Nil(t, err) - err = db.SetServerAssumedOffline(server2) + err = db.SetServerAssumedOffline(context.Background(), server2) assert.Nil(t, err) - isOffline, err = db.IsServerAssumedOffline(server1) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) assert.Nil(t, err) assert.True(t, isOffline) - isOffline, err = db.IsServerAssumedOffline(server2) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) assert.Nil(t, err) assert.True(t, isOffline) - err = db.RemoveAllServersAssumedOffline() + err = db.RemoveAllServersAssumedOffline(context.Background()) assert.Nil(t, err) - isOffline, err = db.IsServerAssumedOffline(server1) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) assert.Nil(t, err) assert.False(t, isOffline) - isOffline, err = db.IsServerAssumedOffline(server2) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) assert.Nil(t, err) assert.False(t, isOffline) }) @@ -301,32 +301,32 @@ func TestRelayServersStored(t *testing.T) { db, closeDB := mustCreateFederationDatabase(t, dbType) defer closeDB() - err := db.AddRelayServersForServer(server, []gomatrixserverlib.ServerName{relayServer1}) + err := db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1}) assert.Nil(t, err) - relayServers, err := db.GetRelayServersForServer(server) + relayServers, err := db.P2PGetRelayServersForServer(context.Background(), server) assert.Nil(t, err) assert.Equal(t, relayServer1, relayServers[0]) - err = db.RemoveRelayServersForServer(server, []gomatrixserverlib.ServerName{relayServer1}) + err = db.P2PRemoveRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1}) assert.Nil(t, err) - relayServers, err = db.GetRelayServersForServer(server) + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) assert.Nil(t, err) assert.Zero(t, len(relayServers)) - err = db.AddRelayServersForServer(server, []gomatrixserverlib.ServerName{relayServer1, relayServer2}) + err = db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1, relayServer2}) assert.Nil(t, err) - relayServers, err = db.GetRelayServersForServer(server) + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) assert.Nil(t, err) assert.Equal(t, relayServer1, relayServers[0]) assert.Equal(t, relayServer2, relayServers[1]) - err = db.RemoveAllRelayServersForServer(server) + err = db.P2PRemoveAllRelayServersForServer(context.Background(), server) assert.Nil(t, err) - relayServers, err = db.GetRelayServersForServer(server) + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) assert.Nil(t, err) assert.Zero(t, len(relayServers)) }) diff --git a/test/memory_federation_db.go b/test/memory_federation_db.go index 4bfb1b27c..74eb2e7f7 100644 --- a/test/memory_federation_db.go +++ b/test/memory_federation_db.go @@ -311,6 +311,7 @@ func (d *InMemoryFederationDatabase) IsServerBlacklisted( } func (d *InMemoryFederationDatabase) SetServerAssumedOffline( + ctx context.Context, serverName gomatrixserverlib.ServerName, ) error { d.dbMutex.Lock() @@ -321,6 +322,7 @@ func (d *InMemoryFederationDatabase) SetServerAssumedOffline( } func (d *InMemoryFederationDatabase) RemoveServerAssumedOffline( + ctx context.Context, serverName gomatrixserverlib.ServerName, ) error { d.dbMutex.Lock() @@ -330,7 +332,9 @@ func (d *InMemoryFederationDatabase) RemoveServerAssumedOffline( return nil } -func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffine() error { +func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffine( + ctx context.Context, +) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -339,6 +343,7 @@ func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffine() error { } func (d *InMemoryFederationDatabase) IsServerAssumedOffline( + ctx context.Context, serverName gomatrixserverlib.ServerName, ) (bool, error) { d.dbMutex.Lock() @@ -352,7 +357,8 @@ func (d *InMemoryFederationDatabase) IsServerAssumedOffline( return assumedOffline, nil } -func (d *InMemoryFederationDatabase) GetRelayServersForServer( +func (d *InMemoryFederationDatabase) P2PGetRelayServersForServer( + ctx context.Context, serverName gomatrixserverlib.ServerName, ) ([]gomatrixserverlib.ServerName, error) { d.dbMutex.Lock() @@ -366,7 +372,8 @@ func (d *InMemoryFederationDatabase) GetRelayServersForServer( return knownRelayServers, nil } -func (d *InMemoryFederationDatabase) AddRelayServersForServer( +func (d *InMemoryFederationDatabase) P2PAddRelayServersForServer( + ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName, ) error { @@ -420,15 +427,15 @@ func (d *InMemoryFederationDatabase) GetJoinedHostsForRooms(ctx context.Context, return nil, nil } -func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffline() error { +func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffline(ctx context.Context) error { return nil } -func (d *InMemoryFederationDatabase) RemoveRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error { +func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error { return nil } -func (d *InMemoryFederationDatabase) RemoveAllRelayServersForServer(serverName gomatrixserverlib.ServerName) error { +func (d *InMemoryFederationDatabase) P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error { return nil }