Add eventsize checks similar to Synapse

This commit is contained in:
Till Faelligen 2023-07-06 15:43:43 +02:00
parent fea946d914
commit f30432678b
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
2 changed files with 74 additions and 12 deletions

View file

@ -250,6 +250,21 @@ func (r *Inputer) processRoomEvent(
// really do anything with the event other than reject it at this point. // really do anything with the event other than reject it at this point.
isRejected = true isRejected = true
rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err) rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err)
switch e := err.(type) {
case gomatrixserverlib.EventValidationError:
if e.Persistable && stateSnapshot != nil {
// We retrieved some state and we ended up having to call /state_ids for
// the new event in question (probably because closing the gap by using
// /get_missing_events didn't do what we hoped) so we'll instead overwrite
// the state snapshot with the newly resolved state.
missingPrev = false
input.HasState = true
input.StateEventIDs = make([]string, 0, len(stateSnapshot.StateEvents))
for _, se := range stateSnapshot.StateEvents {
input.StateEventIDs = append(input.StateEventIDs, se.EventID())
}
}
}
} else if stateSnapshot != nil { } else if stateSnapshot != nil {
// We retrieved some state and we ended up having to call /state_ids for // We retrieved some state and we ended up having to call /state_ids for
// the new event in question (probably because closing the gap by using // the new event in question (probably because closing the gap by using

View file

@ -259,12 +259,20 @@ func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e
// Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query
// the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event.
var states []*respState var states []*respState
var validationError error
for _, prevEventID := range e.PrevEventIDs() { for _, prevEventID := range e.PrevEventIDs() {
// Look up what the state is after the backward extremity. This will either // Look up what the state is after the backward extremity. This will either
// come from the roomserver, if we know all the required events, or it will // come from the roomserver, if we know all the required events, or it will
// come from a remote server via /state_ids if not. // come from a remote server via /state_ids if not.
prevState, trustworthy, err := t.lookupStateAfterEvent(ctx, roomVersion, e.RoomID(), prevEventID) prevState, trustworthy, err := t.lookupStateAfterEvent(ctx, roomVersion, e.RoomID(), prevEventID)
if err != nil { switch err2 := err.(type) {
case gomatrixserverlib.EventValidationError:
if !err2.Persistable {
return nil, err2
}
validationError = err2
case nil:
default:
return nil, fmt.Errorf("t.lookupStateAfterEvent: %w", err) return nil, fmt.Errorf("t.lookupStateAfterEvent: %w", err)
} }
// Append the state onto the collected state. We'll run this through the // Append the state onto the collected state. We'll run this through the
@ -311,12 +319,19 @@ func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e
t.roomsMu.Lock(e.RoomID()) t.roomsMu.Lock(e.RoomID())
resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, e) resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, e)
t.roomsMu.Unlock(e.RoomID()) t.roomsMu.Unlock(e.RoomID())
if err != nil { switch err2 := err.(type) {
case gomatrixserverlib.EventValidationError:
if !err2.Persistable {
return nil, err2
}
validationError = err2
case nil:
default:
return nil, fmt.Errorf("t.resolveStatesAndCheck: %w", err) return nil, fmt.Errorf("t.resolveStatesAndCheck: %w", err)
} }
} }
return resolvedState, nil return resolvedState, validationError
} }
// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event)
@ -339,8 +354,15 @@ func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion
} }
// fetch the event we're missing and add it to the pile // fetch the event we're missing and add it to the pile
var validationError error
h, err := t.lookupEvent(ctx, roomVersion, roomID, eventID, false) h, err := t.lookupEvent(ctx, roomVersion, roomID, eventID, false)
switch err.(type) { switch e := err.(type) {
case gomatrixserverlib.EventValidationError:
if !e.Persistable {
logrus.WithContext(ctx).WithError(err).Errorf("Failed to look up event %s", eventID)
return nil, false, e
}
validationError = e
case verifySigError: case verifySigError:
return respState, false, nil return respState, false, nil
case nil: case nil:
@ -365,7 +387,7 @@ func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion
} }
} }
return respState, false, nil return respState, false, validationError
} }
func (t *missingStateReq) cacheAndReturn(ev gomatrixserverlib.PDU) gomatrixserverlib.PDU { func (t *missingStateReq) cacheAndReturn(ev gomatrixserverlib.PDU) gomatrixserverlib.PDU {
@ -481,6 +503,7 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
return nil, err return nil, err
} }
// apply the current event // apply the current event
var validationError error
retryAllowedState: retryAllowedState:
if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
@ -488,7 +511,12 @@ retryAllowedState:
switch missing := err.(type) { switch missing := err.(type) {
case gomatrixserverlib.MissingAuthEventError: case gomatrixserverlib.MissingAuthEventError:
h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true)
switch err2.(type) { switch e := err2.(type) {
case gomatrixserverlib.EventValidationError:
if !e.Persistable {
return nil, e
}
validationError = e
case verifySigError: case verifySigError:
return &parsedRespState{ return &parsedRespState{
AuthEvents: authEventList, AuthEvents: authEventList,
@ -509,7 +537,7 @@ retryAllowedState:
return &parsedRespState{ return &parsedRespState{
AuthEvents: authEventList, AuthEvents: authEventList,
StateEvents: resolvedStateEvents, StateEvents: resolvedStateEvents,
}, nil }, validationError
} }
// get missing events for `e`. If `isGapFilled`=true then `newEvents` contains all the events to inject, // get missing events for `e`. If `isGapFilled`=true then `newEvents` contains all the events to inject,
@ -779,7 +807,11 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
// Define what we'll do in order to fetch the missing event ID. // Define what we'll do in order to fetch the missing event ID.
fetch := func(missingEventID string) { fetch := func(missingEventID string) {
h, herr := t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false) h, herr := t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false)
switch herr.(type) { switch e := herr.(type) {
case gomatrixserverlib.EventValidationError:
if !e.Persistable {
return
}
case verifySigError: case verifySigError:
return return
case nil: case nil:
@ -869,6 +901,8 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
} }
var event gomatrixserverlib.PDU var event gomatrixserverlib.PDU
found := false found := false
var validationError error
serverLoop:
for _, serverName := range t.servers { for _, serverName := range t.servers {
reqctx, cancel := context.WithTimeout(ctx, time.Second*30) reqctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel() defer cancel()
@ -886,12 +920,25 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
continue continue
} }
event, err = verImpl.NewEventFromUntrustedJSON(txn.PDUs[0]) event, err = verImpl.NewEventFromUntrustedJSON(txn.PDUs[0])
if err != nil { switch e := err.(type) {
case gomatrixserverlib.EventValidationError:
// If the event is persistable, e.g. failed validation for exceeding
// byte sizes, we can "accept" the event.
if e.Persistable {
validationError = e
found = true
break serverLoop
}
// If we can't persist the event, we probably can't do so with results
// from other servers, so also break the loop.
break serverLoop
case nil:
found = true
break serverLoop
default:
t.log.WithError(err).WithField("missing_event_id", missingEventID).Warnf("Failed to parse event JSON of event returned from /event") t.log.WithError(err).WithField("missing_event_id", missingEventID).Warnf("Failed to parse event JSON of event returned from /event")
continue continue
} }
found = true
break
} }
if !found { if !found {
t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers))
@ -903,7 +950,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
return nil, verifySigError{event.EventID(), err} return nil, verifySigError{event.EventID(), err}
} }
return t.cacheAndReturn(event), nil return t.cacheAndReturn(event), validationError
} }
func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) error { func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) error {