Refactor federationapi db to better distinguish p2p specifics

This commit is contained in:
Devon Hudson 2023-01-13 11:35:49 -07:00
parent d926444daf
commit f45d612ebd
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
8 changed files with 178 additions and 83 deletions

View file

@ -738,7 +738,9 @@ func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error {
return fmt.Errorf("auth chain response is missing m.room.create event") return fmt.Errorf("auth chain response is missing m.room.create event")
} }
func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder) gomatrixserverlib.RoomVersion { func setDefaultRoomVersionFromJoinEvent(
joinEvent gomatrixserverlib.EventBuilder,
) gomatrixserverlib.RoomVersion {
// if auth events are not event references we know it must be v3+ // if auth events are not event references we know it must be v3+
// we have to do these shenanigans to satisfy sytest, specifically for: // we have to do these shenanigans to satisfy sytest, specifically for:
// "Outbound federation rejects m.room.create events with an unknown room version" // "Outbound federation rejects m.room.create events with an unknown room version"
@ -829,7 +831,7 @@ func (r *FederationInternalAPI) P2PQueryRelayServers(
response *api.P2PQueryRelayServersResponse, response *api.P2PQueryRelayServersResponse,
) error { ) error {
logrus.Infof("Getting relay servers for: %s", request.Server) logrus.Infof("Getting relay servers for: %s", request.Server)
relayServers, err := r.db.GetRelayServersForServer(request.Server) relayServers, err := r.db.P2PGetRelayServersForServer(ctx, request.Server)
if err != nil { if err != nil {
return err return err
} }
@ -838,7 +840,9 @@ func (r *FederationInternalAPI) P2PQueryRelayServers(
return nil return nil
} }
func (r *FederationInternalAPI) shouldAttemptDirectFederation(destination gomatrixserverlib.ServerName) bool { func (r *FederationInternalAPI) shouldAttemptDirectFederation(
destination gomatrixserverlib.ServerName,
) bool {
var shouldRelay bool var shouldRelay bool
stats := r.statistics.ForServer(destination) stats := r.statistics.ForServer(destination)
if stats.AssumedOffline() && len(stats.KnownRelayServers()) > 0 { if stats.AssumedOffline() && len(stats.KnownRelayServers()) > 0 {

View file

@ -44,11 +44,11 @@ func TestPerformWakeupServers(t *testing.T) {
server := gomatrixserverlib.ServerName("wakeup") server := gomatrixserverlib.ServerName("wakeup")
testDB.AddServerToBlacklist(server) testDB.AddServerToBlacklist(server)
testDB.SetServerAssumedOffline(server) testDB.SetServerAssumedOffline(context.Background(), server)
blacklisted, err := testDB.IsServerBlacklisted(server) blacklisted, err := testDB.IsServerBlacklisted(server)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, blacklisted) assert.True(t, blacklisted)
offline, err := testDB.IsServerAssumedOffline(server) offline, err := testDB.IsServerAssumedOffline(context.Background(), server)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, offline) assert.True(t, offline)
@ -81,7 +81,7 @@ func TestPerformWakeupServers(t *testing.T) {
blacklisted, err = testDB.IsServerBlacklisted(server) blacklisted, err = testDB.IsServerBlacklisted(server)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, blacklisted) assert.False(t, blacklisted)
offline, err = testDB.IsServerAssumedOffline(server) offline, err = testDB.IsServerAssumedOffline(context.Background(), server)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, offline) assert.False(t, offline)
} }
@ -91,7 +91,7 @@ func TestQueryRelayServers(t *testing.T) {
server := gomatrixserverlib.ServerName("wakeup") server := gomatrixserverlib.ServerName("wakeup")
relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"}
err := testDB.AddRelayServersForServer(server, relayServers) err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers)
assert.NoError(t, err) assert.NoError(t, err)
cfg := config.FederationAPI{ cfg := config.FederationAPI{
@ -158,8 +158,8 @@ func TestPerformDirectoryLookupRelaying(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase() testDB := test.NewInMemoryFederationDatabase()
server := gomatrixserverlib.ServerName("wakeup") server := gomatrixserverlib.ServerName("wakeup")
testDB.SetServerAssumedOffline(server) testDB.SetServerAssumedOffline(context.Background(), server)
testDB.AddRelayServersForServer(server, []gomatrixserverlib.ServerName{"relay"}) testDB.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{"relay"})
cfg := config.FederationAPI{ cfg := config.FederationAPI{
Matrix: &config.Global{ Matrix: &config.Global{

View file

@ -858,7 +858,7 @@ func TestSendPDUMultipleFailuresAssumedOffline(t *testing.T) {
data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
assert.NoError(t, dbErr) assert.NoError(t, dbErr)
if len(data) == 1 { if len(data) == 1 {
if val, _ := db.IsServerAssumedOffline(destination); val { if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val {
return poll.Success() return poll.Success()
} }
return poll.Continue("waiting for server to be assumed offline") return poll.Continue("waiting for server to be assumed offline")
@ -891,7 +891,7 @@ func TestSendEDUMultipleFailuresAssumedOffline(t *testing.T) {
data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
assert.NoError(t, dbErr) assert.NoError(t, dbErr)
if len(data) == 1 { if len(data) == 1 {
if val, _ := db.IsServerAssumedOffline(destination); val { if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val {
return poll.Success() return poll.Success()
} }
return poll.Continue("waiting for server to be assumed offline") return poll.Continue("waiting for server to be assumed offline")
@ -938,7 +938,7 @@ func TestSendPDUOnRelaySuccessRemovedFromDB(t *testing.T) {
} }
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
assumedOffline, _ := db.IsServerAssumedOffline(destination) assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination)
assert.Equal(t, true, assumedOffline) assert.Equal(t, true, assumedOffline)
} }
@ -977,6 +977,6 @@ func TestSendEDUOnRelaySuccessRemovedFromDB(t *testing.T) {
} }
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
assumedOffline, _ := db.IsServerAssumedOffline(destination) assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination)
assert.Equal(t, true, assumedOffline) assert.Equal(t, true, assumedOffline)
} }

View file

@ -1,6 +1,7 @@
package statistics package statistics
import ( import (
"context"
"math" "math"
"math/rand" "math/rand"
"sync" "sync"
@ -72,14 +73,14 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
} else { } else {
server.blacklisted.Store(blacklisted) server.blacklisted.Store(blacklisted)
} }
assumedOffline, err := s.DB.IsServerAssumedOffline(serverName) assumedOffline, err := s.DB.IsServerAssumedOffline(context.Background(), serverName)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("Failed to get assumed offline entry %q", serverName) logrus.WithError(err).Errorf("Failed to get assumed offline entry %q", serverName)
} else { } else {
server.assumedOffline.Store(assumedOffline) server.assumedOffline.Store(assumedOffline)
} }
knownRelayServers, err := s.DB.GetRelayServersForServer(serverName) knownRelayServers, err := s.DB.P2PGetRelayServersForServer(context.Background(), serverName)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("Failed to get relay server list for %q", serverName) logrus.WithError(err).Errorf("Failed to get relay server list for %q", serverName)
} else { } else {
@ -186,7 +187,7 @@ func (s *ServerStatistics) Failure() (time.Time, bool) {
if backoffCount >= s.statistics.FailuresUntilAssumedOffline { if backoffCount >= s.statistics.FailuresUntilAssumedOffline {
s.assumedOffline.CompareAndSwap(false, true) s.assumedOffline.CompareAndSwap(false, true)
if s.statistics.DB != nil { if s.statistics.DB != nil {
if err := s.statistics.DB.SetServerAssumedOffline(s.serverName); err != nil { if err := s.statistics.DB.SetServerAssumedOffline(context.Background(), s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to set %q as assumed offline", s.serverName) logrus.WithError(err).Errorf("Failed to set %q as assumed offline", s.serverName)
} }
} }
@ -291,7 +292,7 @@ func (s *ServerStatistics) removeBlacklist() bool {
// removeAssumedOffline removes the assumed offline status from the server. // removeAssumedOffline removes the assumed offline status from the server.
func (s *ServerStatistics) removeAssumedOffline() { func (s *ServerStatistics) removeAssumedOffline() {
if s.AssumedOffline() { if s.AssumedOffline() {
_ = s.statistics.DB.RemoveServerAssumedOffline(s.serverName) _ = s.statistics.DB.RemoveServerAssumedOffline(context.Background(), s.serverName)
} }
s.assumedOffline.Store(false) s.assumedOffline.Store(false)
} }
@ -321,7 +322,7 @@ func (s *ServerStatistics) AddRelayServers(relayServers []gomatrixserverlib.Serv
uniqueList = append(uniqueList, srv) uniqueList = append(uniqueList, srv)
} }
err := s.statistics.DB.AddRelayServersForServer(s.serverName, uniqueList) err := s.statistics.DB.P2PAddRelayServersForServer(context.Background(), s.serverName, uniqueList)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("Failed to add relay servers for %q. Servers: %v", s.serverName, uniqueList) logrus.WithError(err).Errorf("Failed to add relay servers for %q. Servers: %v", s.serverName, uniqueList)
return return

View file

@ -54,16 +54,15 @@ type Database interface {
RemoveAllServersFromBlacklist() error RemoveAllServersFromBlacklist() error
IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error)
// these don't have contexts passed in as we want things to happen regardless of the request context SetServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error
SetServerAssumedOffline(serverName gomatrixserverlib.ServerName) error RemoveServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error
RemoveServerAssumedOffline(serverName gomatrixserverlib.ServerName) error RemoveAllServersAssumedOffline(ctx context.Context) error
RemoveAllServersAssumedOffline() error IsServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) (bool, error)
IsServerAssumedOffline(serverName gomatrixserverlib.ServerName) (bool, error)
AddRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error P2PAddRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
GetRelayServersForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) P2PGetRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error)
RemoveRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
RemoveAllRelayServersForServer(serverName gomatrixserverlib.ServerName) error P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error
AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error

View file

@ -100,11 +100,18 @@ func (d *Database) GetJoinedHosts(
// GetAllJoinedHosts returns the currently joined hosts for // GetAllJoinedHosts returns the currently joined hosts for
// all rooms known to the federation sender. // all rooms known to the federation sender.
// Returns an error if something goes wrong. // Returns an error if something goes wrong.
func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { func (d *Database) GetAllJoinedHosts(
ctx context.Context,
) ([]gomatrixserverlib.ServerName, error) {
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
} }
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { func (d *Database) GetJoinedHostsForRooms(
ctx context.Context,
roomIDs []string,
excludeSelf,
excludeBlacklisted bool,
) ([]gomatrixserverlib.ServerName, error) {
servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted) servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted)
if err != nil { if err != nil {
return nil, err return nil, err
@ -140,13 +147,17 @@ func (d *Database) StoreJSON(
return &newReceipt, nil return &newReceipt, nil
} }
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { func (d *Database) AddServerToBlacklist(
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName) return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName)
}) })
} }
func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { func (d *Database) RemoveServerFromBlacklist(
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationBlacklist.DeleteBlacklist(context.TODO(), txn, serverName) return d.FederationBlacklist.DeleteBlacklist(context.TODO(), txn, serverName)
}) })
@ -158,95 +169,166 @@ func (d *Database) RemoveAllServersFromBlacklist() error {
}) })
} }
func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { func (d *Database) IsServerBlacklisted(
serverName gomatrixserverlib.ServerName,
) (bool, error) {
return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName) return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName)
} }
func (d *Database) SetServerAssumedOffline(serverName gomatrixserverlib.ServerName) error { func (d *Database) SetServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationAssumedOffline.InsertAssumedOffline(context.TODO(), txn, serverName) return d.FederationAssumedOffline.InsertAssumedOffline(ctx, txn, serverName)
}) })
} }
func (d *Database) RemoveServerAssumedOffline(serverName gomatrixserverlib.ServerName) error { func (d *Database) RemoveServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationAssumedOffline.DeleteAssumedOffline(context.TODO(), txn, serverName) return d.FederationAssumedOffline.DeleteAssumedOffline(ctx, txn, serverName)
}) })
} }
func (d *Database) RemoveAllServersAssumedOffline() error { func (d *Database) RemoveAllServersAssumedOffline(
ctx context.Context,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationAssumedOffline.DeleteAllAssumedOffline(context.TODO(), txn) return d.FederationAssumedOffline.DeleteAllAssumedOffline(ctx, txn)
}) })
} }
func (d *Database) IsServerAssumedOffline(serverName gomatrixserverlib.ServerName) (bool, error) { func (d *Database) IsServerAssumedOffline(
return d.FederationAssumedOffline.SelectAssumedOffline(context.TODO(), nil, serverName) ctx context.Context,
serverName gomatrixserverlib.ServerName,
) (bool, error) {
return d.FederationAssumedOffline.SelectAssumedOffline(ctx, nil, serverName)
} }
func (d *Database) AddRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error { func (d *Database) P2PAddRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationRelayServers.InsertRelayServers(context.TODO(), txn, serverName, relayServers) return d.FederationRelayServers.InsertRelayServers(ctx, txn, serverName, relayServers)
}) })
} }
func (d *Database) GetRelayServersForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) { func (d *Database) P2PGetRelayServersForServer(
return d.FederationRelayServers.SelectRelayServers(context.TODO(), nil, serverName) ctx context.Context,
serverName gomatrixserverlib.ServerName,
) ([]gomatrixserverlib.ServerName, error) {
return d.FederationRelayServers.SelectRelayServers(ctx, nil, serverName)
} }
func (d *Database) RemoveRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error { func (d *Database) P2PRemoveRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationRelayServers.DeleteRelayServers(context.TODO(), txn, serverName, relayServers) return d.FederationRelayServers.DeleteRelayServers(ctx, txn, serverName, relayServers)
}) })
} }
func (d *Database) RemoveAllRelayServersForServer(serverName gomatrixserverlib.ServerName) error { func (d *Database) P2PRemoveAllRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationRelayServers.DeleteAllRelayServers(context.TODO(), txn, serverName) return d.FederationRelayServers.DeleteAllRelayServers(ctx, txn, serverName)
}) })
} }
func (d *Database) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { func (d *Database) AddOutboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
renewalInterval int64,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
}) })
} }
func (d *Database) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { func (d *Database) RenewOutboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
renewalInterval int64,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationOutboundPeeks.RenewOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) return d.FederationOutboundPeeks.RenewOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
}) })
} }
func (d *Database) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) { func (d *Database) GetOutboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID,
peekID string,
) (*types.OutboundPeek, error) {
return d.FederationOutboundPeeks.SelectOutboundPeek(ctx, nil, serverName, roomID, peekID) return d.FederationOutboundPeeks.SelectOutboundPeek(ctx, nil, serverName, roomID, peekID)
} }
func (d *Database) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) { func (d *Database) GetOutboundPeeks(
ctx context.Context,
roomID string,
) ([]types.OutboundPeek, error) {
return d.FederationOutboundPeeks.SelectOutboundPeeks(ctx, nil, roomID) return d.FederationOutboundPeeks.SelectOutboundPeeks(ctx, nil, roomID)
} }
func (d *Database) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { func (d *Database) AddInboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
renewalInterval int64,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationInboundPeeks.InsertInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) return d.FederationInboundPeeks.InsertInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
}) })
} }
func (d *Database) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { func (d *Database) RenewInboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
renewalInterval int64,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationInboundPeeks.RenewInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) return d.FederationInboundPeeks.RenewInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
}) })
} }
func (d *Database) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) { func (d *Database) GetInboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
) (*types.InboundPeek, error) {
return d.FederationInboundPeeks.SelectInboundPeek(ctx, nil, serverName, roomID, peekID) return d.FederationInboundPeeks.SelectInboundPeek(ctx, nil, serverName, roomID, peekID)
} }
func (d *Database) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) { func (d *Database) GetInboundPeeks(
ctx context.Context,
roomID string,
) ([]types.InboundPeek, error) {
return d.FederationInboundPeeks.SelectInboundPeeks(ctx, nil, roomID) return d.FederationInboundPeeks.SelectInboundPeeks(ctx, nil, roomID)
} }
func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error { func (d *Database) UpdateNotaryKeys(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
serverKeys gomatrixserverlib.ServerKeys,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
validUntil := serverKeys.ValidUntilTS validUntil := serverKeys.ValidUntilTS
// Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid. // Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid.
@ -281,7 +363,9 @@ func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserv
} }
func (d *Database) GetNotaryKeys( func (d *Database) GetNotaryKeys(
ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID, ctx context.Context,
serverName gomatrixserverlib.ServerName,
optKeyIDs []gomatrixserverlib.KeyID,
) (sks []gomatrixserverlib.ServerKeys, err error) { ) (sks []gomatrixserverlib.ServerKeys, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sks, err = d.NotaryServerKeysMetadata.SelectKeys(ctx, txn, serverName, optKeyIDs) sks, err = d.NotaryServerKeysMetadata.SelectKeys(ctx, txn, serverName, optKeyIDs)

View file

@ -254,39 +254,39 @@ func TestServersAssumedOffline(t *testing.T) {
db, closeDB := mustCreateFederationDatabase(t, dbType) db, closeDB := mustCreateFederationDatabase(t, dbType)
defer closeDB() defer closeDB()
err := db.SetServerAssumedOffline(server1) err := db.SetServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err) assert.Nil(t, err)
isOffline, err := db.IsServerAssumedOffline(server1) isOffline, err := db.IsServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, isOffline) assert.True(t, isOffline)
err = db.RemoveServerAssumedOffline(server1) err = db.RemoveServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err) assert.Nil(t, err)
isOffline, err = db.IsServerAssumedOffline(server1) isOffline, err = db.IsServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, isOffline) assert.False(t, isOffline)
err = db.SetServerAssumedOffline(server1) err = db.SetServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err) assert.Nil(t, err)
err = db.SetServerAssumedOffline(server2) err = db.SetServerAssumedOffline(context.Background(), server2)
assert.Nil(t, err) assert.Nil(t, err)
isOffline, err = db.IsServerAssumedOffline(server1) isOffline, err = db.IsServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, isOffline) assert.True(t, isOffline)
isOffline, err = db.IsServerAssumedOffline(server2) isOffline, err = db.IsServerAssumedOffline(context.Background(), server2)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, isOffline) assert.True(t, isOffline)
err = db.RemoveAllServersAssumedOffline() err = db.RemoveAllServersAssumedOffline(context.Background())
assert.Nil(t, err) assert.Nil(t, err)
isOffline, err = db.IsServerAssumedOffline(server1) isOffline, err = db.IsServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, isOffline) assert.False(t, isOffline)
isOffline, err = db.IsServerAssumedOffline(server2) isOffline, err = db.IsServerAssumedOffline(context.Background(), server2)
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, isOffline) assert.False(t, isOffline)
}) })
@ -301,32 +301,32 @@ func TestRelayServersStored(t *testing.T) {
db, closeDB := mustCreateFederationDatabase(t, dbType) db, closeDB := mustCreateFederationDatabase(t, dbType)
defer closeDB() defer closeDB()
err := db.AddRelayServersForServer(server, []gomatrixserverlib.ServerName{relayServer1}) err := db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1})
assert.Nil(t, err) assert.Nil(t, err)
relayServers, err := db.GetRelayServersForServer(server) relayServers, err := db.P2PGetRelayServersForServer(context.Background(), server)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, relayServer1, relayServers[0]) assert.Equal(t, relayServer1, relayServers[0])
err = db.RemoveRelayServersForServer(server, []gomatrixserverlib.ServerName{relayServer1}) err = db.P2PRemoveRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1})
assert.Nil(t, err) assert.Nil(t, err)
relayServers, err = db.GetRelayServersForServer(server) relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server)
assert.Nil(t, err) assert.Nil(t, err)
assert.Zero(t, len(relayServers)) assert.Zero(t, len(relayServers))
err = db.AddRelayServersForServer(server, []gomatrixserverlib.ServerName{relayServer1, relayServer2}) err = db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1, relayServer2})
assert.Nil(t, err) assert.Nil(t, err)
relayServers, err = db.GetRelayServersForServer(server) relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, relayServer1, relayServers[0]) assert.Equal(t, relayServer1, relayServers[0])
assert.Equal(t, relayServer2, relayServers[1]) assert.Equal(t, relayServer2, relayServers[1])
err = db.RemoveAllRelayServersForServer(server) err = db.P2PRemoveAllRelayServersForServer(context.Background(), server)
assert.Nil(t, err) assert.Nil(t, err)
relayServers, err = db.GetRelayServersForServer(server) relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server)
assert.Nil(t, err) assert.Nil(t, err)
assert.Zero(t, len(relayServers)) assert.Zero(t, len(relayServers))
}) })

