diff --git a/federationsender/consumers/keychange.go b/federationsender/consumers/keychange.go index b03395a10..4c3d23b5d 100644 --- a/federationsender/consumers/keychange.go +++ b/federationsender/consumers/keychange.go @@ -13,10 +13,12 @@ package consumers import ( + "context" "encoding/json" "fmt" "github.com/Shopify/sarama" + stateapi "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/federationsender/queue" "github.com/matrix-org/dendrite/federationsender/storage" "github.com/matrix-org/dendrite/internal" @@ -32,6 +34,7 @@ type KeyChangeConsumer struct { db storage.Database queues *queue.OutgoingQueues serverName gomatrixserverlib.ServerName + stateAPI stateapi.CurrentStateInternalAPI } // NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers. @@ -40,6 +43,7 @@ func NewKeyChangeConsumer( kafkaConsumer sarama.Consumer, queues *queue.OutgoingQueues, store storage.Database, + stateAPI stateapi.CurrentStateInternalAPI, ) *KeyChangeConsumer { c := &KeyChangeConsumer{ consumer: &internal.ContinualConsumer{ @@ -50,6 +54,7 @@ func NewKeyChangeConsumer( queues: queues, db: store, serverName: cfg.Matrix.ServerName, + stateAPI: stateAPI, } c.consumer.ProcessMessage = c.onMessage @@ -72,19 +77,33 @@ func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error { log.WithError(err).Errorf("failed to read device message from key change topic") return nil } + logger := log.WithField("user_id", m.UserID) // only send key change events which originated from us _, originServerName, err := gomatrixserverlib.SplitID('@', m.UserID) if err != nil { - log.WithError(err).WithField("user_id", m.UserID).Error("Failed to extract domain from key change event") + logger.WithError(err).Error("Failed to extract domain from key change event") return nil } if originServerName != t.serverName { return nil } - // TODO: send this key change to all users who share rooms with this user. - var destinations []gomatrixserverlib.ServerName + var queryRes stateapi.QueryRoomsForUserResponse + err = t.stateAPI.QueryRoomsForUser(context.Background(), &stateapi.QueryRoomsForUserRequest{ + UserID: m.UserID, + WantMembership: "join", + }, &queryRes) + if err != nil { + logger.WithError(err).Error("failed to calculate joined rooms for user") + return nil + } + // send this key change to all servers who share rooms with this user. + destinations, err := t.db.GetJoinedHostsForRooms(context.Background(), queryRes.RoomIDs) + if err != nil { + logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") + return nil + } // Pack the EDU and marshal it edu := &gomatrixserverlib.EDU{ diff --git a/federationsender/storage/interface.go b/federationsender/storage/interface.go index b79499d3a..734b368fe 100644 --- a/federationsender/storage/interface.go +++ b/federationsender/storage/interface.go @@ -30,6 +30,8 @@ type Database interface { GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) + // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. + GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) diff --git a/federationsender/storage/postgres/joined_hosts_table.go b/federationsender/storage/postgres/joined_hosts_table.go index af0a52581..52865996b 100644 --- a/federationsender/storage/postgres/joined_hosts_table.go +++ b/federationsender/storage/postgres/joined_hosts_table.go @@ -60,12 +60,16 @@ const selectJoinedHostsSQL = "" + const selectAllJoinedHostsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts" +const selectJoinedHostsForRoomsSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)" + type joinedHostsStatements struct { - db *sql.DB - insertJoinedHostsStmt *sql.Stmt - deleteJoinedHostsStmt *sql.Stmt - selectJoinedHostsStmt *sql.Stmt - selectAllJoinedHostsStmt *sql.Stmt + db *sql.DB + insertJoinedHostsStmt *sql.Stmt + deleteJoinedHostsStmt *sql.Stmt + selectJoinedHostsStmt *sql.Stmt + selectAllJoinedHostsStmt *sql.Stmt + selectJoinedHostsForRoomsStmt *sql.Stmt } func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { @@ -88,6 +92,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro if s.selectAllJoinedHostsStmt, err = s.db.Prepare(selectAllJoinedHostsSQL); err != nil { return } + if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil { + return + } return } @@ -144,6 +151,27 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts( return result, rows.Err() } +func (s *joinedHostsStatements) SelectJoinedHostsForRooms( + ctx context.Context, roomIDs []string, +) ([]gomatrixserverlib.ServerName, error) { + rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName string + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(serverName)) + } + + return result, rows.Err() +} + func joinedHostsFromStmt( ctx context.Context, stmt *sql.Stmt, roomID string, ) ([]types.JoinedHost, error) { diff --git a/federationsender/storage/shared/storage.go b/federationsender/storage/shared/storage.go index 52f02a28b..4a681de65 100644 --- a/federationsender/storage/shared/storage.go +++ b/federationsender/storage/shared/storage.go @@ -123,6 +123,10 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S return d.FederationSenderJoinedHosts.SelectAllJoinedHosts(ctx) } +func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) { + return d.FederationSenderJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs) +} + // StoreJSON adds a JSON blob into the queue JSON table and returns // a NID. The NID will then be used when inserting the per-destination // metadata entries. diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go index bd917c61a..4ae980d7c 100644 --- a/federationsender/storage/sqlite3/joined_hosts_table.go +++ b/federationsender/storage/sqlite3/joined_hosts_table.go @@ -59,13 +59,17 @@ const selectJoinedHostsSQL = "" + const selectAllJoinedHostsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts" +const selectJoinedHostsForRoomsSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" + type joinedHostsStatements struct { - db *sql.DB - writer *sqlutil.TransactionWriter - insertJoinedHostsStmt *sql.Stmt - deleteJoinedHostsStmt *sql.Stmt - selectJoinedHostsStmt *sql.Stmt - selectAllJoinedHostsStmt *sql.Stmt + db *sql.DB + writer *sqlutil.TransactionWriter + insertJoinedHostsStmt *sql.Stmt + deleteJoinedHostsStmt *sql.Stmt + selectJoinedHostsStmt *sql.Stmt + selectAllJoinedHostsStmt *sql.Stmt + selectJoinedHostsForRoomsStmt *sql.Stmt } func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { @@ -89,6 +93,9 @@ func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil { return } + if s.selectJoinedHostsForRoomsStmt, err = db.Prepare(selectJoinedHostsForRoomsSQL); err != nil { + return + } return } @@ -153,6 +160,32 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts( return result, rows.Err() } +func (s *joinedHostsStatements) SelectJoinedHostsForRooms( + ctx context.Context, roomIDs []string, +) ([]gomatrixserverlib.ServerName, error) { + iRoomIDs := make([]interface{}, len(roomIDs)) + for i := range roomIDs { + iRoomIDs[i] = roomIDs[i] + } + + rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, iRoomIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName string + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(serverName)) + } + + return result, rows.Err() +} + func joinedHostsFromStmt( ctx context.Context, stmt *sql.Stmt, roomID string, ) ([]types.JoinedHost, error) { diff --git a/federationsender/storage/tables/interface.go b/federationsender/storage/tables/interface.go index 2def48d07..c6f8a2d52 100644 --- a/federationsender/storage/tables/interface.go +++ b/federationsender/storage/tables/interface.go @@ -53,6 +53,7 @@ type FederationSenderJoinedHosts interface { SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error) SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) + SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) } type FederationSenderRooms interface {