Reject instead of soft-fail, don't copy roominfo so much

This commit is contained in:
Neil Alexander 2022-01-07 10:50:19 +00:00
parent eff348bb69
commit af34b4abe3
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
9 changed files with 22 additions and 23 deletions

View file

@ -56,7 +56,7 @@ func CheckForSoftFail(
// Then get the state entries for the current state snapshot. // Then get the state entries for the current state snapshot.
// We'll use this to check if the event is allowed right now. // We'll use this to check if the event is allowed right now.
roomState := state.NewStateResolution(db, *roomInfo) roomState := state.NewStateResolution(db, roomInfo)
authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
if err != nil { if err != nil {
return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err) return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err)

View file

@ -179,7 +179,7 @@ func GetMembershipsAtState(
return events, nil return events, nil
} }
func StateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) {
roomState := state.NewStateResolution(db, info) roomState := state.NewStateResolution(db, info)
// Lookup the event NID // Lookup the event NID
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
@ -223,7 +223,7 @@ func LoadStateEvents(
} }
func CheckServerAllowedToSeeEvent( func CheckServerAllowedToSeeEvent(
ctx context.Context, db storage.Database, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
) (bool, error) { ) (bool, error) {
roomState := state.NewStateResolution(db, info) roomState := state.NewStateResolution(db, info)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
@ -279,7 +279,7 @@ func CheckServerAllowedToSeeEvent(
// TODO: Remove this when we have tests to assert correctness of this function // TODO: Remove this when we have tests to assert correctness of this function
func ScanEventTree( func ScanEventTree(
ctx context.Context, db storage.Database, info types.RoomInfo, front []string, visited map[string]bool, limit int, ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
) ([]types.EventNID, error) { ) ([]types.EventNID, error) {
var resultNIDs []types.EventNID var resultNIDs []types.EventNID
@ -387,7 +387,7 @@ func QueryLatestEventsAndState(
return nil return nil
} }
roomState := state.NewStateResolution(db, *roomInfo) roomState := state.NewStateResolution(db, roomInfo)
response.RoomExists = true response.RoomExists = true
response.RoomVersion = roomInfo.RoomVersion response.RoomVersion = roomInfo.RoomVersion

View file

@ -185,8 +185,7 @@ func (r *Inputer) processRoomEvent(
haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{}, haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{},
} }
if err = missingState.processEventWithMissingState(ctx, input.Event.Unwrap(), input.Event.RoomVersion); err != nil { if err = missingState.processEventWithMissingState(ctx, input.Event.Unwrap(), input.Event.RoomVersion); err != nil {
//return fmt.Errorf("r.checkForMissingPrevEvents: %w", err) isRejected = true
softfail = true
rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err) rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err)
} }
} }
@ -225,7 +224,7 @@ func (r *Inputer) processRoomEvent(
if stateAtEvent.BeforeStateSnapshotNID == 0 { if stateAtEvent.BeforeStateSnapshotNID == 0 {
// We haven't calculated a state for this event yet. // We haven't calculated a state for this event yet.
// Lets calculate one. // Lets calculate one.
err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event, isRejected) err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected)
if err != nil && input.Kind != api.KindOld { if err != nil && input.Kind != api.KindOld {
return fmt.Errorf("r.calculateAndSetState: %w", err) return fmt.Errorf("r.calculateAndSetState: %w", err)
} }
@ -233,7 +232,7 @@ func (r *Inputer) processRoomEvent(
// We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it. // We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it.
if isRejected || softfail { if isRejected || softfail {
logger.WithField("soft_fail", softfail).WithField("reason", rejectionErr).Debug("Stored rejected event") logger.WithError(rejectionErr).WithField("soft_fail", softfail).Debug("Stored rejected event")
return rejectionErr return rejectionErr
} }
@ -395,7 +394,7 @@ func (r *Inputer) checkForMissingAuthEvents(
func (r *Inputer) calculateAndSetState( func (r *Inputer) calculateAndSetState(
ctx context.Context, ctx context.Context,
input *api.InputRoomEvent, input *api.InputRoomEvent,
roomInfo types.RoomInfo, roomInfo *types.RoomInfo,
stateAtEvent *types.StateAtEvent, stateAtEvent *types.StateAtEvent,
event *gomatrixserverlib.Event, event *gomatrixserverlib.Event,
isRejected bool, isRejected bool,

View file

@ -199,7 +199,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
func (u *latestEventsUpdater) latestState() error { func (u *latestEventsUpdater) latestState() error {
var err error var err error
roomState := state.NewStateResolution(u.api.DB, *u.roomInfo) roomState := state.NewStateResolution(u.api.DB, u.roomInfo)
// Work out if the state at the extremities has actually changed // Work out if the state at the extremities has actually changed
// or not. If they haven't then we won't bother doing all of the // or not. If they haven't then we won't bother doing all of the

View file

@ -77,7 +77,7 @@ func (r *Backfiller) PerformBackfill(
} }
// Scan the event tree for events to send back. // Scan the event tree for events to send back.
resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, *info, front, visited, request.Limit, request.ServerName) resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
if err != nil { if err != nil {
return err return err
} }
@ -418,7 +418,7 @@ FindSuccessor:
return nil return nil
} }
stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, *info, NIDs[eventID]) stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID])
if err != nil { if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
return nil return nil

View file

@ -79,7 +79,7 @@ func (r *InboundPeeker) PerformInboundPeek(
response.LatestEvent = sortedLatestEvents[0].Headered(info.RoomVersion) response.LatestEvent = sortedLatestEvents[0].Headered(info.RoomVersion)
// XXX: do we actually need to do a state resolution here? // XXX: do we actually need to do a state resolution here?
roomState := state.NewStateResolution(r.DB, *info) roomState := state.NewStateResolution(r.DB, info)
var stateEntries []types.StateEntry var stateEntries []types.StateEntry
stateEntries, err = roomState.LoadStateAtSnapshot( stateEntries, err = roomState.LoadStateAtSnapshot(

View file

@ -232,7 +232,7 @@ func buildInviteStrippedState(
StateKey: "", StateKey: "",
}) })
} }
roomState := state.NewStateResolution(db, *info) roomState := state.NewStateResolution(db, info)
stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples(
ctx, info.StateSnapshotNID, stateWanted, ctx, info.StateSnapshotNID, stateWanted,
) )

