Don't get blacklisted hosts when querying joined servers (#2880)

Otherwise we just waste time/CPU.
This commit is contained in:
Neil Alexander 2022-11-15 17:21:16 +00:00 committed by GitHub
parent 5c9aed6af9
commit 9b8bb55430
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 56 additions and 32 deletions

View file

@ -200,6 +200,7 @@ type PerformInviteResponse struct {
type QueryJoinedHostServerNamesInRoomRequest struct {
RoomID string `json:"room_id"`
ExcludeSelf bool `json:"exclude_self"`
ExcludeBlacklisted bool `json:"exclude_blacklisted"`
}
// QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames

View file

@ -128,7 +128,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
}
// send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
if err != nil {
sentry.CaptureException(err)
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
@ -189,7 +189,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
return true
}
// send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
if err != nil {
sentry.CaptureException(err)
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")

View file

@ -111,7 +111,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
}
// send this presence to all servers who share rooms with this user.
joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
if err != nil {
log.WithError(err).Error("failed to get joined hosts")
return true

View file

@ -16,7 +16,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom(
request *api.QueryJoinedHostServerNamesInRoomRequest,
response *api.QueryJoinedHostServerNamesInRoomResponse,
) (err error) {
joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf)
joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf, request.ExcludeBlacklisted)
if err != nil {
return
}

View file

@ -32,7 +32,7 @@ type Database interface {
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error)
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error)
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)

View file

@ -66,6 +66,11 @@ const selectAllJoinedHostsSQL = "" +
const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)"
const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id = ANY($1) AND NOT EXISTS (" +
" SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" +
");"
type joinedHostsStatements struct {
db *sql.DB
insertJoinedHostsStmt *sql.Stmt
@ -74,6 +79,7 @@ type joinedHostsStatements struct {
selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt
selectJoinedHostsForRoomsStmt *sql.Stmt
selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt
}
func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
@ -102,6 +108,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro
if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil {
return
}
if s.selectJoinedHostsForRoomsExcludingBlacklistedStmt, err = s.db.Prepare(selectJoinedHostsForRoomsExcludingBlacklistedSQL); err != nil {
return
}
return
}
@ -167,9 +176,13 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
}
func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
ctx context.Context, roomIDs []string,
ctx context.Context, roomIDs []string, excludingBlacklisted bool,
) ([]gomatrixserverlib.ServerName, error) {
rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs))
stmt := s.selectJoinedHostsForRoomsStmt
if excludingBlacklisted {
stmt = s.selectJoinedHostsForRoomsExcludingBlacklistedStmt
}
rows, err := stmt.QueryContext(ctx, pq.StringArray(roomIDs))
if err != nil {
return nil, err
}

View file

@ -42,6 +42,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
return nil, err
}
blacklist, err := NewPostgresBlacklistTable(d.db)
if err != nil {
return nil, err
}
joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
if err != nil {
return nil, err
@ -58,10 +62,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, err
}
blacklist, err := NewPostgresBlacklistTable(d.db)
if err != nil {
return nil, err
}
inboundPeeks, err := NewPostgresInboundPeeksTable(d.db)
if err != nil {
return nil, err

View file

@ -117,8 +117,8 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
}
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) {
servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs)
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) {
servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted)
if err != nil {
return nil, err
}

View file

@ -66,6 +66,11 @@ const selectAllJoinedHostsSQL = "" +
const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id IN ($1) AND NOT EXISTS (" +
" SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" +
");"
type joinedHostsStatements struct {
db *sql.DB
insertJoinedHostsStmt *sql.Stmt
@ -74,6 +79,7 @@ type joinedHostsStatements struct {
selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
// selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt - prepared at runtime due to variadic
}
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
@ -168,14 +174,17 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
}
func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
ctx context.Context, roomIDs []string,
ctx context.Context, roomIDs []string, excludingBlacklisted bool,
) ([]gomatrixserverlib.ServerName, error) {
iRoomIDs := make([]interface{}, len(roomIDs))
for i := range roomIDs {
iRoomIDs[i] = roomIDs[i]
}
sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
query := selectJoinedHostsForRoomsSQL
if excludingBlacklisted {
query = selectJoinedHostsForRoomsExcludingBlacklistedSQL
}
sql := strings.Replace(query, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...)
if err != nil {
return nil, err

View file

@ -41,6 +41,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
return nil, err
}
blacklist, err := NewSQLiteBlacklistTable(d.db)
if err != nil {
return nil, err
}
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
if err != nil {
return nil, err
@ -57,10 +61,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, err
}
blacklist, err := NewSQLiteBlacklistTable(d.db)
if err != nil {
return nil, err
}
outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db)
if err != nil {
return nil, err

View file

@ -58,7 +58,7 @@ type FederationJoinedHosts interface {
SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error)
SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludingBlacklisted bool) ([]gomatrixserverlib.ServerName, error)
}
type FederationBlacklist interface {

View file

@ -167,6 +167,7 @@ func (r *Inputer) processRoomEvent(
serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{
RoomID: event.RoomID(),
ExcludeSelf: true,
ExcludeBlacklisted: true,
}
if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil {
return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)