Merge branch 'master' into add-receipts

This commit is contained in:
Neil Alexander 2020-10-13 11:55:01 +01:00 committed by GitHub
commit 16515ebc40
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 270 additions and 120 deletions

View file

@ -38,21 +38,38 @@ go run github.com/matrix-org/dendrite/cmd/generate-keys \
--tls-key=server.key --tls-key=server.key
``` ```
## Starting Dendrite ## Starting Dendrite as a monolith deployment
Once in place, start the dependencies: Create your config based on the `dendrite.yaml` configuration file in the `docker/config`
folder in the [Dendrite repository](https://github.com/matrix-org/dendrite). Additionally,
make the following changes to the configuration:
- Enable Naffka: `use_naffka: true`
Once in place, start the PostgreSQL dependency:
``` ```
docker-compose -f docker-compose.deps.yml up docker-compose -f docker-compose.deps.yml up postgres
``` ```
Wait a few seconds for Kafka and Postgres to finish starting up, and then start a monolith: Wait a few seconds for PostgreSQL to finish starting up, and then start a monolith:
``` ```
docker-compose -f docker-compose.monolith.yml up docker-compose -f docker-compose.monolith.yml up
``` ```
... or start the polylith components: ## Starting Dendrite as a polylith deployment
Create your config based on the `dendrite.yaml` configuration file in the `docker/config`
folder in the [Dendrite repository](https://github.com/matrix-org/dendrite).
Once in place, start all the dependencies:
```
docker-compose -f docker-compose.deps.yml up
```
Wait a few seconds for PostgreSQL and Kafka to finish starting up, and then start a polylith:
``` ```
docker-compose -f docker-compose.polylith.yml up docker-compose -f docker-compose.polylith.yml up

View file

@ -76,7 +76,7 @@ global:
# Naffka database options. Not required when using Kafka. # Naffka database options. Not required when using Kafka.
naffka_database: naffka_database:
connection_string: file:naffka.db connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_naffka?sslmode=disable
max_open_conns: 100 max_open_conns: 100
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1

View file

@ -6,6 +6,9 @@ services:
restart: always restart: always
volumes: volumes:
- ./postgres/create_db.sh:/docker-entrypoint-initdb.d/20-create_db.sh - ./postgres/create_db.sh:/docker-entrypoint-initdb.d/20-create_db.sh
# To persist your PostgreSQL databases outside of the Docker image, to
# prevent data loss, you will need to add something like this:
# - ./path/to/persistent/storage:/var/lib/postgresql/data
environment: environment:
POSTGRES_PASSWORD: itsasecret POSTGRES_PASSWORD: itsasecret
POSTGRES_USER: dendrite POSTGRES_USER: dendrite

2
build/docker/postgres/create_db.sh Normal file → Executable file
View file

@ -1,4 +1,4 @@
#!/bin/bash #!/bin/sh
for db in account device mediaapi syncapi roomserver signingkeyserver keyserver federationsender appservice e2ekey naffka; do for db in account device mediaapi syncapi roomserver signingkeyserver keyserver federationsender appservice e2ekey naffka; do
createdb -U dendrite -O dendrite dendrite_$db createdb -U dendrite -O dendrite dendrite_$db

View file

@ -0,0 +1,17 @@
[Unit]
Description=Dendrite (Matrix Homeserver)
After=syslog.target
After=network.target
After=postgresql.service
[Service]
RestartSec=2s
Type=simple
User=dendrite
Group=dendrite
WorkingDirectory=/opt/dendrite/
ExecStart=/opt/dendrite/bin/dendrite-monolith-server
Restart=always
[Install]
WantedBy=multi-user.target

View file

@ -183,7 +183,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res
// Process the events. // Process the events.
for _, e := range pdus { for _, e := range pdus {
if err := t.processEvent(ctx, e.Unwrap(), true); err != nil { if err := t.processEvent(ctx, e.Unwrap()); err != nil {
// If the error is due to the event itself being bad then we skip // If the error is due to the event itself being bad then we skip
// it and move onto the next event. We report an error so that the // it and move onto the next event. We report an error so that the
// sender knows that we have skipped processing it. // sender knows that we have skipped processing it.
@ -246,9 +246,6 @@ func isProcessingErrorFatal(err error) bool {
type roomNotFoundError struct { type roomNotFoundError struct {
roomID string roomID string
} }
type unmarshalError struct {
err error
}
type verifySigError struct { type verifySigError struct {
eventID string eventID string
err error err error
@ -259,7 +256,6 @@ type missingPrevEventsError struct {
} }
func (e roomNotFoundError) Error() string { return fmt.Sprintf("room %q not found", e.roomID) } func (e roomNotFoundError) Error() string { return fmt.Sprintf("room %q not found", e.roomID) }
func (e unmarshalError) Error() string { return fmt.Sprintf("unable to parse event: %s", e.err) }
func (e verifySigError) Error() string { func (e verifySigError) Error() string {
return fmt.Sprintf("unable to verify signature of event %q: %s", e.eventID, e.err) return fmt.Sprintf("unable to verify signature of event %q: %s", e.eventID, e.err)
} }
@ -338,11 +334,28 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli
} }
} }
func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, isInboundTxn bool) error { func (t *txnReq) getServers(ctx context.Context, roomID string) []gomatrixserverlib.ServerName {
servers := []gomatrixserverlib.ServerName{t.Origin}
serverReq := &api.QueryServerJoinedToRoomRequest{
RoomID: roomID,
}
serverRes := &api.QueryServerJoinedToRoomResponse{}
if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil {
servers = append(servers, serverRes.ServerNames...)
util.GetLogger(ctx).Infof("Found %d server(s) to query for missing events in %q", len(servers), roomID)
}
return servers
}
func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event) error {
logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID())
// Work out if the roomserver knows everything it needs to know to auth // Work out if the roomserver knows everything it needs to know to auth
// the event. // the event. This includes the prev_events and auth_events.
// NOTE! This is going to include prev_events that have an empty state
// snapshot. This is because we will need to re-request the event, and
// it's /state_ids, in order for it to exist in the roomserver correctly
// before the roomserver tries to work out
stateReq := api.QueryMissingAuthPrevEventsRequest{ stateReq := api.QueryMissingAuthPrevEventsRequest{
RoomID: e.RoomID(), RoomID: e.RoomID(),
AuthEventIDs: e.AuthEventIDs(), AuthEventIDs: e.AuthEventIDs(),
@ -350,7 +363,7 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is
} }
var stateResp api.QueryMissingAuthPrevEventsResponse var stateResp api.QueryMissingAuthPrevEventsResponse
if err := t.rsAPI.QueryMissingAuthPrevEvents(ctx, &stateReq, &stateResp); err != nil { if err := t.rsAPI.QueryMissingAuthPrevEvents(ctx, &stateReq, &stateResp); err != nil {
return err return fmt.Errorf("t.rsAPI.QueryMissingAuthPrevEvents: %w", err)
} }
if !stateResp.RoomExists { if !stateResp.RoomExists {
@ -365,52 +378,14 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is
if len(stateResp.MissingAuthEventIDs) > 0 { if len(stateResp.MissingAuthEventIDs) > 0 {
logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs)) logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs))
if err := t.retrieveMissingAuthEvents(ctx, e, &stateResp); err != nil {
servers := []gomatrixserverlib.ServerName{t.Origin} return fmt.Errorf("t.retrieveMissingAuthEvents: %w", err)
serverReq := &api.QueryServerJoinedToRoomRequest{
RoomID: e.RoomID(),
}
serverRes := &api.QueryServerJoinedToRoomResponse{}
if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil {
servers = append(servers, serverRes.ServerNames...)
logger.Infof("Found %d server(s) to query for missing events", len(servers))
}
getAuthEvent:
for _, missingAuthEventID := range stateResp.MissingAuthEventIDs {
for _, server := range servers {
logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server)
tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID)
if err != nil {
continue // try the next server
}
ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion)
if err != nil {
logger.WithError(err).Errorf("Failed to unmarshal auth event %q", missingAuthEventID)
continue // try the next server
}
if err = api.SendInputRoomEvents(
context.Background(),
t.rsAPI,
[]api.InputRoomEvent{
{
Kind: api.KindOutlier,
Event: ev.Headered(stateResp.RoomVersion),
AuthEventIDs: ev.AuthEventIDs(),
SendAsServer: api.DoNotSendToOtherServers,
},
},
); err != nil {
logger.WithError(err).Errorf("Failed to send auth event %q to roomserver", missingAuthEventID)
continue getAuthEvent // move onto the next event
}
}
} }
} }
if len(stateResp.MissingPrevEventIDs) > 0 { if len(stateResp.MissingPrevEventIDs) > 0 {
logger.Infof("Event refers to %d unknown prev_events", len(stateResp.MissingPrevEventIDs)) logger.Infof("Event refers to %d unknown prev_events", len(stateResp.MissingPrevEventIDs))
return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion, isInboundTxn) return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion)
} }
// pass the event to the roomserver which will do auth checks // pass the event to the roomserver which will do auth checks
@ -427,6 +402,60 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is
) )
} }
func (t *txnReq) retrieveMissingAuthEvents(
ctx context.Context, e gomatrixserverlib.Event, stateResp *api.QueryMissingAuthPrevEventsResponse,
) error {
logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID())
missingAuthEvents := make(map[string]struct{})
for _, missingAuthEventID := range stateResp.MissingAuthEventIDs {
missingAuthEvents[missingAuthEventID] = struct{}{}
}
servers := t.getServers(ctx, e.RoomID())
if len(servers) > 5 {
servers = servers[:5]
}
withNextEvent:
for missingAuthEventID := range missingAuthEvents {
withNextServer:
for _, server := range servers {
logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server)
tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID)
if err != nil {
logger.WithError(err).Warnf("Failed to retrieve auth event %q", missingAuthEventID)
continue withNextServer
}
ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion)
if err != nil {
logger.WithError(err).Warnf("Failed to unmarshal auth event %q", missingAuthEventID)
continue withNextServer
}
if err = api.SendInputRoomEvents(
context.Background(),
t.rsAPI,
[]api.InputRoomEvent{
{
Kind: api.KindOutlier,
Event: ev.Headered(stateResp.RoomVersion),
AuthEventIDs: ev.AuthEventIDs(),
SendAsServer: api.DoNotSendToOtherServers,
},
},
); err != nil {
return fmt.Errorf("api.SendEvents: %w", err)
}
delete(missingAuthEvents, missingAuthEventID)
continue withNextEvent
}
}
if missing := len(missingAuthEvents); missing > 0 {
return fmt.Errorf("Event refers to %d auth_events which we failed to fetch", missing)
}
return nil
}
func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserverlib.Event) error { func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserverlib.Event) error {
authUsingState := gomatrixserverlib.NewAuthEvents(nil) authUsingState := gomatrixserverlib.NewAuthEvents(nil)
for i := range stateEvents { for i := range stateEvents {
@ -438,7 +467,7 @@ func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserver
return gomatrixserverlib.Allowed(e, &authUsingState) return gomatrixserverlib.Allowed(e, &authUsingState)
} }
func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) error { func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) error {
// Do this with a fresh context, so that we keep working even if the // Do this with a fresh context, so that we keep working even if the
// original request times out. With any luck, by the time the remote // original request times out. With any luck, by the time the remote
// side retries, we'll have fetched the missing state. // side retries, we'll have fetched the missing state.
@ -464,39 +493,82 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser
// - fill in the gap completely then process event `e` returning no backwards extremity // - fill in the gap completely then process event `e` returning no backwards extremity
// - fail to fill in the gap and tell us to terminate the transaction err=not nil // - fail to fill in the gap and tell us to terminate the transaction err=not nil
// - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction
backwardsExtremity, err := t.getMissingEvents(gmectx, e, roomVersion, isInboundTxn) newEvents, err := t.getMissingEvents(gmectx, e, roomVersion)
if err != nil { if err != nil {
return err return err
} }
if backwardsExtremity == nil { if len(newEvents) == 0 {
// we filled in the gap!
return nil return nil
} }
backwardsExtremity := &newEvents[0]
newEvents = newEvents[1:]
// at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity.
// security: we have to do state resolution on the new backwards extremity (TODO: WHY)
// Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query
// the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event.
var states []*gomatrixserverlib.RespState var states []*gomatrixserverlib.RespState
needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{*backwardsExtremity}).Tuples() needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{*backwardsExtremity}).Tuples()
for _, prevEventID := range backwardsExtremity.PrevEventIDs() { for _, prevEventID := range backwardsExtremity.PrevEventIDs() {
// Look up what the state is after the backward extremity. This will either
// come from the roomserver, if we know all the required events, or it will
// come from a remote server via /state_ids if not.
var prevState *gomatrixserverlib.RespState var prevState *gomatrixserverlib.RespState
prevState, err = t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID, needed) prevState, err = t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID, needed)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID) util.GetLogger(ctx).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID)
return err return err
} }
// Append the state onto the collected state. We'll run this through the
// state resolution next.
states = append(states, prevState) states = append(states, prevState)
} }
// Now that we have collected all of the state from the prev_events, we'll
// run the state through the appropriate state resolution algorithm for the
// room. This does a couple of things:
// 1. Ensures that the state is deduplicated fully for each state-key tuple
// 2. Ensures that we pick the latest events from both sets, in the case that
// one of the prev_events is quite a bit older than the others
resolvedState, err := t.resolveStatesAndCheck(gmectx, roomVersion, states, backwardsExtremity) resolvedState, err := t.resolveStatesAndCheck(gmectx, roomVersion, states, backwardsExtremity)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID())
return err return err
} }
// pass the event along with the state to the roomserver using a background context so we don't // First of all, send the backward extremity into the roomserver with the
// needlessly expire // newly resolved state. This marks the "oldest" point in the backfill and
return api.SendEventWithState(context.Background(), t.rsAPI, resolvedState, e.Headered(roomVersion), t.haveEventIDs()) // sets the baseline state for any new events after this.
err = api.SendEventWithState(
context.Background(),
t.rsAPI,
resolvedState,
backwardsExtremity.Headered(roomVersion),
t.haveEventIDs(),
)
if err != nil {
return fmt.Errorf("api.SendEventWithState: %w", err)
}
// Then send all of the newer backfilled events, of which will all be newer
// than the backward extremity, into the roomserver without state. This way
// they will automatically fast-forward based on the room state at the
// extremity in the last step.
headeredNewEvents := make([]gomatrixserverlib.HeaderedEvent, len(newEvents))
for i, newEvent := range newEvents {
headeredNewEvents[i] = newEvent.Headered(roomVersion)
}
if err = api.SendEvents(
context.Background(),
t.rsAPI,
append(headeredNewEvents, e.Headered(roomVersion)),
api.DoNotSendToOtherServers,
nil,
); err != nil {
return fmt.Errorf("api.SendEvents: %w", err)
}
return nil
} }
// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event)
@ -510,18 +582,23 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix
respState, err := t.lookupStateBeforeEvent(ctx, roomVersion, roomID, eventID) respState, err := t.lookupStateBeforeEvent(ctx, roomVersion, roomID, eventID)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("t.lookupStateBeforeEvent: %w", err)
}
servers := t.getServers(ctx, roomID)
if len(servers) > 5 {
servers = servers[:5]
} }
// fetch the event we're missing and add it to the pile // fetch the event we're missing and add it to the pile
h, err := t.lookupEvent(ctx, roomVersion, eventID, false) h, err := t.lookupEvent(ctx, roomVersion, eventID, false, servers)
switch err.(type) { switch err.(type) {
case verifySigError: case verifySigError:
return respState, nil return respState, nil
case nil: case nil:
// do nothing // do nothing
default: default:
return nil, err return nil, fmt.Errorf("t.lookupEvent: %w", err)
} }
t.haveEvents[h.EventID()] = h t.haveEvents[h.EventID()] = h
if h.StateKey() != nil { if h.StateKey() != nil {
@ -622,7 +699,11 @@ retryAllowedState:
if err = checkAllowedByState(*backwardsExtremity, resolvedStateEvents); err != nil { if err = checkAllowedByState(*backwardsExtremity, resolvedStateEvents); err != nil {
switch missing := err.(type) { switch missing := err.(type) {
case gomatrixserverlib.MissingAuthEventError: case gomatrixserverlib.MissingAuthEventError:
h, err2 := t.lookupEvent(ctx, roomVersion, missing.AuthEventID, true) servers := t.getServers(ctx, backwardsExtremity.RoomID())
if len(servers) > 5 {
servers = servers[:5]
}
h, err2 := t.lookupEvent(ctx, roomVersion, missing.AuthEventID, true, servers)
switch err2.(type) { switch err2.(type) {
case verifySigError: case verifySigError:
return &gomatrixserverlib.RespState{ return &gomatrixserverlib.RespState{
@ -652,11 +733,7 @@ retryAllowedState:
// This function recursively calls txnReq.processEvent with the missing events, which will be processed before this function returns. // This function recursively calls txnReq.processEvent with the missing events, which will be processed before this function returns.
// This means that we may recursively call this function, as we spider back up prev_events. // This means that we may recursively call this function, as we spider back up prev_events.
// nolint:gocyclo // nolint:gocyclo
func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) (backwardsExtremity *gomatrixserverlib.Event, err error) { func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []gomatrixserverlib.Event, err error) {
if !isInboundTxn {
// we've recursed here, so just take a state snapshot please!
return &e, nil
}
logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID())
needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{e}) needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{e})
// query latest events (our trusted forward extremities) // query latest events (our trusted forward extremities)
@ -667,7 +744,7 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event
var res api.QueryLatestEventsAndStateResponse var res api.QueryLatestEventsAndStateResponse
if err = t.rsAPI.QueryLatestEventsAndState(ctx, &req, &res); err != nil { if err = t.rsAPI.QueryLatestEventsAndState(ctx, &req, &res); err != nil {
logger.WithError(err).Warn("Failed to query latest events") logger.WithError(err).Warn("Failed to query latest events")
return &e, nil return nil, err
} }
latestEvents := make([]string, len(res.LatestEvents)) latestEvents := make([]string, len(res.LatestEvents))
for i := range res.LatestEvents { for i := range res.LatestEvents {
@ -726,7 +803,7 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event
logger.Infof("get_missing_events returned %d events", len(missingResp.Events)) logger.Infof("get_missing_events returned %d events", len(missingResp.Events))
// topologically sort and sanity check that we are making forward progress // topologically sort and sanity check that we are making forward progress
newEvents := gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents) newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents)
shouldHaveSomeEventIDs := e.PrevEventIDs() shouldHaveSomeEventIDs := e.PrevEventIDs()
hasPrevEvent := false hasPrevEvent := false
Event: Event:
@ -749,16 +826,9 @@ Event:
err: err, err: err,
} }
} }
// process the missing events then the event which started this whole thing
for _, ev := range append(newEvents, e) {
err := t.processEvent(ctx, ev, false)
if err != nil {
return nil, err
}
}
// we processed everything! // we processed everything!
return nil, nil return newEvents, nil
} }
func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
@ -838,6 +908,12 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
"concurrent_requests": concurrentRequests, "concurrent_requests": concurrentRequests,
}).Info("Fetching missing state at event") }).Info("Fetching missing state at event")
// Get a list of servers to fetch from.
servers := t.getServers(ctx, roomID)
if len(servers) > 5 {
servers = servers[:5]
}
// Create a queue containing all of the missing event IDs that we want // Create a queue containing all of the missing event IDs that we want
// to retrieve. // to retrieve.
pending := make(chan string, missingCount) pending := make(chan string, missingCount)
@ -863,7 +939,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
// Define what we'll do in order to fetch the missing event ID. // Define what we'll do in order to fetch the missing event ID.
fetch := func(missingEventID string) { fetch := func(missingEventID string) {
var h *gomatrixserverlib.HeaderedEvent var h *gomatrixserverlib.HeaderedEvent
h, err = t.lookupEvent(ctx, roomVersion, missingEventID, false) h, err = t.lookupEvent(ctx, roomVersion, missingEventID, false, servers)
switch err.(type) { switch err.(type) {
case verifySigError: case verifySigError:
return return
@ -901,26 +977,25 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
} }
func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs) ( func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs) (
*gomatrixserverlib.RespState, error) { *gomatrixserverlib.RespState, error) { // nolint:unparam
// create a RespState response using the response to /state_ids as a guide // create a RespState response using the response to /state_ids as a guide
respState := gomatrixserverlib.RespState{ respState := gomatrixserverlib.RespState{}
AuthEvents: make([]gomatrixserverlib.Event, len(stateIDs.AuthEventIDs)),
StateEvents: make([]gomatrixserverlib.Event, len(stateIDs.StateEventIDs)),
}
for i := range stateIDs.StateEventIDs { for i := range stateIDs.StateEventIDs {
ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]] ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]]
if !ok { if !ok {
return nil, fmt.Errorf("missing state event %s", stateIDs.StateEventIDs[i]) logrus.Warnf("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i])
continue
} }
respState.StateEvents[i] = ev.Unwrap() respState.StateEvents = append(respState.StateEvents, ev.Unwrap())
} }
for i := range stateIDs.AuthEventIDs { for i := range stateIDs.AuthEventIDs {
ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]] ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]]
if !ok { if !ok {
return nil, fmt.Errorf("missing auth event %s", stateIDs.AuthEventIDs[i]) logrus.Warnf("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i])
continue
} }
respState.AuthEvents[i] = ev.Unwrap() respState.AuthEvents = append(respState.AuthEvents, ev.Unwrap())
} }
// We purposefully do not do auth checks on the returned events, as they will still // We purposefully do not do auth checks on the returned events, as they will still
// be processed in the exact same way, just as a 'rejected' event // be processed in the exact same way, just as a 'rejected' event
@ -928,7 +1003,7 @@ func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStat
return &respState, nil return &respState, nil
} }
func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) { func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, missingEventID string, localFirst bool, servers []gomatrixserverlib.ServerName) (*gomatrixserverlib.HeaderedEvent, error) {
if localFirst { if localFirst {
// fetch from the roomserver // fetch from the roomserver
queryReq := api.QueryEventsByIDRequest{ queryReq := api.QueryEventsByIDRequest{
@ -941,19 +1016,27 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.
return &queryRes.Events[0], nil return &queryRes.Events[0], nil
} }
} }
txn, err := t.federation.GetEvent(ctx, t.Origin, missingEventID)
if err != nil || len(txn.PDUs) == 0 {
util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("failed to get missing /event for event ID")
return nil, err
}
pdu := txn.PDUs[0]
var event gomatrixserverlib.Event var event gomatrixserverlib.Event
event, err = gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) found := false
if err != nil { for _, serverName := range servers {
util.GetLogger(ctx).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %q", event.EventID()) txn, err := t.federation.GetEvent(ctx, serverName, missingEventID)
return nil, unmarshalError{err} if err != nil || len(txn.PDUs) == 0 {
util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("Failed to get missing /event for event ID")
continue
}
event, err = gomatrixserverlib.NewEventFromUntrustedJSON(txn.PDUs[0], roomVersion)
if err != nil {
util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warnf("Transaction: Failed to parse event JSON of event")
continue
}
found = true
break
} }
if err = gomatrixserverlib.VerifyAllEventSignatures(ctx, []gomatrixserverlib.Event{event}, t.keys); err != nil { if !found {
util.GetLogger(ctx).WithField("event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(servers))
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(servers))
}
if err := gomatrixserverlib.VerifyAllEventSignatures(ctx, []gomatrixserverlib.Event{event}, t.keys); err != nil {
util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID())
return nil, verifySigError{event.EventID(), err} return nil, verifySigError{event.EventID(), err}
} }

View file

@ -491,7 +491,7 @@ func TestTransactionFailAuthChecks(t *testing.T) {
queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse {
return api.QueryMissingAuthPrevEventsResponse{ return api.QueryMissingAuthPrevEventsResponse{
RoomExists: true, RoomExists: true,
MissingAuthEventIDs: []string{"create_event"}, MissingAuthEventIDs: []string{},
MissingPrevEventIDs: []string{}, MissingPrevEventIDs: []string{},
} }
}, },
@ -516,6 +516,23 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) {
var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions
rsAPI = &testRoomserverAPI{ rsAPI = &testRoomserverAPI{
queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse {
res := api.QueryEventsByIDResponse{}
for _, ev := range testEvents {
for _, id := range req.EventIDs {
if ev.EventID() == id {
res.Events = append(res.Events, ev)
}
}
}
return res
},
queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse {
return api.QueryStateAfterEventsResponse{
PrevEventsExist: true,
StateEvents: testEvents[:5],
}
},
queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse {
missingPrevEvent := []string{"missing_prev_event"} missingPrevEvent := []string{"missing_prev_event"}
if len(req.PrevEventIDs) == 1 { if len(req.PrevEventIDs) == 1 {

View file

@ -196,6 +196,11 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer(
return fmt.Errorf("respMakeJoin.JoinEvent.Build: %w", err) return fmt.Errorf("respMakeJoin.JoinEvent.Build: %w", err)
} }
// No longer reuse the request context from this point forward.
// We don't want the client timing out to interrupt the join.
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(context.Background())
// Try to perform a send_join using the newly built event. // Try to perform a send_join using the newly built event.
respSendJoin, err := r.federation.SendJoin( respSendJoin, err := r.federation.SendJoin(
ctx, ctx,
@ -205,11 +210,16 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer(
) )
if err != nil { if err != nil {
r.statistics.ForServer(serverName).Failure() r.statistics.ForServer(serverName).Failure()
cancel()
return fmt.Errorf("r.federation.SendJoin: %w", err) return fmt.Errorf("r.federation.SendJoin: %w", err)
} }
r.statistics.ForServer(serverName).Success() r.statistics.ForServer(serverName).Success()
// Sanity-check the join response to ensure that it has a create
// event, that the room version is known, etc.
if err := sanityCheckSendJoinResponse(respSendJoin); err != nil { if err := sanityCheckSendJoinResponse(respSendJoin); err != nil {
return err cancel()
return fmt.Errorf("sanityCheckSendJoinResponse: %w", err)
} }
// Process the join response in a goroutine. The idea here is // Process the join response in a goroutine. The idea here is
@ -217,8 +227,6 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer(
// to complete, but if the client does give up waiting, we'll // to complete, but if the client does give up waiting, we'll
// still continue to process the join anyway so that we don't // still continue to process the join anyway so that we don't
// waste the effort. // waste the effort.
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(context.Background())
go func() { go func() {
defer cancel() defer cancel()

View file

@ -93,7 +93,7 @@ func trackGoID(query string) {
if strings.HasPrefix(q, "SELECT") { if strings.HasPrefix(q, "SELECT") {
return // SELECTs can go on other goroutines return // SELECTs can go on other goroutines
} }
logrus.Warnf("unsafe goid: SQL executed not on an ExclusiveWriter: %s", q) logrus.Warnf("unsafe goid %d: SQL executed not on an ExclusiveWriter: %s", thisGoID, q)
} }
// Open opens a database specified by its database driver name and a driver-specific data source name, // Open opens a database specified by its database driver name and a driver-specific data source name,

View file

@ -122,7 +122,7 @@ func (r *Queryer) QueryMissingAuthPrevEvents(
} }
for _, prevEventID := range request.PrevEventIDs { for _, prevEventID := range request.PrevEventIDs {
if nids, err := r.DB.EventNIDs(ctx, []string{prevEventID}); err != nil || len(nids) == 0 { if state, err := r.DB.StateAtEventIDs(ctx, []string{prevEventID}); err != nil || len(state) == 0 {
response.MissingPrevEventIDs = append(response.MissingPrevEventIDs, prevEventID) response.MissingPrevEventIDs = append(response.MissingPrevEventIDs, prevEventID)
} }
} }

View file

@ -22,12 +22,21 @@ func NewMembershipUpdater(
ctx context.Context, d *Database, txn *sql.Tx, roomID, targetUserID string, ctx context.Context, d *Database, txn *sql.Tx, roomID, targetUserID string,
targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
) (*MembershipUpdater, error) { ) (*MembershipUpdater, error) {
roomNID, err := d.assignRoomNID(ctx, txn, roomID, roomVersion) var roomNID types.RoomNID
if err != nil { var targetUserNID types.EventStateKeyNID
return nil, err var err error
} err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
roomNID, err = d.assignRoomNID(ctx, txn, roomID, roomVersion)
if err != nil {
return err
}
targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID) targetUserNID, err = d.assignStateKeyNID(ctx, txn, targetUserID)
if err != nil {
return err
}
return nil
})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -224,10 +224,6 @@ func (s *ServerKeyAPI) handleFetcherKeys(
// Now let's look at the results that we got from this fetcher. // Now let's look at the results that we got from this fetcher.
for req, res := range fetcherResults { for req, res := range fetcherResults {
if req.ServerName == s.ServerName {
continue
}
if prev, ok := results[req]; ok { if prev, ok := results[req]; ok {
// We've already got a previous entry for this request // We've already got a previous entry for this request
// so let's see if the newly retrieved one contains a more // so let's see if the newly retrieved one contains a more