diff --git a/federationapi/api/api.go b/federationapi/api/api.go index e34c9e8b6..f4be53b9a 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -198,8 +198,9 @@ type PerformInviteResponse struct { // QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames type QueryJoinedHostServerNamesInRoomRequest struct { - RoomID string `json:"room_id"` - ExcludeSelf bool `json:"exclude_self"` + RoomID string `json:"room_id"` + ExcludeSelf bool `json:"exclude_self"` + ExcludeBlacklisted bool `json:"exclude_blacklisted"` } // QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 7d1ae0f81..601257d4b 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -128,7 +128,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { } // send this key change to all servers who share rooms with this user. - destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) + destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) if err != nil { sentry.CaptureException(err) logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") @@ -189,7 +189,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { return true } // send this key change to all servers who share rooms with this user. - destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) + destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) if err != nil { sentry.CaptureException(err) logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") diff --git a/federationapi/consumers/presence.go b/federationapi/consumers/presence.go index 153fc40b5..29b16f373 100644 --- a/federationapi/consumers/presence.go +++ b/federationapi/consumers/presence.go @@ -111,7 +111,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg } // send this presence to all servers who share rooms with this user. - joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) + joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) if err != nil { log.WithError(err).Error("failed to get joined hosts") return true diff --git a/federationapi/internal/query.go b/federationapi/internal/query.go index b0a76eeb7..688afa8ea 100644 --- a/federationapi/internal/query.go +++ b/federationapi/internal/query.go @@ -16,7 +16,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom( request *api.QueryJoinedHostServerNamesInRoomRequest, response *api.QueryJoinedHostServerNamesInRoomResponse, ) (err error) { - joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf) + joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf, request.ExcludeBlacklisted) if err != nil { return } diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 09098cd1e..b15b8bfae 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -32,7 +32,7 @@ type Database interface { GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. - GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) + GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) diff --git a/federationapi/storage/postgres/joined_hosts_table.go b/federationapi/storage/postgres/joined_hosts_table.go index 5c95b72a8..9a3977560 100644 --- a/federationapi/storage/postgres/joined_hosts_table.go +++ b/federationapi/storage/postgres/joined_hosts_table.go @@ -66,14 +66,20 @@ const selectAllJoinedHostsSQL = "" + const selectJoinedHostsForRoomsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)" +const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id = ANY($1) AND NOT EXISTS (" + + " SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" + + ");" + type joinedHostsStatements struct { - db *sql.DB - insertJoinedHostsStmt *sql.Stmt - deleteJoinedHostsStmt *sql.Stmt - deleteJoinedHostsForRoomStmt *sql.Stmt - selectJoinedHostsStmt *sql.Stmt - selectAllJoinedHostsStmt *sql.Stmt - selectJoinedHostsForRoomsStmt *sql.Stmt + db *sql.DB + insertJoinedHostsStmt *sql.Stmt + deleteJoinedHostsStmt *sql.Stmt + deleteJoinedHostsForRoomStmt *sql.Stmt + selectJoinedHostsStmt *sql.Stmt + selectAllJoinedHostsStmt *sql.Stmt + selectJoinedHostsForRoomsStmt *sql.Stmt + selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt } func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { @@ -102,6 +108,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil { return } + if s.selectJoinedHostsForRoomsExcludingBlacklistedStmt, err = s.db.Prepare(selectJoinedHostsForRoomsExcludingBlacklistedSQL); err != nil { + return + } return } @@ -167,9 +176,13 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts( } func (s *joinedHostsStatements) SelectJoinedHostsForRooms( - ctx context.Context, roomIDs []string, + ctx context.Context, roomIDs []string, excludingBlacklisted bool, ) ([]gomatrixserverlib.ServerName, error) { - rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs)) + stmt := s.selectJoinedHostsForRoomsStmt + if excludingBlacklisted { + stmt = s.selectJoinedHostsForRoomsExcludingBlacklistedStmt + } + rows, err := stmt.QueryContext(ctx, pq.StringArray(roomIDs)) if err != nil { return nil, err } diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index a33fa4a43..fe84e932e 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -42,6 +42,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { return nil, err } + blacklist, err := NewPostgresBlacklistTable(d.db) + if err != nil { + return nil, err + } joinedHosts, err := NewPostgresJoinedHostsTable(d.db) if err != nil { return nil, err @@ -58,10 +62,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } - blacklist, err := NewPostgresBlacklistTable(d.db) - if err != nil { - return nil, err - } inboundPeeks, err := NewPostgresInboundPeeksTable(d.db) if err != nil { return nil, err diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 4fabff7d4..7d6e8a684 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -117,8 +117,8 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) } -func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) { - servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs) +func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { + servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted) if err != nil { return nil, err } diff --git a/federationapi/storage/sqlite3/joined_hosts_table.go b/federationapi/storage/sqlite3/joined_hosts_table.go index e0e0f2873..83112c150 100644 --- a/federationapi/storage/sqlite3/joined_hosts_table.go +++ b/federationapi/storage/sqlite3/joined_hosts_table.go @@ -66,6 +66,11 @@ const selectAllJoinedHostsSQL = "" + const selectJoinedHostsForRoomsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" +const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id IN ($1) AND NOT EXISTS (" + + " SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" + + ");" + type joinedHostsStatements struct { db *sql.DB insertJoinedHostsStmt *sql.Stmt @@ -74,6 +79,7 @@ type joinedHostsStatements struct { selectJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic + // selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt - prepared at runtime due to variadic } func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { @@ -168,14 +174,17 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts( } func (s *joinedHostsStatements) SelectJoinedHostsForRooms( - ctx context.Context, roomIDs []string, + ctx context.Context, roomIDs []string, excludingBlacklisted bool, ) ([]gomatrixserverlib.ServerName, error) { iRoomIDs := make([]interface{}, len(roomIDs)) for i := range roomIDs { iRoomIDs[i] = roomIDs[i] } - - sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) + query := selectJoinedHostsForRoomsSQL + if excludingBlacklisted { + query = selectJoinedHostsForRoomsExcludingBlacklistedSQL + } + sql := strings.Replace(query, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...) if err != nil { return nil, err diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index e86ac817b..d13b5defc 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -41,6 +41,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { return nil, err } + blacklist, err := NewSQLiteBlacklistTable(d.db) + if err != nil { + return nil, err + } joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) if err != nil { return nil, err @@ -57,10 +61,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } - blacklist, err := NewSQLiteBlacklistTable(d.db) - if err != nil { - return nil, err - } outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db) if err != nil { return nil, err diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 3c116a1d0..9f4e86a6e 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -58,7 +58,7 @@ type FederationJoinedHosts interface { SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error) SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) - SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) + SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludingBlacklisted bool) ([]gomatrixserverlib.ServerName, error) } type FederationBlacklist interface { diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 72c285f8e..10b8ee27f 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -165,8 +165,9 @@ func (r *Inputer) processRoomEvent( if missingAuth || missingPrev { serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ - RoomID: event.RoomID(), - ExcludeSelf: true, + RoomID: event.RoomID(), + ExcludeSelf: true, + ExcludeBlacklisted: true, } if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)