View file

@ -311,6 +311,7 @@ func (d *InMemoryFederationDatabase) IsServerBlacklisted(
} }
func (d *InMemoryFederationDatabase) SetServerAssumedOffline( func (d *InMemoryFederationDatabase) SetServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
) error { ) error {
d.dbMutex.Lock() d.dbMutex.Lock()
@ -321,6 +322,7 @@ func (d *InMemoryFederationDatabase) SetServerAssumedOffline(
} }
func (d *InMemoryFederationDatabase) RemoveServerAssumedOffline( func (d *InMemoryFederationDatabase) RemoveServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
) error { ) error {
d.dbMutex.Lock() d.dbMutex.Lock()
@ -330,7 +332,9 @@ func (d *InMemoryFederationDatabase) RemoveServerAssumedOffline(
return nil return nil
} }
func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffine() error { func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffine(
ctx context.Context,
) error {
d.dbMutex.Lock() d.dbMutex.Lock()
defer d.dbMutex.Unlock() defer d.dbMutex.Unlock()
@ -339,6 +343,7 @@ func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffine() error {
} }
func (d *InMemoryFederationDatabase) IsServerAssumedOffline( func (d *InMemoryFederationDatabase) IsServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
) (bool, error) { ) (bool, error) {
d.dbMutex.Lock() d.dbMutex.Lock()
@ -352,7 +357,8 @@ func (d *InMemoryFederationDatabase) IsServerAssumedOffline(
return assumedOffline, nil return assumedOffline, nil
} }
func (d *InMemoryFederationDatabase) GetRelayServersForServer( func (d *InMemoryFederationDatabase) P2PGetRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
) ([]gomatrixserverlib.ServerName, error) { ) ([]gomatrixserverlib.ServerName, error) {
d.dbMutex.Lock() d.dbMutex.Lock()
@ -366,7 +372,8 @@ func (d *InMemoryFederationDatabase) GetRelayServersForServer(
return knownRelayServers, nil return knownRelayServers, nil
} }
func (d *InMemoryFederationDatabase) AddRelayServersForServer( func (d *InMemoryFederationDatabase) P2PAddRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName,
) error { ) error {
@ -420,15 +427,15 @@ func (d *InMemoryFederationDatabase) GetJoinedHostsForRooms(ctx context.Context,
return nil, nil return nil, nil
} }
func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffline() error { func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffline(ctx context.Context) error {
return nil return nil
} }
func (d *InMemoryFederationDatabase) RemoveRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error { func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error {
return nil return nil
} }
func (d *InMemoryFederationDatabase) RemoveAllRelayServersForServer(serverName gomatrixserverlib.ServerName) error { func (d *InMemoryFederationDatabase) P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error {
return nil return nil
} }