View file

@ -63,7 +63,7 @@ func (r *Queryer) QueryStateAfterEvents(
return nil return nil
} }
roomState := state.NewStateResolution(r.DB, *info) roomState := state.NewStateResolution(r.DB, info)
response.RoomExists = true response.RoomExists = true
response.RoomVersion = info.RoomVersion response.RoomVersion = info.RoomVersion
@ -294,7 +294,7 @@ func (r *Queryer) QueryMembershipsForRoom(
events, err = r.DB.Events(ctx, eventNIDs) events, err = r.DB.Events(ctx, eventNIDs)
} else { } else {
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, *info, membershipEventNID) stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
if err != nil { if err != nil {
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
return err return err
@ -377,7 +377,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID) return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID)
} }
response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent( response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent(
ctx, r.DB, *info, request.EventID, request.ServerName, inRoomRes.IsInRoom, ctx, r.DB, info, request.EventID, request.ServerName, inRoomRes.IsInRoom,
) )
return return
} }
@ -416,7 +416,7 @@ func (r *Queryer) QueryMissingEvents(
return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID()) return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID())
} }
resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, *info, front, visited, request.Limit, request.ServerName) resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
if err != nil { if err != nil {
return err return err
} }
@ -458,7 +458,7 @@ func (r *Queryer) QueryStateAndAuthChain(
response.RoomVersion = info.RoomVersion response.RoomVersion = info.RoomVersion
var stateEvents []*gomatrixserverlib.Event var stateEvents []*gomatrixserverlib.Event
stateEvents, err = r.loadStateAtEventIDs(ctx, *info, request.PrevEventIDs) stateEvents, err = r.loadStateAtEventIDs(ctx, info, request.PrevEventIDs)
if err != nil { if err != nil {
return err return err
} }
@ -497,7 +497,7 @@ func (r *Queryer) QueryStateAndAuthChain(
return err return err
} }
func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, error) { func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, error) {
roomState := state.NewStateResolution(r.DB, roomInfo) roomState := state.NewStateResolution(r.DB, roomInfo)
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
if err != nil { if err != nil {

View file

@ -32,11 +32,11 @@ import (
type StateResolution struct { type StateResolution struct {
db storage.Database db storage.Database
roomInfo types.RoomInfo roomInfo *types.RoomInfo
events map[types.EventNID]*gomatrixserverlib.Event events map[types.EventNID]*gomatrixserverlib.Event
} }
func NewStateResolution(db storage.Database, roomInfo types.RoomInfo) StateResolution { func NewStateResolution(db storage.Database, roomInfo *types.RoomInfo) StateResolution {
return StateResolution{ return StateResolution{
db: db, db: db,
roomInfo: roomInfo, roomInfo: roomInfo,