diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/events.go b/src/github.com/matrix-org/dendrite/roomserver/input/events.go index b498cc058..5fd45ccca 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/events.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/events.go @@ -105,24 +105,7 @@ func processRoomEvent( if stateAtEvent.BeforeStateSnapshotNID == 0 { // We haven't calculated a state for this event yet. // Lets calculate one. - if input.HasState { - // We've been told what the state at the event is so we don't need to calculate it. - // Check that those state events are in the database and store the state. - var entries []types.StateEntry - if entries, err = db.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { - return err - } - - if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(ctx, roomNID, nil, entries); err != nil { - return err - } - } else { - // We haven't been told what the state at the event is so we need to calculate it from the prev_events - if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(ctx, db, event, roomNID); err != nil { - return err - } - } - err = db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) + err = calculateAndSetState(ctx, db, input, roomNID, &stateAtEvent, event) if err != nil { return err } @@ -137,6 +120,35 @@ func processRoomEvent( return updateLatestEvents(ctx, db, ow, roomNID, stateAtEvent, event, input.SendAsServer, input.TransactionID) } +func calculateAndSetState( + ctx context.Context, + db RoomEventDatabase, + input api.InputRoomEvent, + roomNID types.RoomNID, + stateAtEvent *types.StateAtEvent, + event gomatrixserverlib.Event, +) error { + var err error + if input.HasState { + // We've been told what the state at the event is so we don't need to calculate it. + // Check that those state events are in the database and store the state. + var entries []types.StateEntry + if entries, err = db.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { + return err + } + + if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(ctx, roomNID, nil, entries); err != nil { + return err + } + } else { + // We haven't been told what the state at the event is so we need to calculate it from the prev_events + if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(ctx, db, event, roomNID); err != nil { + return err + } + } + return db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) +} + func processInviteEvent( ctx context.Context, db RoomEventDatabase,