diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 1f4215e74..ddda8081c 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -56,7 +56,7 @@ func CheckForSoftFail( // Then get the state entries for the current state snapshot. // We'll use this to check if the event is allowed right now. - roomState := state.NewStateResolution(db, *roomInfo) + roomState := state.NewStateResolution(db, roomInfo) authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) if err != nil { return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err) diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index a389cc898..78a875c76 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -179,7 +179,7 @@ func GetMembershipsAtState( return events, nil } -func StateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { +func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { roomState := state.NewStateResolution(db, info) // Lookup the event NID eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) @@ -223,7 +223,7 @@ func LoadStateEvents( } func CheckServerAllowedToSeeEvent( - ctx context.Context, db storage.Database, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, + ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, ) (bool, error) { roomState := state.NewStateResolution(db, info) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) @@ -279,7 +279,7 @@ func CheckServerAllowedToSeeEvent( // TODO: Remove this when we have tests to assert correctness of this function func ScanEventTree( - ctx context.Context, db storage.Database, info types.RoomInfo, front []string, visited map[string]bool, limit int, + ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int, serverName gomatrixserverlib.ServerName, ) ([]types.EventNID, error) { var resultNIDs []types.EventNID @@ -387,7 +387,7 @@ func QueryLatestEventsAndState( return nil } - roomState := state.NewStateResolution(db, *roomInfo) + roomState := state.NewStateResolution(db, roomInfo) response.RoomExists = true response.RoomVersion = roomInfo.RoomVersion diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 6ca549a3c..db631b632 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -185,8 +185,7 @@ func (r *Inputer) processRoomEvent( haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{}, } if err = missingState.processEventWithMissingState(ctx, input.Event.Unwrap(), input.Event.RoomVersion); err != nil { - //return fmt.Errorf("r.checkForMissingPrevEvents: %w", err) - softfail = true + isRejected = true rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err) } } @@ -225,7 +224,7 @@ func (r *Inputer) processRoomEvent( if stateAtEvent.BeforeStateSnapshotNID == 0 { // We haven't calculated a state for this event yet. // Lets calculate one. - err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event, isRejected) + err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected) if err != nil && input.Kind != api.KindOld { return fmt.Errorf("r.calculateAndSetState: %w", err) } @@ -233,7 +232,7 @@ func (r *Inputer) processRoomEvent( // We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it. if isRejected || softfail { - logger.WithField("soft_fail", softfail).WithField("reason", rejectionErr).Debug("Stored rejected event") + logger.WithError(rejectionErr).WithField("soft_fail", softfail).Debug("Stored rejected event") return rejectionErr } @@ -395,7 +394,7 @@ func (r *Inputer) checkForMissingAuthEvents( func (r *Inputer) calculateAndSetState( ctx context.Context, input *api.InputRoomEvent, - roomInfo types.RoomInfo, + roomInfo *types.RoomInfo, stateAtEvent *types.StateAtEvent, event *gomatrixserverlib.Event, isRejected bool, diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index c9264a27d..6137941e1 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -199,7 +199,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error - roomState := state.NewStateResolution(u.api.DB, *u.roomInfo) + roomState := state.NewStateResolution(u.api.DB, u.roomInfo) // Work out if the state at the extremities has actually changed // or not. If they haven't then we won't bother doing all of the diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index e198f67d8..f3623de8a 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -77,7 +77,7 @@ func (r *Backfiller) PerformBackfill( } // Scan the event tree for events to send back. - resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, *info, front, visited, request.Limit, request.ServerName) + resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) if err != nil { return err } @@ -418,7 +418,7 @@ FindSuccessor: return nil } - stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, *info, NIDs[eventID]) + stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID]) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") return nil diff --git a/roomserver/internal/perform/perform_inbound_peek.go b/roomserver/internal/perform/perform_inbound_peek.go index 98f5f6f96..d19fc8386 100644 --- a/roomserver/internal/perform/perform_inbound_peek.go +++ b/roomserver/internal/perform/perform_inbound_peek.go @@ -79,7 +79,7 @@ func (r *InboundPeeker) PerformInboundPeek( response.LatestEvent = sortedLatestEvents[0].Headered(info.RoomVersion) // XXX: do we actually need to do a state resolution here? - roomState := state.NewStateResolution(r.DB, *info) + roomState := state.NewStateResolution(r.DB, info) var stateEntries []types.StateEntry stateEntries, err = roomState.LoadStateAtSnapshot( diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 3083d6537..6aa2bf26d 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -232,7 +232,7 @@ func buildInviteStrippedState( StateKey: "", }) } - roomState := state.NewStateResolution(db, *info) + roomState := state.NewStateResolution(db, info) stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( ctx, info.StateSnapshotNID, stateWanted, ) diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index b62184e2e..c9978de1d 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -63,7 +63,7 @@ func (r *Queryer) QueryStateAfterEvents( return nil } - roomState := state.NewStateResolution(r.DB, *info) + roomState := state.NewStateResolution(r.DB, info) response.RoomExists = true response.RoomVersion = info.RoomVersion @@ -294,7 +294,7 @@ func (r *Queryer) QueryMembershipsForRoom( events, err = r.DB.Events(ctx, eventNIDs) } else { - stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, *info, membershipEventNID) + stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID) if err != nil { logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") return err @@ -377,7 +377,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent( return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID) } response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent( - ctx, r.DB, *info, request.EventID, request.ServerName, inRoomRes.IsInRoom, + ctx, r.DB, info, request.EventID, request.ServerName, inRoomRes.IsInRoom, ) return } @@ -416,7 +416,7 @@ func (r *Queryer) QueryMissingEvents( return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID()) } - resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, *info, front, visited, request.Limit, request.ServerName) + resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) if err != nil { return err } @@ -458,7 +458,7 @@ func (r *Queryer) QueryStateAndAuthChain( response.RoomVersion = info.RoomVersion var stateEvents []*gomatrixserverlib.Event - stateEvents, err = r.loadStateAtEventIDs(ctx, *info, request.PrevEventIDs) + stateEvents, err = r.loadStateAtEventIDs(ctx, info, request.PrevEventIDs) if err != nil { return err } @@ -497,7 +497,7 @@ func (r *Queryer) QueryStateAndAuthChain( return err } -func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, error) { +func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, error) { roomState := state.NewStateResolution(r.DB, roomInfo) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 78398fc7c..15d592b46 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -32,11 +32,11 @@ import ( type StateResolution struct { db storage.Database - roomInfo types.RoomInfo + roomInfo *types.RoomInfo events map[types.EventNID]*gomatrixserverlib.Event } -func NewStateResolution(db storage.Database, roomInfo types.RoomInfo) StateResolution { +func NewStateResolution(db storage.Database, roomInfo *types.RoomInfo) StateResolution { return StateResolution{ db: db, roomInfo: roomInfo,