Don't send to ACL'd servers (#1267)

* Don't send to ACL'd servers

* Use gjson to look for room_id in EDU
This commit is contained in:
Neil Alexander 2020-08-13 14:23:37 +01:00 committed by GitHub
parent 9677a95afc
commit 4c4732a9c9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 35 deletions

View file

@ -58,7 +58,8 @@ func NewInternalAPI(
} }
queues := queue.NewOutgoingQueues( queues := queue.NewOutgoingQueues(
federationSenderDB, cfg.Matrix.ServerName, federation, rsAPI, stats, federationSenderDB, cfg.Matrix.ServerName, federation,
rsAPI, stateAPI, stats,
&queue.SigningInfo{ &queue.SigningInfo{
KeyID: cfg.Matrix.KeyID, KeyID: cfg.Matrix.KeyID,
PrivateKey: cfg.Matrix.PrivateKey, PrivateKey: cfg.Matrix.PrivateKey,

View file

@ -21,12 +21,13 @@ import (
"fmt" "fmt"
"sync" "sync"
stateapi "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/federationsender/statistics" "github.com/matrix-org/dendrite/federationsender/statistics"
"github.com/matrix-org/dendrite/federationsender/storage" "github.com/matrix-org/dendrite/federationsender/storage"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
) )
// OutgoingQueues is a collection of queues for sending transactions to other // OutgoingQueues is a collection of queues for sending transactions to other
@ -34,6 +35,7 @@ import (
type OutgoingQueues struct { type OutgoingQueues struct {
db storage.Database db storage.Database
rsAPI api.RoomserverInternalAPI rsAPI api.RoomserverInternalAPI
stateAPI stateapi.CurrentStateInternalAPI
origin gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName
client *gomatrixserverlib.FederationClient client *gomatrixserverlib.FederationClient
statistics *statistics.Statistics statistics *statistics.Statistics
@ -48,12 +50,14 @@ func NewOutgoingQueues(
origin gomatrixserverlib.ServerName, origin gomatrixserverlib.ServerName,
client *gomatrixserverlib.FederationClient, client *gomatrixserverlib.FederationClient,
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
stateAPI stateapi.CurrentStateInternalAPI,
statistics *statistics.Statistics, statistics *statistics.Statistics,
signing *SigningInfo, signing *SigningInfo,
) *OutgoingQueues { ) *OutgoingQueues {
queues := &OutgoingQueues{ queues := &OutgoingQueues{
db: db, db: db,
rsAPI: rsAPI, rsAPI: rsAPI,
stateAPI: stateAPI,
origin: origin, origin: origin,
client: client, client: client,
statistics: statistics, statistics: statistics,
@ -128,14 +132,33 @@ func (oqs *OutgoingQueues) SendEvent(
) )
} }
// Remove our own server from the list of destinations. // Deduplicate destinations and remove the origin from the list of
destinations = filterAndDedupeDests(oqs.origin, destinations) // destinations just to be sure.
if len(destinations) == 0 { destmap := map[gomatrixserverlib.ServerName]struct{}{}
for _, d := range destinations {
destmap[d] = struct{}{}
}
delete(destmap, oqs.origin)
// Check if any of the destinations are prohibited by server ACLs.
for destination := range destmap {
if stateapi.IsServerBannedFromRoom(
context.TODO(),
oqs.stateAPI,
ev.RoomID(),
destination,
) {
delete(destmap, destination)
}
}
// If there are no remaining destinations then give up.
if len(destmap) == 0 {
return nil return nil
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"destinations": destinations, "event": ev.EventID(), "destinations": len(destmap), "event": ev.EventID(),
}).Infof("Sending event") }).Infof("Sending event")
headeredJSON, err := json.Marshal(ev) headeredJSON, err := json.Marshal(ev)
@ -148,7 +171,7 @@ func (oqs *OutgoingQueues) SendEvent(
return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err) return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err)
} }
for _, destination := range destinations { for destination := range destmap {
oqs.getQueue(destination).sendEvent(nid) oqs.getQueue(destination).sendEvent(nid)
} }
@ -164,7 +187,7 @@ func (oqs *OutgoingQueues) SendInvite(
if stateKey == nil { if stateKey == nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": ev.EventID(), "event_id": ev.EventID(),
}).Info("invite had no state key, dropping") }).Info("Invite had no state key, dropping")
return nil return nil
} }
@ -173,7 +196,20 @@ func (oqs *OutgoingQueues) SendInvite(
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": ev.EventID(), "event_id": ev.EventID(),
"state_key": stateKey, "state_key": stateKey,
}).Info("failed to split destination from state key") }).Info("Failed to split destination from state key")
return nil
}
if stateapi.IsServerBannedFromRoom(
context.TODO(),
oqs.stateAPI,
ev.RoomID(),
destination,
) {
log.WithFields(log.Fields{
"room_id": ev.RoomID(),
"destination": destination,
}).Info("Dropping invite to server which is prohibited by ACLs")
return nil return nil
} }
@ -200,14 +236,40 @@ func (oqs *OutgoingQueues) SendEDU(
) )
} }
// Remove our own server from the list of destinations. // Deduplicate destinations and remove the origin from the list of
destinations = filterAndDedupeDests(oqs.origin, destinations) // destinations just to be sure.
destmap := map[gomatrixserverlib.ServerName]struct{}{}
if len(destinations) > 0 { for _, d := range destinations {
log.WithFields(log.Fields{ destmap[d] = struct{}{}
"destinations": destinations, "edu_type": e.Type,
}).Info("Sending EDU event")
} }
delete(destmap, oqs.origin)
// There is absolutely no guarantee that the EDU will have a room_id
// field, as it is not required by the spec. However, if it *does*
// (e.g. typing notifications) then we should try to make sure we don't
// bother sending them to servers that are prohibited by the server
// ACLs.
if result := gjson.GetBytes(e.Content, "room_id"); result.Exists() {
for destination := range destmap {
if stateapi.IsServerBannedFromRoom(
context.TODO(),
oqs.stateAPI,
result.Str,
destination,
) {
delete(destmap, destination)
}
}
}
// If there are no remaining destinations then give up.
if len(destmap) == 0 {
return nil
}
log.WithFields(log.Fields{
"destinations": len(destmap), "edu_type": e.Type,
}).Info("Sending EDU event")
ephemeralJSON, err := json.Marshal(e) ephemeralJSON, err := json.Marshal(e)
if err != nil { if err != nil {
@ -219,7 +281,7 @@ func (oqs *OutgoingQueues) SendEDU(
return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err) return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err)
} }
for _, destination := range destinations { for destination := range destmap {
oqs.getQueue(destination).sendEDU(nid) oqs.getQueue(destination).sendEDU(nid)
} }
@ -234,21 +296,3 @@ func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) {
} }
q.wakeQueueIfNeeded() q.wakeQueueIfNeeded()
} }
// filterAndDedupeDests removes our own server from the list of destinations
// and deduplicates any servers in the list that may appear more than once.
func filterAndDedupeDests(origin gomatrixserverlib.ServerName, destinations []gomatrixserverlib.ServerName) (
result []gomatrixserverlib.ServerName,
) {
strs := make([]string, len(destinations))
for i, d := range destinations {
strs[i] = string(d)
}
for _, destination := range util.UniqueStrings(strs) {
if gomatrixserverlib.ServerName(destination) == origin {
continue
}
result = append(result, gomatrixserverlib.ServerName(destination))
}
return result
}