Refactor federationapi db to better distinguish p2p specifics

This commit is contained in:
Devon Hudson 2023-01-13 11:35:49 -07:00
parent d926444daf
commit f45d612ebd
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
8 changed files with 178 additions and 83 deletions

View file

@ -738,7 +738,9 @@ func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error {
return fmt.Errorf("auth chain response is missing m.room.create event")
}
func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder) gomatrixserverlib.RoomVersion {
func setDefaultRoomVersionFromJoinEvent(
joinEvent gomatrixserverlib.EventBuilder,
) gomatrixserverlib.RoomVersion {
// if auth events are not event references we know it must be v3+
// we have to do these shenanigans to satisfy sytest, specifically for:
// "Outbound federation rejects m.room.create events with an unknown room version"
@ -829,7 +831,7 @@ func (r *FederationInternalAPI) P2PQueryRelayServers(
response *api.P2PQueryRelayServersResponse,
) error {
logrus.Infof("Getting relay servers for: %s", request.Server)
relayServers, err := r.db.GetRelayServersForServer(request.Server)
relayServers, err := r.db.P2PGetRelayServersForServer(ctx, request.Server)
if err != nil {
return err
}
@ -838,7 +840,9 @@ func (r *FederationInternalAPI) P2PQueryRelayServers(
return nil
}
func (r *FederationInternalAPI) shouldAttemptDirectFederation(destination gomatrixserverlib.ServerName) bool {
func (r *FederationInternalAPI) shouldAttemptDirectFederation(
destination gomatrixserverlib.ServerName,
) bool {
var shouldRelay bool
stats := r.statistics.ForServer(destination)
if stats.AssumedOffline() && len(stats.KnownRelayServers()) > 0 {

View file

@ -44,11 +44,11 @@ func TestPerformWakeupServers(t *testing.T) {
server := gomatrixserverlib.ServerName("wakeup")
testDB.AddServerToBlacklist(server)
testDB.SetServerAssumedOffline(server)
testDB.SetServerAssumedOffline(context.Background(), server)
blacklisted, err := testDB.IsServerBlacklisted(server)
assert.NoError(t, err)
assert.True(t, blacklisted)
offline, err := testDB.IsServerAssumedOffline(server)
offline, err := testDB.IsServerAssumedOffline(context.Background(), server)
assert.NoError(t, err)
assert.True(t, offline)
@ -81,7 +81,7 @@ func TestPerformWakeupServers(t *testing.T) {
blacklisted, err = testDB.IsServerBlacklisted(server)
assert.NoError(t, err)
assert.False(t, blacklisted)
offline, err = testDB.IsServerAssumedOffline(server)
offline, err = testDB.IsServerAssumedOffline(context.Background(), server)
assert.NoError(t, err)
assert.False(t, offline)
}
@ -91,7 +91,7 @@ func TestQueryRelayServers(t *testing.T) {
server := gomatrixserverlib.ServerName("wakeup")
relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"}
err := testDB.AddRelayServersForServer(server, relayServers)
err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers)
assert.NoError(t, err)
cfg := config.FederationAPI{
@ -158,8 +158,8 @@ func TestPerformDirectoryLookupRelaying(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
server := gomatrixserverlib.ServerName("wakeup")
testDB.SetServerAssumedOffline(server)
testDB.AddRelayServersForServer(server, []gomatrixserverlib.ServerName{"relay"})
testDB.SetServerAssumedOffline(context.Background(), server)
testDB.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{"relay"})
cfg := config.FederationAPI{
Matrix: &config.Global{

View file

@ -858,7 +858,7 @@ func TestSendPDUMultipleFailuresAssumedOffline(t *testing.T) {
data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
assert.NoError(t, dbErr)
if len(data) == 1 {
if val, _ := db.IsServerAssumedOffline(destination); val {
if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val {
return poll.Success()
}
return poll.Continue("waiting for server to be assumed offline")
@ -891,7 +891,7 @@ func TestSendEDUMultipleFailuresAssumedOffline(t *testing.T) {
data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
assert.NoError(t, dbErr)
if len(data) == 1 {
if val, _ := db.IsServerAssumedOffline(destination); val {
if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val {
return poll.Success()
}
return poll.Continue("waiting for server to be assumed offline")
@ -938,7 +938,7 @@ func TestSendPDUOnRelaySuccessRemovedFromDB(t *testing.T) {
}
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
assumedOffline, _ := db.IsServerAssumedOffline(destination)
assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination)
assert.Equal(t, true, assumedOffline)
}
@ -977,6 +977,6 @@ func TestSendEDUOnRelaySuccessRemovedFromDB(t *testing.T) {
}
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
assumedOffline, _ := db.IsServerAssumedOffline(destination)
assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination)
assert.Equal(t, true, assumedOffline)
}

View file

@ -1,6 +1,7 @@
package statistics
import (
"context"
"math"
"math/rand"
"sync"
@ -72,14 +73,14 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
} else {
server.blacklisted.Store(blacklisted)
}
assumedOffline, err := s.DB.IsServerAssumedOffline(serverName)
assumedOffline, err := s.DB.IsServerAssumedOffline(context.Background(), serverName)
if err != nil {
logrus.WithError(err).Errorf("Failed to get assumed offline entry %q", serverName)
} else {
server.assumedOffline.Store(assumedOffline)
}
knownRelayServers, err := s.DB.GetRelayServersForServer(serverName)
knownRelayServers, err := s.DB.P2PGetRelayServersForServer(context.Background(), serverName)
if err != nil {
logrus.WithError(err).Errorf("Failed to get relay server list for %q", serverName)
} else {
@ -186,7 +187,7 @@ func (s *ServerStatistics) Failure() (time.Time, bool) {
if backoffCount >= s.statistics.FailuresUntilAssumedOffline {
s.assumedOffline.CompareAndSwap(false, true)
if s.statistics.DB != nil {
if err := s.statistics.DB.SetServerAssumedOffline(s.serverName); err != nil {
if err := s.statistics.DB.SetServerAssumedOffline(context.Background(), s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to set %q as assumed offline", s.serverName)
}
}
@ -291,7 +292,7 @@ func (s *ServerStatistics) removeBlacklist() bool {
// removeAssumedOffline removes the assumed offline status from the server.
func (s *ServerStatistics) removeAssumedOffline() {
if s.AssumedOffline() {
_ = s.statistics.DB.RemoveServerAssumedOffline(s.serverName)
_ = s.statistics.DB.RemoveServerAssumedOffline(context.Background(), s.serverName)
}
s.assumedOffline.Store(false)
}
@ -321,7 +322,7 @@ func (s *ServerStatistics) AddRelayServers(relayServers []gomatrixserverlib.Serv
uniqueList = append(uniqueList, srv)
}
err := s.statistics.DB.AddRelayServersForServer(s.serverName, uniqueList)
err := s.statistics.DB.P2PAddRelayServersForServer(context.Background(), s.serverName, uniqueList)
if err != nil {
logrus.WithError(err).Errorf("Failed to add relay servers for %q. Servers: %v", s.serverName, uniqueList)
return

View file

@ -54,16 +54,15 @@ type Database interface {
RemoveAllServersFromBlacklist() error
IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error)
// these don't have contexts passed in as we want things to happen regardless of the request context
SetServerAssumedOffline(serverName gomatrixserverlib.ServerName) error
RemoveServerAssumedOffline(serverName gomatrixserverlib.ServerName) error
RemoveAllServersAssumedOffline() error
IsServerAssumedOffline(serverName gomatrixserverlib.ServerName) (bool, error)
SetServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error
RemoveServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error
RemoveAllServersAssumedOffline(ctx context.Context) error
IsServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) (bool, error)
AddRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
GetRelayServersForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error)
RemoveRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
RemoveAllRelayServersForServer(serverName gomatrixserverlib.ServerName) error
P2PAddRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
P2PGetRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error)
P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error
AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error

View file

@ -100,11 +100,18 @@ func (d *Database) GetJoinedHosts(
// GetAllJoinedHosts returns the currently joined hosts for
// all rooms known to the federation sender.
// Returns an error if something goes wrong.
func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
func (d *Database) GetAllJoinedHosts(
ctx context.Context,
) ([]gomatrixserverlib.ServerName, error) {
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
}
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) {
func (d *Database) GetJoinedHostsForRooms(
ctx context.Context,
roomIDs []string,
excludeSelf,
excludeBlacklisted bool,
) ([]gomatrixserverlib.ServerName, error) {
servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted)
if err != nil {
return nil, err
@ -140,13 +147,17 @@ func (d *Database) StoreJSON(
return &newReceipt, nil
}
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
func (d *Database) AddServerToBlacklist(
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName)
})
}
func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error {
func (d *Database) RemoveServerFromBlacklist(
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationBlacklist.DeleteBlacklist(context.TODO(), txn, serverName)
})
@ -158,95 +169,166 @@ func (d *Database) RemoveAllServersFromBlacklist() error {
})
}
func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) {
func (d *Database) IsServerBlacklisted(
serverName gomatrixserverlib.ServerName,
) (bool, error) {
return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName)
}
func (d *Database) SetServerAssumedOffline(serverName gomatrixserverlib.ServerName) error {
func (d *Database) SetServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationAssumedOffline.InsertAssumedOffline(context.TODO(), txn, serverName)
return d.FederationAssumedOffline.InsertAssumedOffline(ctx, txn, serverName)
})
}
func (d *Database) RemoveServerAssumedOffline(serverName gomatrixserverlib.ServerName) error {
func (d *Database) RemoveServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationAssumedOffline.DeleteAssumedOffline(context.TODO(), txn, serverName)
return d.FederationAssumedOffline.DeleteAssumedOffline(ctx, txn, serverName)
})
}
func (d *Database) RemoveAllServersAssumedOffline() error {
func (d *Database) RemoveAllServersAssumedOffline(
ctx context.Context,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationAssumedOffline.DeleteAllAssumedOffline(context.TODO(), txn)
return d.FederationAssumedOffline.DeleteAllAssumedOffline(ctx, txn)
})
}
func (d *Database) IsServerAssumedOffline(serverName gomatrixserverlib.ServerName) (bool, error) {
return d.FederationAssumedOffline.SelectAssumedOffline(context.TODO(), nil, serverName)
func (d *Database) IsServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) (bool, error) {
return d.FederationAssumedOffline.SelectAssumedOffline(ctx, nil, serverName)
}
func (d *Database) AddRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error {
func (d *Database) P2PAddRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationRelayServers.InsertRelayServers(context.TODO(), txn, serverName, relayServers)
return d.FederationRelayServers.InsertRelayServers(ctx, txn, serverName, relayServers)
})
}
func (d *Database) GetRelayServersForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) {
return d.FederationRelayServers.SelectRelayServers(context.TODO(), nil, serverName)
func (d *Database) P2PGetRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) ([]gomatrixserverlib.ServerName, error) {
return d.FederationRelayServers.SelectRelayServers(ctx, nil, serverName)
}
func (d *Database) RemoveRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error {
func (d *Database) P2PRemoveRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationRelayServers.DeleteRelayServers(context.TODO(), txn, serverName, relayServers)
return d.FederationRelayServers.DeleteRelayServers(ctx, txn, serverName, relayServers)
})
}
func (d *Database) RemoveAllRelayServersForServer(serverName gomatrixserverlib.ServerName) error {
func (d *Database) P2PRemoveAllRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationRelayServers.DeleteAllRelayServers(context.TODO(), txn, serverName)
return d.FederationRelayServers.DeleteAllRelayServers(ctx, txn, serverName)
})
}
func (d *Database) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
func (d *Database) AddOutboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
renewalInterval int64,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
})
}
func (d *Database) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
func (d *Database) RenewOutboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
renewalInterval int64,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationOutboundPeeks.RenewOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
})
}
func (d *Database) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) {
func (d *Database) GetOutboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID,
peekID string,
) (*types.OutboundPeek, error) {
return d.FederationOutboundPeeks.SelectOutboundPeek(ctx, nil, serverName, roomID, peekID)
}
func (d *Database) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) {
func (d *Database) GetOutboundPeeks(
ctx context.Context,
roomID string,
) ([]types.OutboundPeek, error) {
return d.FederationOutboundPeeks.SelectOutboundPeeks(ctx, nil, roomID)
}
func (d *Database) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
func (d *Database) AddInboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
renewalInterval int64,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationInboundPeeks.InsertInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
})
}
func (d *Database) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
func (d *Database) RenewInboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
renewalInterval int64,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationInboundPeeks.RenewInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
})
}
func (d *Database) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) {
func (d *Database) GetInboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
) (*types.InboundPeek, error) {
return d.FederationInboundPeeks.SelectInboundPeek(ctx, nil, serverName, roomID, peekID)
}
func (d *Database) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) {
func (d *Database) GetInboundPeeks(
ctx context.Context,
roomID string,
) ([]types.InboundPeek, error) {
return d.FederationInboundPeeks.SelectInboundPeeks(ctx, nil, roomID)
}
func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error {
func (d *Database) UpdateNotaryKeys(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
serverKeys gomatrixserverlib.ServerKeys,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
validUntil := serverKeys.ValidUntilTS
// Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid.
@ -281,7 +363,9 @@ func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserv
}
func (d *Database) GetNotaryKeys(
ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID,
ctx context.Context,
serverName gomatrixserverlib.ServerName,
optKeyIDs []gomatrixserverlib.KeyID,
) (sks []gomatrixserverlib.ServerKeys, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sks, err = d.NotaryServerKeysMetadata.SelectKeys(ctx, txn, serverName, optKeyIDs)

View file

@ -254,39 +254,39 @@ func TestServersAssumedOffline(t *testing.T) {
db, closeDB := mustCreateFederationDatabase(t, dbType)
defer closeDB()
err := db.SetServerAssumedOffline(server1)
err := db.SetServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
isOffline, err := db.IsServerAssumedOffline(server1)
isOffline, err := db.IsServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
assert.True(t, isOffline)
err = db.RemoveServerAssumedOffline(server1)
err = db.RemoveServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
isOffline, err = db.IsServerAssumedOffline(server1)
isOffline, err = db.IsServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
assert.False(t, isOffline)
err = db.SetServerAssumedOffline(server1)
err = db.SetServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
err = db.SetServerAssumedOffline(server2)
err = db.SetServerAssumedOffline(context.Background(), server2)
assert.Nil(t, err)
isOffline, err = db.IsServerAssumedOffline(server1)
isOffline, err = db.IsServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
assert.True(t, isOffline)
isOffline, err = db.IsServerAssumedOffline(server2)
isOffline, err = db.IsServerAssumedOffline(context.Background(), server2)
assert.Nil(t, err)
assert.True(t, isOffline)
err = db.RemoveAllServersAssumedOffline()
err = db.RemoveAllServersAssumedOffline(context.Background())
assert.Nil(t, err)
isOffline, err = db.IsServerAssumedOffline(server1)
isOffline, err = db.IsServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
assert.False(t, isOffline)
isOffline, err = db.IsServerAssumedOffline(server2)
isOffline, err = db.IsServerAssumedOffline(context.Background(), server2)
assert.Nil(t, err)
assert.False(t, isOffline)
})
@ -301,32 +301,32 @@ func TestRelayServersStored(t *testing.T) {
db, closeDB := mustCreateFederationDatabase(t, dbType)
defer closeDB()
err := db.AddRelayServersForServer(server, []gomatrixserverlib.ServerName{relayServer1})
err := db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1})
assert.Nil(t, err)
relayServers, err := db.GetRelayServersForServer(server)
relayServers, err := db.P2PGetRelayServersForServer(context.Background(), server)
assert.Nil(t, err)
assert.Equal(t, relayServer1, relayServers[0])
err = db.RemoveRelayServersForServer(server, []gomatrixserverlib.ServerName{relayServer1})
err = db.P2PRemoveRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1})
assert.Nil(t, err)
relayServers, err = db.GetRelayServersForServer(server)
relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server)
assert.Nil(t, err)
assert.Zero(t, len(relayServers))
err = db.AddRelayServersForServer(server, []gomatrixserverlib.ServerName{relayServer1, relayServer2})
err = db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1, relayServer2})
assert.Nil(t, err)
relayServers, err = db.GetRelayServersForServer(server)
relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server)
assert.Nil(t, err)
assert.Equal(t, relayServer1, relayServers[0])
assert.Equal(t, relayServer2, relayServers[1])
err = db.RemoveAllRelayServersForServer(server)
err = db.P2PRemoveAllRelayServersForServer(context.Background(), server)
assert.Nil(t, err)
relayServers, err = db.GetRelayServersForServer(server)
relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server)
assert.Nil(t, err)
assert.Zero(t, len(relayServers))
})

View file

@ -311,6 +311,7 @@ func (d *InMemoryFederationDatabase) IsServerBlacklisted(
}
func (d *InMemoryFederationDatabase) SetServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) error {
d.dbMutex.Lock()
@ -321,6 +322,7 @@ func (d *InMemoryFederationDatabase) SetServerAssumedOffline(
}
func (d *InMemoryFederationDatabase) RemoveServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) error {
d.dbMutex.Lock()
@ -330,7 +332,9 @@ func (d *InMemoryFederationDatabase) RemoveServerAssumedOffline(
return nil
}
func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffine() error {
func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffine(
ctx context.Context,
) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
@ -339,6 +343,7 @@ func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffine() error {
}
func (d *InMemoryFederationDatabase) IsServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) (bool, error) {
d.dbMutex.Lock()
@ -352,7 +357,8 @@ func (d *InMemoryFederationDatabase) IsServerAssumedOffline(
return assumedOffline, nil
}
func (d *InMemoryFederationDatabase) GetRelayServersForServer(
func (d *InMemoryFederationDatabase) P2PGetRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) ([]gomatrixserverlib.ServerName, error) {
d.dbMutex.Lock()
@ -366,7 +372,8 @@ func (d *InMemoryFederationDatabase) GetRelayServersForServer(
return knownRelayServers, nil
}
func (d *InMemoryFederationDatabase) AddRelayServersForServer(
func (d *InMemoryFederationDatabase) P2PAddRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
@ -420,15 +427,15 @@ func (d *InMemoryFederationDatabase) GetJoinedHostsForRooms(ctx context.Context,
return nil, nil
}
func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffline() error {
func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffline(ctx context.Context) error {
return nil
}
func (d *InMemoryFederationDatabase) RemoveRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error {
func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error {
return nil
}
func (d *InMemoryFederationDatabase) RemoveAllRelayServersForServer(serverName gomatrixserverlib.ServerName) error {
func (d *InMemoryFederationDatabase) P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error {
return nil
}