mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-21 05:43:09 -06:00
Send device list updates to all required servers
This commit is contained in:
parent
b67e0b3374
commit
1ed7102f07
|
|
@ -13,10 +13,12 @@
|
||||||
package consumers
|
package consumers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/Shopify/sarama"
|
"github.com/Shopify/sarama"
|
||||||
|
stateapi "github.com/matrix-org/dendrite/currentstateserver/api"
|
||||||
"github.com/matrix-org/dendrite/federationsender/queue"
|
"github.com/matrix-org/dendrite/federationsender/queue"
|
||||||
"github.com/matrix-org/dendrite/federationsender/storage"
|
"github.com/matrix-org/dendrite/federationsender/storage"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
|
@ -32,6 +34,7 @@ type KeyChangeConsumer struct {
|
||||||
db storage.Database
|
db storage.Database
|
||||||
queues *queue.OutgoingQueues
|
queues *queue.OutgoingQueues
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
|
stateAPI stateapi.CurrentStateInternalAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers.
|
// NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers.
|
||||||
|
|
@ -40,6 +43,7 @@ func NewKeyChangeConsumer(
|
||||||
kafkaConsumer sarama.Consumer,
|
kafkaConsumer sarama.Consumer,
|
||||||
queues *queue.OutgoingQueues,
|
queues *queue.OutgoingQueues,
|
||||||
store storage.Database,
|
store storage.Database,
|
||||||
|
stateAPI stateapi.CurrentStateInternalAPI,
|
||||||
) *KeyChangeConsumer {
|
) *KeyChangeConsumer {
|
||||||
c := &KeyChangeConsumer{
|
c := &KeyChangeConsumer{
|
||||||
consumer: &internal.ContinualConsumer{
|
consumer: &internal.ContinualConsumer{
|
||||||
|
|
@ -50,6 +54,7 @@ func NewKeyChangeConsumer(
|
||||||
queues: queues,
|
queues: queues,
|
||||||
db: store,
|
db: store,
|
||||||
serverName: cfg.Matrix.ServerName,
|
serverName: cfg.Matrix.ServerName,
|
||||||
|
stateAPI: stateAPI,
|
||||||
}
|
}
|
||||||
c.consumer.ProcessMessage = c.onMessage
|
c.consumer.ProcessMessage = c.onMessage
|
||||||
|
|
||||||
|
|
@ -72,19 +77,33 @@ func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error {
|
||||||
log.WithError(err).Errorf("failed to read device message from key change topic")
|
log.WithError(err).Errorf("failed to read device message from key change topic")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
logger := log.WithField("user_id", m.UserID)
|
||||||
|
|
||||||
// only send key change events which originated from us
|
// only send key change events which originated from us
|
||||||
_, originServerName, err := gomatrixserverlib.SplitID('@', m.UserID)
|
_, originServerName, err := gomatrixserverlib.SplitID('@', m.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).WithField("user_id", m.UserID).Error("Failed to extract domain from key change event")
|
logger.WithError(err).Error("Failed to extract domain from key change event")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if originServerName != t.serverName {
|
if originServerName != t.serverName {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: send this key change to all users who share rooms with this user.
|
var queryRes stateapi.QueryRoomsForUserResponse
|
||||||
var destinations []gomatrixserverlib.ServerName
|
err = t.stateAPI.QueryRoomsForUser(context.Background(), &stateapi.QueryRoomsForUserRequest{
|
||||||
|
UserID: m.UserID,
|
||||||
|
WantMembership: "join",
|
||||||
|
}, &queryRes)
|
||||||
|
if err != nil {
|
||||||
|
logger.WithError(err).Error("failed to calculate joined rooms for user")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// send this key change to all servers who share rooms with this user.
|
||||||
|
destinations, err := t.db.GetJoinedHostsForRooms(context.Background(), queryRes.RoomIDs)
|
||||||
|
if err != nil {
|
||||||
|
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Pack the EDU and marshal it
|
// Pack the EDU and marshal it
|
||||||
edu := &gomatrixserverlib.EDU{
|
edu := &gomatrixserverlib.EDU{
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,8 @@ type Database interface {
|
||||||
|
|
||||||
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
|
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
|
||||||
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||||
|
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
|
||||||
|
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
|
||||||
|
|
||||||
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
|
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -60,12 +60,16 @@ const selectJoinedHostsSQL = "" +
|
||||||
const selectAllJoinedHostsSQL = "" +
|
const selectAllJoinedHostsSQL = "" +
|
||||||
"SELECT DISTINCT server_name FROM federationsender_joined_hosts"
|
"SELECT DISTINCT server_name FROM federationsender_joined_hosts"
|
||||||
|
|
||||||
|
const selectJoinedHostsForRoomsSQL = "" +
|
||||||
|
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)"
|
||||||
|
|
||||||
type joinedHostsStatements struct {
|
type joinedHostsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertJoinedHostsStmt *sql.Stmt
|
insertJoinedHostsStmt *sql.Stmt
|
||||||
deleteJoinedHostsStmt *sql.Stmt
|
deleteJoinedHostsStmt *sql.Stmt
|
||||||
selectJoinedHostsStmt *sql.Stmt
|
selectJoinedHostsStmt *sql.Stmt
|
||||||
selectAllJoinedHostsStmt *sql.Stmt
|
selectAllJoinedHostsStmt *sql.Stmt
|
||||||
|
selectJoinedHostsForRoomsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
|
func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
|
||||||
|
|
@ -88,6 +92,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro
|
||||||
if s.selectAllJoinedHostsStmt, err = s.db.Prepare(selectAllJoinedHostsSQL); err != nil {
|
if s.selectAllJoinedHostsStmt, err = s.db.Prepare(selectAllJoinedHostsSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -144,6 +151,27 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
|
||||||
|
ctx context.Context, roomIDs []string,
|
||||||
|
) ([]gomatrixserverlib.ServerName, error) {
|
||||||
|
rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed")
|
||||||
|
|
||||||
|
var result []gomatrixserverlib.ServerName
|
||||||
|
for rows.Next() {
|
||||||
|
var serverName string
|
||||||
|
if err = rows.Scan(&serverName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, gomatrixserverlib.ServerName(serverName))
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
func joinedHostsFromStmt(
|
func joinedHostsFromStmt(
|
||||||
ctx context.Context, stmt *sql.Stmt, roomID string,
|
ctx context.Context, stmt *sql.Stmt, roomID string,
|
||||||
) ([]types.JoinedHost, error) {
|
) ([]types.JoinedHost, error) {
|
||||||
|
|
|
||||||
|
|
@ -123,6 +123,10 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S
|
||||||
return d.FederationSenderJoinedHosts.SelectAllJoinedHosts(ctx)
|
return d.FederationSenderJoinedHosts.SelectAllJoinedHosts(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) {
|
||||||
|
return d.FederationSenderJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs)
|
||||||
|
}
|
||||||
|
|
||||||
// StoreJSON adds a JSON blob into the queue JSON table and returns
|
// StoreJSON adds a JSON blob into the queue JSON table and returns
|
||||||
// a NID. The NID will then be used when inserting the per-destination
|
// a NID. The NID will then be used when inserting the per-destination
|
||||||
// metadata entries.
|
// metadata entries.
|
||||||
|
|
|
||||||
|
|
@ -59,13 +59,17 @@ const selectJoinedHostsSQL = "" +
|
||||||
const selectAllJoinedHostsSQL = "" +
|
const selectAllJoinedHostsSQL = "" +
|
||||||
"SELECT DISTINCT server_name FROM federationsender_joined_hosts"
|
"SELECT DISTINCT server_name FROM federationsender_joined_hosts"
|
||||||
|
|
||||||
|
const selectJoinedHostsForRoomsSQL = "" +
|
||||||
|
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
|
||||||
|
|
||||||
type joinedHostsStatements struct {
|
type joinedHostsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer *sqlutil.TransactionWriter
|
||||||
insertJoinedHostsStmt *sql.Stmt
|
insertJoinedHostsStmt *sql.Stmt
|
||||||
deleteJoinedHostsStmt *sql.Stmt
|
deleteJoinedHostsStmt *sql.Stmt
|
||||||
selectJoinedHostsStmt *sql.Stmt
|
selectJoinedHostsStmt *sql.Stmt
|
||||||
selectAllJoinedHostsStmt *sql.Stmt
|
selectAllJoinedHostsStmt *sql.Stmt
|
||||||
|
selectJoinedHostsForRoomsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
|
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
|
||||||
|
|
@ -89,6 +93,9 @@ func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error)
|
||||||
if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil {
|
if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.selectJoinedHostsForRoomsStmt, err = db.Prepare(selectJoinedHostsForRoomsSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -153,6 +160,32 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
|
||||||
|
ctx context.Context, roomIDs []string,
|
||||||
|
) ([]gomatrixserverlib.ServerName, error) {
|
||||||
|
iRoomIDs := make([]interface{}, len(roomIDs))
|
||||||
|
for i := range roomIDs {
|
||||||
|
iRoomIDs[i] = roomIDs[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, iRoomIDs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed")
|
||||||
|
|
||||||
|
var result []gomatrixserverlib.ServerName
|
||||||
|
for rows.Next() {
|
||||||
|
var serverName string
|
||||||
|
if err = rows.Scan(&serverName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, gomatrixserverlib.ServerName(serverName))
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
func joinedHostsFromStmt(
|
func joinedHostsFromStmt(
|
||||||
ctx context.Context, stmt *sql.Stmt, roomID string,
|
ctx context.Context, stmt *sql.Stmt, roomID string,
|
||||||
) ([]types.JoinedHost, error) {
|
) ([]types.JoinedHost, error) {
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,7 @@ type FederationSenderJoinedHosts interface {
|
||||||
SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error)
|
SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error)
|
||||||
SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
|
SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
|
||||||
SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||||
|
SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type FederationSenderRooms interface {
|
type FederationSenderRooms interface {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue