diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index 8f96c3d35..3c4a35216 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -462,6 +462,7 @@ func createRoom( AuthEvents: accumulated, }, ev.Headered(roomVersion), + cfg.Matrix.ServerName, nil, ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEventWithState failed") diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 33fb38831..29fe8231a 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -109,6 +109,7 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, cfg.Matrix.ServerName, + cfg.Matrix.ServerName, nil, ); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 7bea35e50..8d77b1580 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -169,7 +169,7 @@ func SetAvatarURL( return jsonerror.InternalServerError() } - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, nil); err != nil { + if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -286,7 +286,7 @@ func SetDisplayName( return jsonerror.InternalServerError() } - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, nil); err != nil { + if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index c25ca4eff..627f049f4 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -120,7 +120,7 @@ func SendRedaction( JSON: jsonerror.NotFound("Room does not exist"), } } - if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, cfg.Matrix.ServerName, nil); err != nil { + if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 204d2592a..12e625be4 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -121,6 +121,7 @@ func SendEvent( e.Headered(verRes.RoomVersion), }, cfg.Matrix.ServerName, + cfg.Matrix.ServerName, txnAndSessionID, ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 53cd6b8ca..a0bbb42b1 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -366,6 +366,7 @@ func emit3PIDInviteEvent( event.Headered(queryRes.RoomVersion), }, cfg.Matrix.ServerName, + cfg.Matrix.ServerName, nil, ) } diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 0cab7d5a4..487b51bc0 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -24,6 +24,7 @@ type FederationClient interface { MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest) (res gomatrixserverlib.MSC2946SpacesResponse, err error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) + LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } // FederationClientError is returned from FederationClient methods in the event of a problem. diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index 66611fe29..8937f0b52 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -106,6 +106,21 @@ func (a *FederationInternalAPI) LookupStateIDs( return ires.(gomatrixserverlib.RespStateIDs), nil } +func (a *FederationInternalAPI) LookupMissingEvents( + ctx context.Context, s gomatrixserverlib.ServerName, roomID string, + missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, +) (res gomatrixserverlib.RespMissingEvents, err error) { + ctx, cancel := context.WithTimeout(ctx, time.Second*30) + defer cancel() + ires, err := a.doRequest(s, func() (interface{}, error) { + return a.federation.LookupMissingEvents(ctx, s, roomID, missing, roomVersion) + }) + if err != nil { + return gomatrixserverlib.RespMissingEvents{}, err + } + return ires.(gomatrixserverlib.RespMissingEvents), nil +} + func (a *FederationInternalAPI) GetEvent( ctx context.Context, s gomatrixserverlib.ServerName, eventID string, ) (res gomatrixserverlib.Transaction, err error) { diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 82d04c21e..edf7a9209 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -249,6 +249,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( roomserverAPI.KindNew, respState, event.Headered(respMakeJoin.RoomVersion), + serverName, nil, ); err != nil { logrus.WithFields(logrus.Fields{ @@ -430,6 +431,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( roomserverAPI.KindNew, &respState, respPeek.LatestEvent.Headered(respPeek.RoomVersion), + serverName, nil, ); err != nil { return fmt.Errorf("r.producer.SendEventWithState: %w", err) diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index fa9fae343..d056cb382 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -26,17 +26,18 @@ const ( FederationAPIPerformServersAlivePath = "/federationapi/performServersAlive" FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" - FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" - FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" - FederationAPIQueryKeysPath = "/federationapi/client/queryKeys" - FederationAPIBackfillPath = "/federationapi/client/backfill" - FederationAPILookupStatePath = "/federationapi/client/lookupState" - FederationAPILookupStateIDsPath = "/federationapi/client/lookupStateIDs" - FederationAPIGetEventPath = "/federationapi/client/getEvent" - FederationAPILookupServerKeysPath = "/federationapi/client/lookupServerKeys" - FederationAPIEventRelationshipsPath = "/federationapi/client/msc2836eventRelationships" - FederationAPISpacesSummaryPath = "/federationapi/client/msc2946spacesSummary" - FederationAPIGetEventAuthPath = "/federationapi/client/getEventAuth" + FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" + FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" + FederationAPIQueryKeysPath = "/federationapi/client/queryKeys" + FederationAPIBackfillPath = "/federationapi/client/backfill" + FederationAPILookupStatePath = "/federationapi/client/lookupState" + FederationAPILookupStateIDsPath = "/federationapi/client/lookupStateIDs" + FederationAPILookupMissingEventsPath = "/federationapi/client/lookupMissingEvents" + FederationAPIGetEventPath = "/federationapi/client/getEvent" + FederationAPILookupServerKeysPath = "/federationapi/client/lookupServerKeys" + FederationAPIEventRelationshipsPath = "/federationapi/client/msc2836eventRelationships" + FederationAPISpacesSummaryPath = "/federationapi/client/msc2946spacesSummary" + FederationAPIGetEventAuthPath = "/federationapi/client/getEventAuth" FederationAPIInputPublicKeyPath = "/federationapi/inputPublicKey" FederationAPIQueryPublicKeyPath = "/federationapi/queryPublicKey" @@ -354,6 +355,40 @@ func (h *httpFederationInternalAPI) LookupStateIDs( return *response.Res, nil } +type lookupMissingEvents struct { + S gomatrixserverlib.ServerName + RoomID string + Missing gomatrixserverlib.MissingEvents + RoomVersion gomatrixserverlib.RoomVersion + Res *gomatrixserverlib.RespMissingEvents + Err *api.FederationClientError +} + +func (h *httpFederationInternalAPI) LookupMissingEvents( + ctx context.Context, s gomatrixserverlib.ServerName, roomID string, + missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, +) (gomatrixserverlib.RespMissingEvents, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "LookupMissingEvents") + defer span.Finish() + + request := lookupMissingEvents{ + S: s, + RoomID: roomID, + Missing: missing, + RoomVersion: roomVersion, + } + var response lookupMissingEvents + apiURL := h.federationAPIURL + FederationAPILookupMissingEventsPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) + if err != nil { + return gomatrixserverlib.RespMissingEvents{}, err + } + if response.Err != nil { + return gomatrixserverlib.RespMissingEvents{}, response.Err + } + return *response.Res, nil +} + type getEvent struct { S gomatrixserverlib.ServerName EventID string diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 4a5af0706..104df1a87 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -241,6 +241,28 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: request} }), ) + internalAPIMux.Handle( + FederationAPILookupMissingEventsPath, + httputil.MakeInternalAPI("LookupMissingEvents", func(req *http.Request) util.JSONResponse { + var request lookupMissingEvents + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + res, err := intAPI.LookupMissingEvents(req.Context(), request.S, request.RoomID, request.Missing, request.RoomVersion) + if err != nil { + ferr, ok := err.(*api.FederationClientError) + if ok { + request.Err = ferr + } else { + request.Err = &api.FederationClientError{ + Err: err.Error(), + } + } + } + request.Res = &res + return util.JSONResponse{Code: http.StatusOK, JSON: request} + }), + ) internalAPIMux.Handle( FederationAPIGetEventPath, httputil.MakeInternalAPI("GetEvent", func(req *http.Request) util.JSONResponse { diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 9341c3298..8678f2240 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -17,7 +17,6 @@ package routing import ( "context" "encoding/json" - "errors" "fmt" "net/http" "sync" @@ -201,8 +200,6 @@ func Send( eduAPI: eduAPI, keys: keys, federation: federation, - hadEvents: make(map[string]bool), - haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), servers: servers, keyAPI: keyAPI, roomsMu: mu, @@ -263,22 +260,8 @@ type txnReq struct { keys gomatrixserverlib.JSONVerifier federation txnFederationClient roomsMu *internal.MutexByRoom - // something that can tell us about which servers are in a room right now - servers federationAPI.ServersInRoomProvider - // a list of events from the auth and prev events which we already had - hadEvents map[string]bool - hadEventsMutex sync.Mutex - // local cache of events for auth checks, etc - this may include events - // which the roomserver is unaware of. - haveEvents map[string]*gomatrixserverlib.HeaderedEvent - haveEventsMutex sync.Mutex - work string // metrics -} - -func (t *txnReq) hadEvent(eventID string, had bool) { - t.hadEventsMutex.Lock() - defer t.hadEventsMutex.Unlock() - t.hadEvents[eventID] = had + servers federationAPI.ServersInRoomProvider + work string } // A subset of FederationClient functionality that txn requires. Useful for testing. @@ -440,22 +423,8 @@ func (t *inputWorker) run() { type roomNotFoundError struct { roomID string } -type verifySigError struct { - eventID string - err error -} -type missingPrevEventsError struct { - eventID string - err error -} func (e roomNotFoundError) Error() string { return fmt.Sprintf("room %q not found", e.roomID) } -func (e verifySigError) Error() string { - return fmt.Sprintf("unable to verify signature of event %q: %s", e.eventID, e.err) -} -func (e missingPrevEventsError) Error() string { - return fmt.Sprintf("unable to get prev_events for event %q: %s", e.eventID, e.err) -} func (t *txnReq) processEDUs(ctx context.Context) { for _, e := range t.EDUs { @@ -599,28 +568,7 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli } } -func (t *txnReq) getServers(ctx context.Context, roomID string, event *gomatrixserverlib.Event) []gomatrixserverlib.ServerName { - // The server that sent us the event should be sufficient to tell us about missing - // prev and auth events. - servers := []gomatrixserverlib.ServerName{t.Origin} - // If the event origin is different to the transaction origin then we can use - // this as a last resort. The origin server that created the event would have - // had to know the auth and prev events. - if event != nil { - if origin := event.Origin(); origin != t.Origin { - servers = append(servers, origin) - } - } - // If a specific room-to-server provider exists then use that. This will primarily - // be used for the P2P demos. - if t.servers != nil { - servers = append(servers, t.servers.GetServersForRoom(ctx, roomID, event)...) - } - return servers -} - func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) error { - logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) t.work = "" // reset from previous event // Ask the roomserver if we know about the room and/or if we're joined @@ -648,30 +596,14 @@ func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) e // before the roomserver tries to work out stateReq := api.QueryMissingAuthPrevEventsRequest{ RoomID: e.RoomID(), - AuthEventIDs: e.AuthEventIDs(), - PrevEventIDs: e.PrevEventIDs(), + AuthEventIDs: nil, //e.AuthEventIDs(), + PrevEventIDs: nil, //e.PrevEventIDs(), } var stateResp api.QueryMissingAuthPrevEventsResponse if err := t.rsAPI.QueryMissingAuthPrevEvents(ctx, &stateReq, &stateResp); err != nil { return fmt.Errorf("t.rsAPI.QueryMissingAuthPrevEvents: %w", err) } - // Prepare a map of all the events we already had before this point, so - // that we don't send them to the roomserver again. - for _, eventID := range append(e.AuthEventIDs(), e.PrevEventIDs()...) { - t.hadEvent(eventID, true) - } - for _, eventID := range append(stateResp.MissingAuthEventIDs, stateResp.MissingPrevEventIDs...) { - t.hadEvent(eventID, false) - } - - if len(stateResp.MissingPrevEventIDs) > 0 { - t.work = MetricsWorkMissingPrevEvents - logger.Infof("Event refers to %d unknown prev_events", len(stateResp.MissingPrevEventIDs)) - return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion) - } - t.work = MetricsWorkDirect - // pass the event to the roomserver which will do auth checks // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently // discarded by the caller of this function @@ -682,656 +614,8 @@ func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) e []*gomatrixserverlib.HeaderedEvent{ e.Headered(stateResp.RoomVersion), }, + t.Origin, api.DoNotSendToOtherServers, nil, ) } - -func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserverlib.Event) error { - authUsingState := gomatrixserverlib.NewAuthEvents(nil) - for i := range stateEvents { - err := authUsingState.AddEvent(stateEvents[i]) - if err != nil { - return err - } - } - return gomatrixserverlib.Allowed(e, &authUsingState) -} - -func (t *txnReq) processEventWithMissingState( - ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, -) error { - // We are missing the previous events for this events. - // This means that there is a gap in our view of the history of the - // room. There two ways that we can handle such a gap: - // 1) We can fill in the gap using /get_missing_events - // 2) We can leave the gap and request the state of the room at - // this event from the remote server using either /state_ids - // or /state. - // Synapse will attempt to do 1 and if that fails or if the gap is - // too large then it will attempt 2. - // Synapse will use /state_ids if possible since usually the state - // is largely unchanged and it is more efficient to fetch a list of - // event ids and then use /event to fetch the individual events. - // However not all version of synapse support /state_ids so you may - // need to fallback to /state. - - // Attempt to fill in the gap using /get_missing_events - // This will either: - // - fill in the gap completely then process event `e` returning no backwards extremity - // - fail to fill in the gap and tell us to terminate the transaction err=not nil - // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction - newEvents, err := t.getMissingEvents(ctx, e, roomVersion) - if err != nil { - return err - } - if len(newEvents) == 0 { - return nil - } - - backwardsExtremity := newEvents[0] - newEvents = newEvents[1:] - - type respState struct { - // A snapshot is considered trustworthy if it came from our own roomserver. - // That's because the state will have been through state resolution once - // already in QueryStateAfterEvent. - trustworthy bool - *gomatrixserverlib.RespState - } - - // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. - // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query - // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. - var states []*respState - for _, prevEventID := range backwardsExtremity.PrevEventIDs() { - // Look up what the state is after the backward extremity. This will either - // come from the roomserver, if we know all the required events, or it will - // come from a remote server via /state_ids if not. - prevState, trustworthy, lerr := t.lookupStateAfterEvent(ctx, roomVersion, backwardsExtremity.RoomID(), prevEventID) - if lerr != nil { - util.GetLogger(ctx).WithError(lerr).Errorf("Failed to lookup state after prev_event: %s", prevEventID) - return lerr - } - // Append the state onto the collected state. We'll run this through the - // state resolution next. - states = append(states, &respState{trustworthy, prevState}) - } - - // Now that we have collected all of the state from the prev_events, we'll - // run the state through the appropriate state resolution algorithm for the - // room if needed. This does a couple of things: - // 1. Ensures that the state is deduplicated fully for each state-key tuple - // 2. Ensures that we pick the latest events from both sets, in the case that - // one of the prev_events is quite a bit older than the others - resolvedState := &gomatrixserverlib.RespState{} - switch len(states) { - case 0: - extremityIsCreate := backwardsExtremity.Type() == gomatrixserverlib.MRoomCreate && backwardsExtremity.StateKeyEquals("") - if !extremityIsCreate { - // There are no previous states and this isn't the beginning of the - // room - this is an error condition! - util.GetLogger(ctx).Errorf("Failed to lookup any state after prev_events") - return fmt.Errorf("expected %d states but got %d", len(backwardsExtremity.PrevEventIDs()), len(states)) - } - case 1: - // There's only one previous state - if it's trustworthy (came from a - // local state snapshot which will already have been through state res), - // use it as-is. There's no point in resolving it again. - if states[0].trustworthy { - resolvedState = states[0].RespState - break - } - // Otherwise, if it isn't trustworthy (came from federation), run it through - // state resolution anyway for safety, in case there are duplicates. - fallthrough - default: - respStates := make([]*gomatrixserverlib.RespState, len(states)) - for i := range states { - respStates[i] = states[i].RespState - } - // There's more than one previous state - run them all through state res - t.roomsMu.Lock(e.RoomID()) - resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, backwardsExtremity) - t.roomsMu.Unlock(e.RoomID()) - if err != nil { - util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) - return err - } - } - - // First of all, send the backward extremity into the roomserver with the - // newly resolved state. This marks the "oldest" point in the backfill and - // sets the baseline state for any new events after this. We'll make a - // copy of the hadEvents map so that it can be taken downstream without - // worrying about concurrent map reads/writes, since t.hadEvents is meant - // to be protected by a mutex. - hadEvents := map[string]bool{} - t.hadEventsMutex.Lock() - for k, v := range t.hadEvents { - hadEvents[k] = v - } - t.hadEventsMutex.Unlock() - err = api.SendEventWithState( - context.Background(), - t.rsAPI, - api.KindOld, - resolvedState, - backwardsExtremity.Headered(roomVersion), - hadEvents, - ) - if err != nil { - return fmt.Errorf("api.SendEventWithState: %w", err) - } - - // Then send all of the newer backfilled events, of which will all be newer - // than the backward extremity, into the roomserver without state. This way - // they will automatically fast-forward based on the room state at the - // extremity in the last step. - headeredNewEvents := make([]*gomatrixserverlib.HeaderedEvent, len(newEvents)) - for i, newEvent := range newEvents { - headeredNewEvents[i] = newEvent.Headered(roomVersion) - } - if err = api.SendEvents( - context.Background(), - t.rsAPI, - api.KindOld, - append(headeredNewEvents, e.Headered(roomVersion)), - api.DoNotSendToOtherServers, - nil, - ); err != nil { - return fmt.Errorf("api.SendEvents: %w", err) - } - - return nil -} - -// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) -// added into the mix. -func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*gomatrixserverlib.RespState, bool, error) { - // try doing all this locally before we resort to querying federation - respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID) - if respState != nil { - return respState, true, nil - } - - respState, err := t.lookupStateBeforeEvent(ctx, roomVersion, roomID, eventID) - if err != nil { - return nil, false, fmt.Errorf("t.lookupStateBeforeEvent: %w", err) - } - - // fetch the event we're missing and add it to the pile - h, err := t.lookupEvent(ctx, roomVersion, roomID, eventID, false) - switch err.(type) { - case verifySigError: - return respState, false, nil - case nil: - // do nothing - default: - return nil, false, fmt.Errorf("t.lookupEvent: %w", err) - } - h = t.cacheAndReturn(h) - if h.StateKey() != nil { - addedToState := false - for i := range respState.StateEvents { - se := respState.StateEvents[i] - if se.Type() == h.Type() && se.StateKeyEquals(*h.StateKey()) { - respState.StateEvents[i] = h.Unwrap() - addedToState = true - break - } - } - if !addedToState { - respState.StateEvents = append(respState.StateEvents, h.Unwrap()) - } - } - - return respState, false, nil -} - -func (t *txnReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *gomatrixserverlib.HeaderedEvent { - t.haveEventsMutex.Lock() - defer t.haveEventsMutex.Unlock() - if cached, exists := t.haveEvents[ev.EventID()]; exists { - return cached - } - t.haveEvents[ev.EventID()] = ev - return ev -} - -func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *gomatrixserverlib.RespState { - var res api.QueryStateAfterEventsResponse - err := t.rsAPI.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{ - RoomID: roomID, - PrevEventIDs: []string{eventID}, - }, &res) - if err != nil || !res.PrevEventsExist { - util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to query state after %s locally, prev exists=%v", eventID, res.PrevEventsExist) - return nil - } - stateEvents := make([]*gomatrixserverlib.HeaderedEvent, len(res.StateEvents)) - for i, ev := range res.StateEvents { - // set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this - // processEvent request, which is better for memory. - stateEvents[i] = t.cacheAndReturn(ev) - t.hadEvent(ev.EventID(), true) - } - // we should never access res.StateEvents again so we delete it here to make GC faster - res.StateEvents = nil - - var authEvents []*gomatrixserverlib.Event - missingAuthEvents := map[string]bool{} - for _, ev := range stateEvents { - t.haveEventsMutex.Lock() - for _, ae := range ev.AuthEventIDs() { - if aev, ok := t.haveEvents[ae]; ok { - authEvents = append(authEvents, aev.Unwrap()) - } else { - missingAuthEvents[ae] = true - } - } - t.haveEventsMutex.Unlock() - } - // QueryStateAfterEvents does not return the auth events, so fetch them now. We know the roomserver has them else it wouldn't - // have stored the event. - if len(missingAuthEvents) > 0 { - var missingEventList []string - for evID := range missingAuthEvents { - missingEventList = append(missingEventList, evID) - } - queryReq := api.QueryEventsByIDRequest{ - EventIDs: missingEventList, - } - util.GetLogger(ctx).WithField("count", len(missingEventList)).Infof("Fetching missing auth events") - var queryRes api.QueryEventsByIDResponse - if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { - return nil - } - for i, ev := range queryRes.Events { - authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap()) - t.hadEvent(ev.EventID(), true) - } - queryRes.Events = nil - } - - return &gomatrixserverlib.RespState{ - StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateEvents), - AuthEvents: authEvents, - } -} - -// lookuptStateBeforeEvent returns the room state before the event e, which is just /state_ids and/or /state depending on what -// the server supports. -func (t *txnReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( - *gomatrixserverlib.RespState, error) { - - // Attempt to fetch the missing state using /state_ids and /events - return t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion) -} - -func (t *txnReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*gomatrixserverlib.RespState, backwardsExtremity *gomatrixserverlib.Event) (*gomatrixserverlib.RespState, error) { - var authEventList []*gomatrixserverlib.Event - var stateEventList []*gomatrixserverlib.Event - for _, state := range states { - authEventList = append(authEventList, state.AuthEvents...) - stateEventList = append(stateEventList, state.StateEvents...) - } - resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(roomVersion, stateEventList, authEventList) - if err != nil { - return nil, err - } - // apply the current event -retryAllowedState: - if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil { - switch missing := err.(type) { - case gomatrixserverlib.MissingAuthEventError: - h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) - switch err2.(type) { - case verifySigError: - return &gomatrixserverlib.RespState{ - AuthEvents: authEventList, - StateEvents: resolvedStateEvents, - }, nil - case nil: - // do nothing - default: - return nil, fmt.Errorf("missing auth event %s and failed to look it up: %w", missing.AuthEventID, err2) - } - util.GetLogger(ctx).Infof("fetched event %s", missing.AuthEventID) - resolvedStateEvents = append(resolvedStateEvents, h.Unwrap()) - goto retryAllowedState - default: - } - return nil, err - } - return &gomatrixserverlib.RespState{ - AuthEvents: authEventList, - StateEvents: resolvedStateEvents, - }, nil -} - -func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, err error) { - logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) - needed := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{e}) - // query latest events (our trusted forward extremities) - req := api.QueryLatestEventsAndStateRequest{ - RoomID: e.RoomID(), - StateToFetch: needed.Tuples(), - } - var res api.QueryLatestEventsAndStateResponse - if err = t.rsAPI.QueryLatestEventsAndState(ctx, &req, &res); err != nil { - logger.WithError(err).Warn("Failed to query latest events") - return nil, err - } - latestEvents := make([]string, len(res.LatestEvents)) - for i, ev := range res.LatestEvents { - latestEvents[i] = res.LatestEvents[i].EventID - t.hadEvent(ev.EventID, true) - } - - var missingResp *gomatrixserverlib.RespMissingEvents - servers := t.getServers(ctx, e.RoomID(), e) - for _, server := range servers { - var m gomatrixserverlib.RespMissingEvents - if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ - Limit: 20, - // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. - EarliestEvents: latestEvents, - // The event IDs to retrieve the previous events for. - LatestEvents: []string{e.EventID()}, - }, roomVersion); err == nil { - missingResp = &m - break - } else { - logger.WithError(err).Errorf("%s pushed us an event but %q did not respond to /get_missing_events", t.Origin, server) - if errors.Is(err, context.DeadlineExceeded) { - break - } - } - } - - if missingResp == nil { - logger.WithError(err).Errorf( - "%s pushed us an event but %d server(s) couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", - t.Origin, len(servers), - ) - return nil, missingPrevEventsError{ - eventID: e.EventID(), - err: err, - } - } - - // security: how we handle failures depends on whether or not this event will become the new forward extremity for the room. - // There's 2 scenarios to consider: - // - Case A: We got pushed an event and are now fetching missing prev_events. (isInboundTxn=true) - // - Case B: We are fetching missing prev_events already and now fetching some more (isInboundTxn=false) - // In Case B, we know for sure that the event we are currently processing will not become the new forward extremity for the room, - // as it was called in response to an inbound txn which had it as a prev_event. - // In Case A, the event is a forward extremity, and could eventually become the _only_ forward extremity in the room. This is bad - // because it means we would trust the state at that event to be the state for the entire room, and allows rooms to be hijacked. - // https://github.com/matrix-org/synapse/pull/3456 - // https://github.com/matrix-org/synapse/blob/229eb81498b0fe1da81e9b5b333a0285acde9446/synapse/handlers/federation.py#L335 - // For now, we do not allow Case B, so reject the event. - logger.Infof("get_missing_events returned %d events", len(missingResp.Events)) - - // Make sure events from the missingResp are using the cache - missing events - // will be added and duplicates will be removed. - for i, ev := range missingResp.Events { - missingResp.Events[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() - } - - // topologically sort and sanity check that we are making forward progress - newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents) - shouldHaveSomeEventIDs := e.PrevEventIDs() - hasPrevEvent := false -Event: - for _, pe := range shouldHaveSomeEventIDs { - for _, ev := range newEvents { - if ev.EventID() == pe { - hasPrevEvent = true - break Event - } - } - } - if !hasPrevEvent { - err = fmt.Errorf("called /get_missing_events but server %s didn't return any prev_events with IDs %v", t.Origin, shouldHaveSomeEventIDs) - logger.WithError(err).Errorf( - "%s pushed us an event but couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", - t.Origin, - ) - return nil, missingPrevEventsError{ - eventID: e.EventID(), - err: err, - } - } - - return newEvents, nil -} - -func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - respState *gomatrixserverlib.RespState, err error) { - state, err := t.federation.LookupState(ctx, t.Origin, roomID, eventID, roomVersion) - if err != nil { - return nil, err - } - // Check that the returned state is valid. - if err := state.Check(ctx, t.keys, nil); err != nil { - return nil, err - } - // Cache the results of this state lookup and deduplicate anything we already - // have in the cache, freeing up memory. - for i, ev := range state.AuthEvents { - state.AuthEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() - } - for i, ev := range state.StateEvents { - state.StateEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() - } - return &state, nil -} - -func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - *gomatrixserverlib.RespState, error) { - util.GetLogger(ctx).WithField("room_id", roomID).Infof("lookupMissingStateViaStateIDs %s", eventID) - // fetch the state event IDs at the time of the event - stateIDs, err := t.federation.LookupStateIDs(ctx, t.Origin, roomID, eventID) - if err != nil { - return nil, err - } - // work out which auth/state IDs are missing - wantIDs := append(stateIDs.StateEventIDs, stateIDs.AuthEventIDs...) - missing := make(map[string]bool) - var missingEventList []string - t.haveEventsMutex.Lock() - for _, sid := range wantIDs { - if _, ok := t.haveEvents[sid]; !ok { - if !missing[sid] { - missing[sid] = true - missingEventList = append(missingEventList, sid) - } - } - } - t.haveEventsMutex.Unlock() - - // fetch as many as we can from the roomserver - queryReq := api.QueryEventsByIDRequest{ - EventIDs: missingEventList, - } - var queryRes api.QueryEventsByIDResponse - if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { - return nil, err - } - for i, ev := range queryRes.Events { - queryRes.Events[i] = t.cacheAndReturn(queryRes.Events[i]) - t.hadEvent(ev.EventID(), true) - evID := queryRes.Events[i].EventID() - if missing[evID] { - delete(missing, evID) - } - } - queryRes.Events = nil // allow it to be GCed - - concurrentRequests := 8 - missingCount := len(missing) - util.GetLogger(ctx).WithField("room_id", roomID).WithField("event_id", eventID).Infof("lookupMissingStateViaStateIDs missing %d/%d events", missingCount, len(wantIDs)) - - // If over 50% of the auth/state events from /state_ids are missing - // then we'll just call /state instead, otherwise we'll just end up - // hammering the remote side with /event requests unnecessarily. - if missingCount > concurrentRequests && missingCount > len(wantIDs)/2 { - util.GetLogger(ctx).WithFields(logrus.Fields{ - "missing": missingCount, - "event_id": eventID, - "room_id": roomID, - "total_state": len(stateIDs.StateEventIDs), - "total_auth_events": len(stateIDs.AuthEventIDs), - }).Info("Fetching all state at event") - return t.lookupMissingStateViaState(ctx, roomID, eventID, roomVersion) - } - - if missingCount > 0 { - util.GetLogger(ctx).WithFields(logrus.Fields{ - "missing": missingCount, - "event_id": eventID, - "room_id": roomID, - "total_state": len(stateIDs.StateEventIDs), - "total_auth_events": len(stateIDs.AuthEventIDs), - "concurrent_requests": concurrentRequests, - }).Info("Fetching missing state at event") - - // Create a queue containing all of the missing event IDs that we want - // to retrieve. - pending := make(chan string, missingCount) - for missingEventID := range missing { - pending <- missingEventID - } - close(pending) - - // Define how many workers we should start to do this. - if missingCount < concurrentRequests { - concurrentRequests = missingCount - } - - // Create the wait group. - var fetchgroup sync.WaitGroup - fetchgroup.Add(concurrentRequests) - - // This is the only place where we'll write to t.haveEvents from - // multiple goroutines, and everywhere else is blocked on this - // synchronous function anyway. - var haveEventsMutex sync.Mutex - - // Define what we'll do in order to fetch the missing event ID. - fetch := func(missingEventID string) { - var h *gomatrixserverlib.HeaderedEvent - h, err = t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false) - switch err.(type) { - case verifySigError: - return - case nil: - break - default: - util.GetLogger(ctx).WithFields(logrus.Fields{ - "event_id": missingEventID, - "room_id": roomID, - }).Info("Failed to fetch missing event") - return - } - haveEventsMutex.Lock() - t.cacheAndReturn(h) - haveEventsMutex.Unlock() - } - - // Create the worker. - worker := func(ch <-chan string) { - defer fetchgroup.Done() - for missingEventID := range ch { - fetch(missingEventID) - } - } - - // Start the workers. - for i := 0; i < concurrentRequests; i++ { - go worker(pending) - } - - // Wait for the workers to finish. - fetchgroup.Wait() - } - - resp, err := t.createRespStateFromStateIDs(stateIDs) - return resp, err -} - -func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs) ( - *gomatrixserverlib.RespState, error) { // nolint:unparam - t.haveEventsMutex.Lock() - defer t.haveEventsMutex.Unlock() - - // create a RespState response using the response to /state_ids as a guide - respState := gomatrixserverlib.RespState{} - - for i := range stateIDs.StateEventIDs { - ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]] - if !ok { - logrus.Warnf("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i]) - continue - } - respState.StateEvents = append(respState.StateEvents, ev.Unwrap()) - } - for i := range stateIDs.AuthEventIDs { - ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]] - if !ok { - logrus.Warnf("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i]) - continue - } - respState.AuthEvents = append(respState.AuthEvents, ev.Unwrap()) - } - // We purposefully do not do auth checks on the returned events, as they will still - // be processed in the exact same way, just as a 'rejected' event - // TODO: Add a field to HeaderedEvent to indicate if the event is rejected. - return &respState, nil -} - -func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) { - if localFirst { - // fetch from the roomserver - queryReq := api.QueryEventsByIDRequest{ - EventIDs: []string{missingEventID}, - } - var queryRes api.QueryEventsByIDResponse - if err := t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { - util.GetLogger(ctx).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) - } else if len(queryRes.Events) == 1 { - return queryRes.Events[0], nil - } - } - var event *gomatrixserverlib.Event - found := false - servers := t.getServers(ctx, roomID, nil) - for _, serverName := range servers { - txn, err := t.federation.GetEvent(ctx, serverName, missingEventID) - if err != nil || len(txn.PDUs) == 0 { - util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("Failed to get missing /event for event ID") - if errors.Is(err, context.DeadlineExceeded) { - break - } - continue - } - event, err = gomatrixserverlib.NewEventFromUntrustedJSON(txn.PDUs[0], roomVersion) - if err != nil { - util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warnf("Transaction: Failed to parse event JSON of event") - continue - } - found = true - break - } - if !found { - util.GetLogger(ctx).WithField("event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(servers)) - return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(servers)) - } - if err := event.VerifyEventSignatures(ctx, t.keys); err != nil { - util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) - return nil, verifySigError{event.EventID(), err} - } - return t.cacheAndReturn(event.Headered(roomVersion)), nil -} diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 702884613..665246fce 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -244,8 +244,6 @@ func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederat eduAPI: &testEDUProducer{}, keys: &test.NopJSONVerifier{}, federation: fedClient, - haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), - hadEvents: make(map[string]bool), roomsMu: internal.NewMutexByRoom(), } t.PDUs = pdus diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index 5ba28881c..3fef93036 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -89,7 +89,7 @@ func CreateInvitesFrom3PIDInvites( } // Send all the events - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, evs, cfg.Matrix.ServerName, nil); err != nil { + if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, evs, "TODO", cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -178,6 +178,7 @@ func ExchangeThirdPartyInvite( []*gomatrixserverlib.HeaderedEvent{ signedEvent.Event.Headered(verRes.RoomVersion), }, + request.Origin(), cfg.Matrix.ServerName, nil, ); err != nil { diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 8e6e4ac7b..71eb9d758 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -54,6 +54,8 @@ type InputRoomEvent struct { Kind Kind `json:"kind"` // The event JSON for the event to add. Event *gomatrixserverlib.HeaderedEvent `json:"event"` + // Which server told us about this event. + Origin gomatrixserverlib.ServerName `json:"origin"` // List of state event IDs that authenticate this event. // These are likely derived from the "auth_events" JSON key of the event. // But can be different because the "auth_events" key can be incomplete or wrong. diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index de66df803..0b8f3b757 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -22,10 +22,15 @@ import ( "github.com/matrix-org/util" ) +type RoomEventInputter interface { + InputRoomEvents(context.Context, *InputRoomEventsRequest, *InputRoomEventsResponse) +} + // SendEvents to the roomserver The events are written with KindNew. func SendEvents( - ctx context.Context, rsAPI RoomserverInternalAPI, + ctx context.Context, rsAPI RoomEventInputter, kind Kind, events []*gomatrixserverlib.HeaderedEvent, + origin gomatrixserverlib.ServerName, sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, ) error { ires := make([]InputRoomEvent, len(events)) @@ -33,6 +38,7 @@ func SendEvents( ires[i] = InputRoomEvent{ Kind: kind, Event: event, + Origin: origin, AuthEventIDs: event.AuthEventIDs(), SendAsServer: string(sendAsServer), TransactionID: txnID, @@ -45,9 +51,9 @@ func SendEvents( // with the state at the event as KindOutlier before it. Will not send any event that is // marked as `true` in haveEventIDs. func SendEventWithState( - ctx context.Context, rsAPI RoomserverInternalAPI, kind Kind, + ctx context.Context, rsAPI RoomEventInputter, kind Kind, state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, - haveEventIDs map[string]bool, + origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, ) error { outliers, err := state.Events() if err != nil { @@ -62,6 +68,7 @@ func SendEventWithState( ires = append(ires, InputRoomEvent{ Kind: KindOutlier, Event: outlier.Headered(event.RoomVersion), + Origin: origin, AuthEventIDs: outlier.AuthEventIDs(), }) } @@ -74,6 +81,7 @@ func SendEventWithState( ires = append(ires, InputRoomEvent{ Kind: kind, Event: event, + Origin: origin, AuthEventIDs: event.AuthEventIDs(), HasState: true, StateEventIDs: stateEventIDs, @@ -84,7 +92,7 @@ func SendEventWithState( // SendInputRoomEvents to the roomserver. func SendInputRoomEvents( - ctx context.Context, rsAPI RoomserverInternalAPI, ires []InputRoomEvent, + ctx context.Context, rsAPI RoomEventInputter, ires []InputRoomEvent, ) error { request := InputRoomEventsRequest{InputRoomEvents: ires} var response InputRoomEventsResponse diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index c1de3c9cb..2c3178ab6 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -92,6 +92,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA FSAPI: fsAPI, KeyRing: keyRing, ACLs: r.ServerACLs, + Queryer: r.Queryer, } r.Inviter = &perform.Inviter{ DB: r.DB, diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index bc9046cf5..e100efa01 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -27,6 +27,7 @@ import ( "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/prometheus/client_golang/prometheus" @@ -50,6 +51,8 @@ type Inputer struct { ACLs *acls.ServerACLs OutputRoomEventTopic string workers sync.Map // room ID -> *inputWorker + + Queryer *query.Queryer } type inputTask struct { diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 52b1a7025..86a940818 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -23,6 +23,7 @@ import ( "time" fedapi "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" @@ -76,6 +77,11 @@ func (r *Inputer) processRoomEvent( // Parse and validate the event JSON headered := input.Event event := headered.Unwrap() + logger := util.GetLogger(ctx).WithFields(logrus.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "type": event.Type(), + }) // if we have already got this event then do not process it again, if the input kind is an outlier. // Outliers contain no extra information which may warrant a re-processing. @@ -88,23 +94,35 @@ func (r *Inputer) processRoomEvent( switch idFormat { case gomatrixserverlib.EventIDFormatV1: if bytes.Equal(event.EventReference().EventSHA256, evs[0].EventReference().EventSHA256) { - util.GetLogger(ctx).WithField("event_id", event.EventID()).Infof("Already processed event; ignoring") + logger.Debugf("Already processed event; ignoring") return event.EventID(), nil } default: - util.GetLogger(ctx).WithField("event_id", event.EventID()).Infof("Already processed event; ignoring") + logger.Debugf("Already processed event; ignoring") return event.EventID(), nil } } } } + missingReq := &api.QueryMissingAuthPrevEventsRequest{ + RoomID: event.RoomID(), + AuthEventIDs: event.AuthEventIDs(), + PrevEventIDs: event.PrevEventIDs(), + } + missingRes := &api.QueryMissingAuthPrevEventsResponse{} + if event.Type() != gomatrixserverlib.MRoomCreate { + if err = r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil { + return "", fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err) + } + } + // First of all, check that the auth events of the event are known. // If they aren't then we will ask the federation API for them. isRejected := false authEvents := gomatrixserverlib.NewAuthEvents(nil) - knownAuthEvents := map[string]types.Event{} - if err = r.checkForMissingAuthEvents(ctx, input.Event, &authEvents, knownAuthEvents); err != nil { + knownEvents := map[string]*types.Event{} + if err = r.checkForMissingAuthEvents(ctx, logger, input.Event, &authEvents, knownEvents); err != nil { return "", fmt.Errorf("r.checkForMissingAuthEvents: %w", err) } @@ -113,14 +131,14 @@ func (r *Inputer) processRoomEvent( var rejectionErr error if rejectionErr = gomatrixserverlib.Allowed(event, &authEvents); rejectionErr != nil { isRejected = true - logrus.WithError(rejectionErr).Warnf("Event %s rejected", event.EventID()) + logger.WithError(rejectionErr).Warnf("Event %s rejected", event.EventID()) } // Accumulate the auth event NIDs. authEventIDs := event.AuthEventIDs() authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) for _, authEventID := range authEventIDs { - authEventNIDs = append(authEventNIDs, knownAuthEvents[authEventID].EventNID) + authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) } var softfail bool @@ -129,11 +147,7 @@ func (r *Inputer) processRoomEvent( // current room state. softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) if err != nil { - logrus.WithFields(logrus.Fields{ - "event_id": event.EventID(), - "type": event.Type(), - "room": event.RoomID(), - }).WithError(err).Info("Error authing soft-failed event") + logger.WithError(err).Info("Error authing soft-failed event") } } @@ -156,12 +170,7 @@ func (r *Inputer) processRoomEvent( // doesn't have any associated state to store and we don't need to // notify anyone about it. if input.Kind == api.KindOutlier { - logrus.WithFields(logrus.Fields{ - "event_id": event.EventID(), - "type": event.Type(), - "room": event.RoomID(), - "sender": event.Sender(), - }).Debug("Stored outlier") + logger.Debug("Stored outlier") return event.EventID(), nil } @@ -173,6 +182,27 @@ func (r *Inputer) processRoomEvent( return "", fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID()) } + if input.Origin == "" { + return "", fmt.Errorf("expected an origin") + } + + if len(missingRes.MissingPrevEventIDs) > 0 { + missingState := missingStateReq{ + origin: input.Origin, + inputer: r, + queryer: r.Queryer, + db: r.DB, + federation: r.FSAPI, + roomsMu: internal.NewMutexByRoom(), + servers: []gomatrixserverlib.ServerName{input.Origin}, + hadEvents: map[string]bool{}, + haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{}, + } + if err = missingState.processEventWithMissingState(ctx, input.Event.Unwrap(), roomInfo.RoomVersion); err != nil { + return "", fmt.Errorf("r.checkForMissingPrevEvents: %w", err) + } + } + if stateAtEvent.BeforeStateSnapshotNID == 0 { // We haven't calculated a state for this event yet. // Lets calculate one. @@ -184,13 +214,7 @@ func (r *Inputer) processRoomEvent( // We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it. if isRejected || softfail { - logrus.WithFields(logrus.Fields{ - "event_id": event.EventID(), - "type": event.Type(), - "room": event.RoomID(), - "soft_fail": softfail, - "sender": event.Sender(), - }).Debug("Stored rejected event") + logger.WithField("soft_fail", softfail).Debug("Stored rejected event") return event.EventID(), rejectionErr } @@ -246,9 +270,10 @@ func (r *Inputer) processRoomEvent( func (r *Inputer) checkForMissingAuthEvents( ctx context.Context, + logger *logrus.Entry, event *gomatrixserverlib.HeaderedEvent, auth *gomatrixserverlib.AuthEvents, - known map[string]types.Event, + known map[string]*types.Event, ) error { authEventIDs := event.AuthEventIDs() if len(authEventIDs) == 0 { @@ -263,7 +288,8 @@ func (r *Inputer) checkForMissingAuthEvents( } for _, event := range authEvents { if event.Event != nil { - known[event.EventID()] = event + ev := event // don't take the address of the iterated value + known[event.EventID()] = &ev if err = auth.AddEvent(event.Event); err != nil { return fmt.Errorf("auth.AddEvent: %w", err) } @@ -273,7 +299,7 @@ func (r *Inputer) checkForMissingAuthEvents( } if len(unknown) > 0 { - logrus.Printf("XXX: There are %d missing auth events", len(unknown)) + logger.Printf("XXX: There are %d missing auth events", len(unknown)) serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ RoomID: event.RoomID(), @@ -283,22 +309,22 @@ func (r *Inputer) checkForMissingAuthEvents( return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) } - logrus.Printf("XXX: Asking servers %+v", serverRes.ServerNames) + logger.Printf("XXX: Asking servers %+v", serverRes.ServerNames) var res gomatrixserverlib.RespEventAuth var found bool for _, serverName := range serverRes.ServerNames { res, err = r.FSAPI.GetEventAuth(ctx, serverName, event.RoomID(), event.EventID()) if err != nil { - logrus.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err) + logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err) continue } - logrus.Printf("XXX: Server %q provided us with %d auth events", serverName, len(res.AuthEvents)) + logger.Printf("XXX: Server %q provided us with %d auth events", serverName, len(res.AuthEvents)) found = true break } if !found { - logrus.Printf("XXX: None of the %d servers provided us with auth events", len(serverRes.ServerNames)) + logger.Printf("XXX: None of the %d servers provided us with auth events", len(serverRes.ServerNames)) return fmt.Errorf("no servers provided event auth") } @@ -332,7 +358,7 @@ func (r *Inputer) checkForMissingAuthEvents( } // Let's take a note of the fact that we now know about this event. - known[event.EventID()] = types.Event{} + known[event.EventID()] = nil if err := auth.AddEvent(event); err != nil { return fmt.Errorf("auth.AddEvent: %w", err) } @@ -341,7 +367,7 @@ func (r *Inputer) checkForMissingAuthEvents( isRejected := false if err := gomatrixserverlib.Allowed(event, auth); err != nil { isRejected = true - logrus.WithError(err).Warnf("Auth event %s rejected", event.EventID()) + logger.WithError(err).Warnf("Auth event %s rejected", event.EventID()) } // Finally, store the event in the database. @@ -351,7 +377,7 @@ func (r *Inputer) checkForMissingAuthEvents( } // Now we know about this event, too. - known[event.EventID()] = types.Event{ + known[event.EventID()] = &types.Event{ EventNID: eventNID, Event: event, } @@ -361,6 +387,228 @@ func (r *Inputer) checkForMissingAuthEvents( return nil } +/* +func (r *Inputer) checkForMissingPrevEvents( + ctx context.Context, + logger *logrus.Entry, + event *gomatrixserverlib.HeaderedEvent, + roomInfo *types.RoomInfo, + known map[string]*types.Event, +) error { + prevStates := map[string]*types.StateAtEvent{} + prevEventIDs := event.PrevEventIDs() + if len(prevEventIDs) == 0 && event.Type() != gomatrixserverlib.MRoomCreate { + return fmt.Errorf("expected to find some prev events for event type %q", event.Type()) + } + + for _, eventID := range prevEventIDs { + state, err := r.DB.StateAtEventIDs(ctx, []string{eventID}) + if err != nil { + if _, ok := err.(types.MissingEventError); ok { + continue + } + return fmt.Errorf("r.DB.StateAtEventIDs: %w", err) + } + if len(state) == 1 { + prevStates[eventID] = &state[0] + continue + } + } + + // If we know all of the states of the previous events then there is nothing more to + // do here, as the state across them will be resolved later. + if len(prevStates) == len(prevEventIDs) { + return nil + } + if r.FSAPI == nil { + return fmt.Errorf("cannot satisfy missing events without federation") + } + + // Ask the federation API which servers we should ask. In theory the roomserver + // doesn't need the help of the federation API to do this because we already know + // all of the membership states, it's just that the federation API tracks this in + // a table for this purpose. TODO: Work out what makes most sense here. + serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ + RoomID: event.RoomID(), + } + serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{} + if err := r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { + return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) + } + + // Attempt to fill in the gap using /get_missing_events + // This will either: + // - fill in the gap completely then process event `e` returning no backwards extremity + // - fail to fill in the gap and tell us to terminate the transaction err=not nil + // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction + newEvents, err := r.getMissingEvents(ctx, logger, event, roomInfo, serverRes.ServerNames, known) + if err != nil { + return err + } + if len(newEvents) == 0 { + return fmt.Errorf("/get_missing_events returned no new events") + } + + return nil +} + +func (r *Inputer) getMissingEvents( + ctx context.Context, + logger *logrus.Entry, + event *gomatrixserverlib.HeaderedEvent, + roomInfo *types.RoomInfo, + servers []gomatrixserverlib.ServerName, + known map[string]*types.Event, +) (newEvents []*gomatrixserverlib.Event, err error) { + logger.Printf("XXX: get_missing_events called") + needed := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) + + // Ask the roomserver for our current forward extremities. These will form + // the "earliest" part of the `/get_missing_events` request. + req := &api.QueryLatestEventsAndStateRequest{ + RoomID: event.RoomID(), + StateToFetch: needed.Tuples(), + } + res := &api.QueryLatestEventsAndStateResponse{} + if err = r.Queryer.QueryLatestEventsAndState(ctx, req, res); err != nil { + logger.WithError(err).Warn("Failed to query latest events") + return nil, err + } + + // Accumulate the event IDs of our forward extremities for use in the request. + latestEvents := make([]string, len(res.LatestEvents)) + for i := range res.LatestEvents { + latestEvents[i] = res.LatestEvents[i].EventID + } + + var missingResp *gomatrixserverlib.RespMissingEvents + for _, server := range servers { + logger.Printf("XXX: Calling /get_missing_events via %q", server) + var m gomatrixserverlib.RespMissingEvents + if m, err = r.FSAPI.LookupMissingEvents(ctx, server, event.RoomID(), gomatrixserverlib.MissingEvents{ + Limit: 20, + EarliestEvents: latestEvents, + LatestEvents: []string{event.EventID()}, + }, event.RoomVersion); err == nil { + missingResp = &m + break + } else if errors.Is(err, context.DeadlineExceeded) { + break + } + } + + if missingResp == nil { + return nil, fmt.Errorf("/get_missing_events failed via all candidate servers") + } + if len(missingResp.Events) == 0 { + return nil, fmt.Errorf("/get_missing_events returned no events") + } + + // security: how we handle failures depends on whether or not this event will become the new forward extremity for the room. + // There's 2 scenarios to consider: + // - Case A: We got pushed an event and are now fetching missing prev_events. (isInboundTxn=true) + // - Case B: We are fetching missing prev_events already and now fetching some more (isInboundTxn=false) + // In Case B, we know for sure that the event we are currently processing will not become the new forward extremity for the room, + // as it was called in response to an inbound txn which had it as a prev_event. + // In Case A, the event is a forward extremity, and could eventually become the _only_ forward extremity in the room. This is bad + // because it means we would trust the state at that event to be the state for the entire room, and allows rooms to be hijacked. + // https://github.com/matrix-org/synapse/pull/3456 + // https://github.com/matrix-org/synapse/blob/229eb81498b0fe1da81e9b5b333a0285acde9446/synapse/handlers/federation.py#L335 + // For now, we do not allow Case B, so reject the event. + logger.Printf("XXX: get_missing_events returned %d events", len(missingResp.Events)) + + newEvents = gomatrixserverlib.ReverseTopologicalOrdering( + missingResp.Events, + gomatrixserverlib.TopologicalOrderByPrevEvents, + ) + for _, pe := range event.PrevEventIDs() { + hasPrevEvent := false + for _, ev := range newEvents { + if ev.EventID() == pe { + hasPrevEvent = true + break + } + } + if !hasPrevEvent { + logger.Errorf("Prev event %q is still missing after /get_missing_events", pe) + } + } + + backwardExtremity := newEvents[0] + fastForwardEvents := newEvents[1:] + + // Do we know about the state of the backward extremity already? + if _, err := r.DB.StateAtEventIDs(ctx, []string{backwardExtremity.EventID()}); err == nil { + // Yes, we do, so we don't need to store that event. + } else { + // No, we don't, so let's go find it. + // r.FSAPI.LookupStateIDs() + } + + for _, ev := range fastForwardEvents { + if _, err := r.processRoomEvent(ctx, &api.InputRoomEvent{ + Kind: api.KindOld, + Event: ev.Headered(event.RoomVersion), + AuthEventIDs: ev.AuthEventIDs(), + }); err != nil { + return nil, fmt.Errorf("r.processRoomEvent (prev event): %w", err) + } + } + + return newEvents, nil +} + +func (r *Inputer) lookupStateBeforeEvent( + ctx context.Context, + logger *logrus.Entry, + event *gomatrixserverlib.HeaderedEvent, + roomInfo *types.RoomInfo, + servers []gomatrixserverlib.ServerName, +) error { + knownPrevStates := map[string]types.StateAtEvent{} + unknownPrevStates := map[string]struct{}{} + neededStateEvents := map[string]struct{}{} + + for _, prevEventID := range event.PrevEventIDs() { + if state, err := r.DB.StateAtEventIDs(ctx, []string{prevEventID}); err == nil && len(state) == 1 { + knownPrevStates[prevEventID] = state[0] + } else { + unknownPrevStates[prevEventID] = struct{}{} + } + } + + for prevEventID := range unknownPrevStates { + stateIDs, err := r.FSAPI.LookupStateIDs(ctx, "TODO: SERVER", event.RoomID(), prevEventID) + if err != nil { + return fmt.Errorf("r.FSAPI.LookupStateIDs: %w", err) + } + events, err := r.DB.EventsFromIDs(ctx, stateIDs.StateEventIDs) + if err != nil { + return fmt.Errorf("r.DB.EventsFromIDs: %w", err) + } + for i, eventID := range stateIDs.StateEventIDs { + if events[i].Event == nil || events[i].EventNID == 0 { + neededStateEvents[eventID] = struct{}{} + } + } + + if len(neededStateEvents) > (len(stateIDs.StateEventIDs) / 2) { + // More than 50% of the state events are missing, so let's just + // call `/state` instead of fetching the events individually. + state, err := r.FSAPI.LookupState(ctx, "", event.RoomID(), prevEventID, roomInfo.RoomVersion) + if err != nil { + return fmt.Errorf("r.FSAPI.LookupState: %w", err) + } + knownPrevStates[prevEventID] = types.StateAtEvent{ + StateEntry: types.StateEntry{}, + } + } + } + + return nil +} +*/ + func (r *Inputer) calculateAndSetState( ctx context.Context, input *api.InputRoomEvent, diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go new file mode 100644 index 000000000..8d0632150 --- /dev/null +++ b/roomserver/internal/input/input_missing.go @@ -0,0 +1,708 @@ +package input + +import ( + "context" + "errors" + "fmt" + "sync" + + fedapi "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/query" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +type missingStateReq struct { + origin gomatrixserverlib.ServerName + db storage.Database + inputer *Inputer + queryer *query.Queryer + keys gomatrixserverlib.JSONVerifier + federation fedapi.FederationInternalAPI + roomsMu *internal.MutexByRoom + servers []gomatrixserverlib.ServerName + hadEvents map[string]bool + hadEventsMutex sync.Mutex + haveEvents map[string]*gomatrixserverlib.HeaderedEvent + haveEventsMutex sync.Mutex +} + +func (t *missingStateReq) processEventWithMissingState( + ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, +) error { + // We are missing the previous events for this events. + // This means that there is a gap in our view of the history of the + // room. There two ways that we can handle such a gap: + // 1) We can fill in the gap using /get_missing_events + // 2) We can leave the gap and request the state of the room at + // this event from the remote server using either /state_ids + // or /state. + // Synapse will attempt to do 1 and if that fails or if the gap is + // too large then it will attempt 2. + // Synapse will use /state_ids if possible since usually the state + // is largely unchanged and it is more efficient to fetch a list of + // event ids and then use /event to fetch the individual events. + // However not all version of synapse support /state_ids so you may + // need to fallback to /state. + + // Attempt to fill in the gap using /get_missing_events + // This will either: + // - fill in the gap completely then process event `e` returning no backwards extremity + // - fail to fill in the gap and tell us to terminate the transaction err=not nil + // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction + newEvents, err := t.getMissingEvents(ctx, e, roomVersion) + if err != nil { + return err + } + if len(newEvents) == 0 { + return nil + } + + backwardsExtremity := newEvents[0] + newEvents = newEvents[1:] + + type respState struct { + // A snapshot is considered trustworthy if it came from our own roomserver. + // That's because the state will have been through state resolution once + // already in QueryStateAfterEvent. + trustworthy bool + *gomatrixserverlib.RespState + } + + // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. + // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query + // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. + var states []*respState + for _, prevEventID := range backwardsExtremity.PrevEventIDs() { + // Look up what the state is after the backward extremity. This will either + // come from the roomserver, if we know all the required events, or it will + // come from a remote server via /state_ids if not. + prevState, trustworthy, lerr := t.lookupStateAfterEvent(ctx, roomVersion, backwardsExtremity.RoomID(), prevEventID) + if lerr != nil { + util.GetLogger(ctx).WithError(lerr).Errorf("Failed to lookup state after prev_event: %s", prevEventID) + return lerr + } + // Append the state onto the collected state. We'll run this through the + // state resolution next. + states = append(states, &respState{trustworthy, prevState}) + } + + // Now that we have collected all of the state from the prev_events, we'll + // run the state through the appropriate state resolution algorithm for the + // room if needed. This does a couple of things: + // 1. Ensures that the state is deduplicated fully for each state-key tuple + // 2. Ensures that we pick the latest events from both sets, in the case that + // one of the prev_events is quite a bit older than the others + resolvedState := &gomatrixserverlib.RespState{} + switch len(states) { + case 0: + extremityIsCreate := backwardsExtremity.Type() == gomatrixserverlib.MRoomCreate && backwardsExtremity.StateKeyEquals("") + if !extremityIsCreate { + // There are no previous states and this isn't the beginning of the + // room - this is an error condition! + util.GetLogger(ctx).Errorf("Failed to lookup any state after prev_events") + return fmt.Errorf("expected %d states but got %d", len(backwardsExtremity.PrevEventIDs()), len(states)) + } + case 1: + // There's only one previous state - if it's trustworthy (came from a + // local state snapshot which will already have been through state res), + // use it as-is. There's no point in resolving it again. + if states[0].trustworthy { + resolvedState = states[0].RespState + break + } + // Otherwise, if it isn't trustworthy (came from federation), run it through + // state resolution anyway for safety, in case there are duplicates. + fallthrough + default: + respStates := make([]*gomatrixserverlib.RespState, len(states)) + for i := range states { + respStates[i] = states[i].RespState + } + // There's more than one previous state - run them all through state res + t.roomsMu.Lock(e.RoomID()) + resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, backwardsExtremity) + t.roomsMu.Unlock(e.RoomID()) + if err != nil { + util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) + return err + } + } + + // First of all, send the backward extremity into the roomserver with the + // newly resolved state. This marks the "oldest" point in the backfill and + // sets the baseline state for any new events after this. We'll make a + // copy of the hadEvents map so that it can be taken downstream without + // worrying about concurrent map reads/writes, since t.hadEvents is meant + // to be protected by a mutex. + hadEvents := map[string]bool{} + t.hadEventsMutex.Lock() + for k, v := range t.hadEvents { + hadEvents[k] = v + } + t.hadEventsMutex.Unlock() + + err = api.SendEventWithState( + context.Background(), + t.inputer, + api.KindOld, + resolvedState, + backwardsExtremity.Headered(roomVersion), + t.origin, + hadEvents, + ) + if err != nil { + return fmt.Errorf("api.SendEventWithState: %w", err) + } + + // Then send all of the newer backfilled events, of which will all be newer + // than the backward extremity, into the roomserver without state. This way + // they will automatically fast-forward based on the room state at the + // extremity in the last step. + headeredNewEvents := make([]*gomatrixserverlib.HeaderedEvent, len(newEvents)) + for i, newEvent := range newEvents { + headeredNewEvents[i] = newEvent.Headered(roomVersion) + } + if err = api.SendEvents( + context.Background(), + t.inputer, + api.KindOld, + append(headeredNewEvents, e.Headered(roomVersion)), + t.origin, + api.DoNotSendToOtherServers, + nil, + ); err != nil { + return fmt.Errorf("api.SendEvents: %w", err) + } + + return nil +} + +// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) +// added into the mix. +func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*gomatrixserverlib.RespState, bool, error) { + // try doing all this locally before we resort to querying federation + respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID) + if respState != nil { + return respState, true, nil + } + + respState, err := t.lookupStateBeforeEvent(ctx, roomVersion, roomID, eventID) + if err != nil { + return nil, false, fmt.Errorf("t.lookupStateBeforeEvent: %w", err) + } + + // fetch the event we're missing and add it to the pile + h, err := t.lookupEvent(ctx, roomVersion, roomID, eventID, false) + switch err.(type) { + case verifySigError: + return respState, false, nil + case nil: + // do nothing + default: + return nil, false, fmt.Errorf("t.lookupEvent: %w", err) + } + h = t.cacheAndReturn(h) + if h.StateKey() != nil { + addedToState := false + for i := range respState.StateEvents { + se := respState.StateEvents[i] + if se.Type() == h.Type() && se.StateKeyEquals(*h.StateKey()) { + respState.StateEvents[i] = h.Unwrap() + addedToState = true + break + } + } + if !addedToState { + respState.StateEvents = append(respState.StateEvents, h.Unwrap()) + } + } + + return respState, false, nil +} + +func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *gomatrixserverlib.HeaderedEvent { + t.haveEventsMutex.Lock() + defer t.haveEventsMutex.Unlock() + if cached, exists := t.haveEvents[ev.EventID()]; exists { + return cached + } + t.haveEvents[ev.EventID()] = ev + return ev +} + +func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *gomatrixserverlib.RespState { + var res api.QueryStateAfterEventsResponse + err := t.queryer.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{ + RoomID: roomID, + PrevEventIDs: []string{eventID}, + }, &res) + if err != nil || !res.PrevEventsExist { + util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to query state after %s locally, prev exists=%v", eventID, res.PrevEventsExist) + return nil + } + stateEvents := make([]*gomatrixserverlib.HeaderedEvent, len(res.StateEvents)) + for i, ev := range res.StateEvents { + // set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this + // processEvent request, which is better for memory. + stateEvents[i] = t.cacheAndReturn(ev) + t.hadEvent(ev.EventID(), true) + } + // we should never access res.StateEvents again so we delete it here to make GC faster + res.StateEvents = nil + + var authEvents []*gomatrixserverlib.Event + missingAuthEvents := map[string]bool{} + for _, ev := range stateEvents { + t.haveEventsMutex.Lock() + for _, ae := range ev.AuthEventIDs() { + if aev, ok := t.haveEvents[ae]; ok { + authEvents = append(authEvents, aev.Unwrap()) + } else { + missingAuthEvents[ae] = true + } + } + t.haveEventsMutex.Unlock() + } + // QueryStateAfterEvents does not return the auth events, so fetch them now. We know the roomserver has them else it wouldn't + // have stored the event. + if len(missingAuthEvents) > 0 { + var missingEventList []string + for evID := range missingAuthEvents { + missingEventList = append(missingEventList, evID) + } + queryReq := api.QueryEventsByIDRequest{ + EventIDs: missingEventList, + } + util.GetLogger(ctx).WithField("count", len(missingEventList)).Infof("Fetching missing auth events") + var queryRes api.QueryEventsByIDResponse + if err = t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { + return nil + } + for i, ev := range queryRes.Events { + authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap()) + t.hadEvent(ev.EventID(), true) + } + queryRes.Events = nil + } + + return &gomatrixserverlib.RespState{ + StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateEvents), + AuthEvents: authEvents, + } +} + +// lookuptStateBeforeEvent returns the room state before the event e, which is just /state_ids and/or /state depending on what +// the server supports. +func (t *missingStateReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( + *gomatrixserverlib.RespState, error) { + + // Attempt to fetch the missing state using /state_ids and /events + return t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion) +} + +func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*gomatrixserverlib.RespState, backwardsExtremity *gomatrixserverlib.Event) (*gomatrixserverlib.RespState, error) { + var authEventList []*gomatrixserverlib.Event + var stateEventList []*gomatrixserverlib.Event + for _, state := range states { + authEventList = append(authEventList, state.AuthEvents...) + stateEventList = append(stateEventList, state.StateEvents...) + } + resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(roomVersion, stateEventList, authEventList) + if err != nil { + return nil, err + } + // apply the current event +retryAllowedState: + if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil { + switch missing := err.(type) { + case gomatrixserverlib.MissingAuthEventError: + h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) + switch err2.(type) { + case verifySigError: + return &gomatrixserverlib.RespState{ + AuthEvents: authEventList, + StateEvents: resolvedStateEvents, + }, nil + case nil: + // do nothing + default: + return nil, fmt.Errorf("missing auth event %s and failed to look it up: %w", missing.AuthEventID, err2) + } + util.GetLogger(ctx).Infof("fetched event %s", missing.AuthEventID) + resolvedStateEvents = append(resolvedStateEvents, h.Unwrap()) + goto retryAllowedState + default: + } + return nil, err + } + return &gomatrixserverlib.RespState{ + AuthEvents: authEventList, + StateEvents: resolvedStateEvents, + }, nil +} + +func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, err error) { + logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) + needed := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{e}) + // query latest events (our trusted forward extremities) + req := api.QueryLatestEventsAndStateRequest{ + RoomID: e.RoomID(), + StateToFetch: needed.Tuples(), + } + var res api.QueryLatestEventsAndStateResponse + if err = t.queryer.QueryLatestEventsAndState(ctx, &req, &res); err != nil { + logger.WithError(err).Warn("Failed to query latest events") + return nil, err + } + latestEvents := make([]string, len(res.LatestEvents)) + for i, ev := range res.LatestEvents { + latestEvents[i] = res.LatestEvents[i].EventID + t.hadEvent(ev.EventID, true) + } + + var missingResp *gomatrixserverlib.RespMissingEvents + for _, server := range t.servers { + var m gomatrixserverlib.RespMissingEvents + if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ + Limit: 20, + // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. + EarliestEvents: latestEvents, + // The event IDs to retrieve the previous events for. + LatestEvents: []string{e.EventID()}, + }, roomVersion); err == nil { + missingResp = &m + break + } else { + logger.WithError(err).Errorf("%s pushed us an event but %q did not respond to /get_missing_events", t.origin, server) + if errors.Is(err, context.DeadlineExceeded) { + break + } + } + } + + if missingResp == nil { + logger.WithError(err).Errorf( + "%s pushed us an event but %d server(s) couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", + t.origin, len(t.servers), + ) + return nil, missingPrevEventsError{ + eventID: e.EventID(), + err: err, + } + } + + // security: how we handle failures depends on whether or not this event will become the new forward extremity for the room. + // There's 2 scenarios to consider: + // - Case A: We got pushed an event and are now fetching missing prev_events. (isInboundTxn=true) + // - Case B: We are fetching missing prev_events already and now fetching some more (isInboundTxn=false) + // In Case B, we know for sure that the event we are currently processing will not become the new forward extremity for the room, + // as it was called in response to an inbound txn which had it as a prev_event. + // In Case A, the event is a forward extremity, and could eventually become the _only_ forward extremity in the room. This is bad + // because it means we would trust the state at that event to be the state for the entire room, and allows rooms to be hijacked. + // https://github.com/matrix-org/synapse/pull/3456 + // https://github.com/matrix-org/synapse/blob/229eb81498b0fe1da81e9b5b333a0285acde9446/synapse/handlers/federation.py#L335 + // For now, we do not allow Case B, so reject the event. + logger.Infof("get_missing_events returned %d events", len(missingResp.Events)) + + // Make sure events from the missingResp are using the cache - missing events + // will be added and duplicates will be removed. + for i, ev := range missingResp.Events { + missingResp.Events[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + } + + // topologically sort and sanity check that we are making forward progress + newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents) + shouldHaveSomeEventIDs := e.PrevEventIDs() + hasPrevEvent := false +Event: + for _, pe := range shouldHaveSomeEventIDs { + for _, ev := range newEvents { + if ev.EventID() == pe { + hasPrevEvent = true + break Event + } + } + } + if !hasPrevEvent { + err = fmt.Errorf("called /get_missing_events but server %s didn't return any prev_events with IDs %v", t.origin, shouldHaveSomeEventIDs) + logger.WithError(err).Errorf( + "%s pushed us an event but couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", + t.origin, + ) + return nil, missingPrevEventsError{ + eventID: e.EventID(), + err: err, + } + } + + return newEvents, nil +} + +func (t *missingStateReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( + respState *gomatrixserverlib.RespState, err error) { + state, err := t.federation.LookupState(ctx, t.origin, roomID, eventID, roomVersion) + if err != nil { + return nil, err + } + // Check that the returned state is valid. + if err := state.Check(ctx, t.keys, nil); err != nil { + return nil, err + } + // Cache the results of this state lookup and deduplicate anything we already + // have in the cache, freeing up memory. + for i, ev := range state.AuthEvents { + state.AuthEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + } + for i, ev := range state.StateEvents { + state.StateEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + } + return &state, nil +} + +func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( + *gomatrixserverlib.RespState, error) { + util.GetLogger(ctx).WithField("room_id", roomID).Infof("lookupMissingStateViaStateIDs %s", eventID) + // fetch the state event IDs at the time of the event + stateIDs, err := t.federation.LookupStateIDs(ctx, t.origin, roomID, eventID) + if err != nil { + return nil, err + } + // work out which auth/state IDs are missing + wantIDs := append(stateIDs.StateEventIDs, stateIDs.AuthEventIDs...) + missing := make(map[string]bool) + var missingEventList []string + t.haveEventsMutex.Lock() + for _, sid := range wantIDs { + if _, ok := t.haveEvents[sid]; !ok { + if !missing[sid] { + missing[sid] = true + missingEventList = append(missingEventList, sid) + } + } + } + t.haveEventsMutex.Unlock() + + // fetch as many as we can from the roomserver + queryReq := api.QueryEventsByIDRequest{ + EventIDs: missingEventList, + } + var queryRes api.QueryEventsByIDResponse + if err = t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { + return nil, err + } + for i, ev := range queryRes.Events { + queryRes.Events[i] = t.cacheAndReturn(queryRes.Events[i]) + t.hadEvent(ev.EventID(), true) + evID := queryRes.Events[i].EventID() + if missing[evID] { + delete(missing, evID) + } + } + queryRes.Events = nil // allow it to be GCed + + concurrentRequests := 8 + missingCount := len(missing) + util.GetLogger(ctx).WithField("room_id", roomID).WithField("event_id", eventID).Infof("lookupMissingStateViaStateIDs missing %d/%d events", missingCount, len(wantIDs)) + + // If over 50% of the auth/state events from /state_ids are missing + // then we'll just call /state instead, otherwise we'll just end up + // hammering the remote side with /event requests unnecessarily. + if missingCount > concurrentRequests && missingCount > len(wantIDs)/2 { + util.GetLogger(ctx).WithFields(logrus.Fields{ + "missing": missingCount, + "event_id": eventID, + "room_id": roomID, + "total_state": len(stateIDs.StateEventIDs), + "total_auth_events": len(stateIDs.AuthEventIDs), + }).Info("Fetching all state at event") + return t.lookupMissingStateViaState(ctx, roomID, eventID, roomVersion) + } + + if missingCount > 0 { + util.GetLogger(ctx).WithFields(logrus.Fields{ + "missing": missingCount, + "event_id": eventID, + "room_id": roomID, + "total_state": len(stateIDs.StateEventIDs), + "total_auth_events": len(stateIDs.AuthEventIDs), + "concurrent_requests": concurrentRequests, + }).Info("Fetching missing state at event") + + // Create a queue containing all of the missing event IDs that we want + // to retrieve. + pending := make(chan string, missingCount) + for missingEventID := range missing { + pending <- missingEventID + } + close(pending) + + // Define how many workers we should start to do this. + if missingCount < concurrentRequests { + concurrentRequests = missingCount + } + + // Create the wait group. + var fetchgroup sync.WaitGroup + fetchgroup.Add(concurrentRequests) + + // This is the only place where we'll write to t.haveEvents from + // multiple goroutines, and everywhere else is blocked on this + // synchronous function anyway. + var haveEventsMutex sync.Mutex + + // Define what we'll do in order to fetch the missing event ID. + fetch := func(missingEventID string) { + var h *gomatrixserverlib.HeaderedEvent + h, err = t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false) + switch err.(type) { + case verifySigError: + return + case nil: + break + default: + util.GetLogger(ctx).WithFields(logrus.Fields{ + "event_id": missingEventID, + "room_id": roomID, + }).Info("Failed to fetch missing event") + return + } + haveEventsMutex.Lock() + t.cacheAndReturn(h) + haveEventsMutex.Unlock() + } + + // Create the worker. + worker := func(ch <-chan string) { + defer fetchgroup.Done() + for missingEventID := range ch { + fetch(missingEventID) + } + } + + // Start the workers. + for i := 0; i < concurrentRequests; i++ { + go worker(pending) + } + + // Wait for the workers to finish. + fetchgroup.Wait() + } + + resp, err := t.createRespStateFromStateIDs(stateIDs) + return resp, err +} + +func (t *missingStateReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs) ( + *gomatrixserverlib.RespState, error) { // nolint:unparam + t.haveEventsMutex.Lock() + defer t.haveEventsMutex.Unlock() + + // create a RespState response using the response to /state_ids as a guide + respState := gomatrixserverlib.RespState{} + + for i := range stateIDs.StateEventIDs { + ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]] + if !ok { + logrus.Warnf("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i]) + continue + } + respState.StateEvents = append(respState.StateEvents, ev.Unwrap()) + } + for i := range stateIDs.AuthEventIDs { + ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]] + if !ok { + logrus.Warnf("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i]) + continue + } + respState.AuthEvents = append(respState.AuthEvents, ev.Unwrap()) + } + // We purposefully do not do auth checks on the returned events, as they will still + // be processed in the exact same way, just as a 'rejected' event + // TODO: Add a field to HeaderedEvent to indicate if the event is rejected. + return &respState, nil +} + +func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) { + if localFirst { + // fetch from the roomserver + queryReq := api.QueryEventsByIDRequest{ + EventIDs: []string{missingEventID}, + } + var queryRes api.QueryEventsByIDResponse + if err := t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { + util.GetLogger(ctx).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) + } else if len(queryRes.Events) == 1 { + return queryRes.Events[0], nil + } + } + var event *gomatrixserverlib.Event + found := false + for _, serverName := range t.servers { + txn, err := t.federation.GetEvent(ctx, serverName, missingEventID) + if err != nil || len(txn.PDUs) == 0 { + util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("Failed to get missing /event for event ID") + if errors.Is(err, context.DeadlineExceeded) { + break + } + continue + } + event, err = gomatrixserverlib.NewEventFromUntrustedJSON(txn.PDUs[0], roomVersion) + if err != nil { + util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warnf("Transaction: Failed to parse event JSON of event") + continue + } + found = true + break + } + if !found { + util.GetLogger(ctx).WithField("event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) + return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) + } + if err := event.VerifyEventSignatures(ctx, t.keys); err != nil { + util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) + return nil, verifySigError{event.EventID(), err} + } + return t.cacheAndReturn(event.Headered(roomVersion)), nil +} + +func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserverlib.Event) error { + authUsingState := gomatrixserverlib.NewAuthEvents(nil) + for i := range stateEvents { + err := authUsingState.AddEvent(stateEvents[i]) + if err != nil { + return err + } + } + return gomatrixserverlib.Allowed(e, &authUsingState) +} + +func (t *missingStateReq) hadEvent(eventID string, had bool) { + t.hadEventsMutex.Lock() + defer t.hadEventsMutex.Unlock() + t.hadEvents[eventID] = had +} + +type roomNotFoundError struct { + roomID string +} +type verifySigError struct { + eventID string + err error +} +type missingPrevEventsError struct { + eventID string + err error +} + +func (e roomNotFoundError) Error() string { return fmt.Sprintf("room %q not found", e.roomID) } +func (e verifySigError) Error() string { + return fmt.Sprintf("unable to verify signature of event %q: %s", e.eventID, e.err) +} +func (e missingPrevEventsError) Error() string { + return fmt.Sprintf("unable to get prev_events for event %q: %s", e.eventID, e.err) +} diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index ca0654685..3083d6537 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -172,6 +172,7 @@ func (r *Inviter) PerformInvite( { Kind: api.KindNew, Event: event, + Origin: event.Origin(), AuthEventIDs: event.AuthEventIDs(), SendAsServer: req.SendAsServer, }, diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 40e8e92d1..2e887db18 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -189,7 +189,7 @@ func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []js t.Helper() rsAPI, dp := mustCreateRoomserverAPI(t) hevents := mustLoadRawEvents(t, ver, events) - if err := api.SendEvents(ctx, rsAPI, api.KindNew, hevents, testOrigin, nil); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, hevents, testOrigin, testOrigin, nil); err != nil { t.Errorf("failed to SendEvents: %s", err) } return rsAPI, dp, hevents @@ -335,7 +335,7 @@ func TestOutputRewritesState(t *testing.T) { deleteDatabase() rsAPI, producer := mustCreateRoomserverAPI(t) defer deleteDatabase() - err := api.SendEvents(context.Background(), rsAPI, api.KindNew, originalEvents, testOrigin, nil) + err := api.SendEvents(context.Background(), rsAPI, api.KindNew, originalEvents, testOrigin, testOrigin, nil) if err != nil { t.Fatalf("failed to send original events: %s", err) }