Optimise loadAuthEvents, add roomserver tracing

This commit is contained in:
Neil Alexander 2022-06-07 14:23:26 +01:00
parent aafb7bf120
commit 27948fb304
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
5 changed files with 228 additions and 51 deletions

View file

@ -33,6 +33,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -75,6 +76,11 @@ func (r *Inputer) processRoomEvent(
default: default:
} }
span, ctx := opentracing.StartSpanFromContext(ctx, "processRoomEvent")
span.SetTag("room_id", input.Event.RoomID())
span.SetTag("event_id", input.Event.EventID())
defer span.Finish()
// Measure how long it takes to process this event. // Measure how long it takes to process this event.
started := time.Now() started := time.Now()
defer func() { defer func() {
@ -411,6 +417,9 @@ func (r *Inputer) fetchAuthEvents(
known map[string]*types.Event, known map[string]*types.Event,
servers []gomatrixserverlib.ServerName, servers []gomatrixserverlib.ServerName,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "fetchAuthEvents")
defer span.Finish()
unknown := map[string]struct{}{} unknown := map[string]struct{}{}
authEventIDs := event.AuthEventIDs() authEventIDs := event.AuthEventIDs()
if len(authEventIDs) == 0 { if len(authEventIDs) == 0 {
@ -526,6 +535,9 @@ func (r *Inputer) calculateAndSetState(
event *gomatrixserverlib.Event, event *gomatrixserverlib.Event,
isRejected bool, isRejected bool,
) error { ) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "calculateAndSetState")
defer span.Finish()
var succeeded bool var succeeded bool
updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) updater, err := r.DB.GetRoomUpdater(ctx, roomInfo)
if err != nil { if err != nil {

View file

@ -27,6 +27,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/opentracing/opentracing-go"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -56,6 +57,9 @@ func (r *Inputer) updateLatestEvents(
transactionID *api.TransactionID, transactionID *api.TransactionID,
rewritesState bool, rewritesState bool,
) (err error) { ) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "updateLatestEvents")
defer span.Finish()
var succeeded bool var succeeded bool
updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) updater, err := r.DB.GetRoomUpdater(ctx, roomInfo)
if err != nil { if err != nil {
@ -200,6 +204,9 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
} }
func (u *latestEventsUpdater) latestState() error { func (u *latestEventsUpdater) latestState() error {
span, ctx := opentracing.StartSpanFromContext(u.ctx, "processEventWithMissingState")
defer span.Finish()
var err error var err error
roomState := state.NewStateResolution(u.updater, u.roomInfo) roomState := state.NewStateResolution(u.updater, u.roomInfo)
@ -246,7 +253,7 @@ func (u *latestEventsUpdater) latestState() error {
// of the state after the events. The snapshot state will be resolved // of the state after the events. The snapshot state will be resolved
// using the correct state resolution algorithm for the room. // using the correct state resolution algorithm for the room.
u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents( u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents(
u.ctx, latestStateAtEvents, ctx, latestStateAtEvents,
) )
if err != nil { if err != nil {
return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err) return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err)
@ -258,7 +265,7 @@ func (u *latestEventsUpdater) latestState() error {
// another list of added ones. Replacing a value for a state-key tuple // another list of added ones. Replacing a value for a state-key tuple
// will result one removed (the old event) and one added (the new event). // will result one removed (the old event) and one added (the new event).
u.removed, u.added, err = roomState.DifferenceBetweeenStateSnapshots( u.removed, u.added, err = roomState.DifferenceBetweeenStateSnapshots(
u.ctx, u.oldStateNID, u.newStateNID, ctx, u.oldStateNID, u.newStateNID,
) )
if err != nil { if err != nil {
return fmt.Errorf("roomState.DifferenceBetweenStateSnapshots: %w", err) return fmt.Errorf("roomState.DifferenceBetweenStateSnapshots: %w", err)
@ -278,7 +285,7 @@ func (u *latestEventsUpdater) latestState() error {
// Also work out the state before the event removes and the event // Also work out the state before the event removes and the event
// adds. // adds.
u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots( u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots(
u.ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID, ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
) )
if err != nil { if err != nil {
return fmt.Errorf("roomState.DifferenceBetweeenStateSnapshots: %w", err) return fmt.Errorf("roomState.DifferenceBetweeenStateSnapshots: %w", err)
@ -294,6 +301,9 @@ func (u *latestEventsUpdater) calculateLatest(
newEvent *gomatrixserverlib.Event, newEvent *gomatrixserverlib.Event,
newStateAndRef types.StateAtEventAndReference, newStateAndRef types.StateAtEventAndReference,
) (bool, error) { ) (bool, error) {
span, _ := opentracing.StartSpanFromContext(u.ctx, "calculateLatest")
defer span.Finish()
// First of all, get a list of all of the events in our current // First of all, get a list of all of the events in our current
// set of forward extremities. // set of forward extremities.
existingRefs := make(map[string]*types.StateAtEventAndReference) existingRefs := make(map[string]*types.StateAtEventAndReference)

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/opentracing/opentracing-go"
) )
// updateMembership updates the current membership and the invites for each // updateMembership updates the current membership and the invites for each
@ -34,6 +35,9 @@ func (r *Inputer) updateMemberships(
updater *shared.RoomUpdater, updater *shared.RoomUpdater,
removed, added []types.StateEntry, removed, added []types.StateEntry,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "updateMemberships")
defer span.Finish()
changes := membershipChanges(removed, added) changes := membershipChanges(removed, added)
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
for _, change := range changes { for _, change := range changes {

View file

@ -15,6 +15,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/opentracing/opentracing-go"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -59,6 +60,9 @@ type missingStateReq struct {
func (t *missingStateReq) processEventWithMissingState( func (t *missingStateReq) processEventWithMissingState(
ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion,
) (*parsedRespState, error) { ) (*parsedRespState, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "processEventWithMissingState")
defer span.Finish()
// We are missing the previous events for this events. // We are missing the previous events for this events.
// This means that there is a gap in our view of the history of the // This means that there is a gap in our view of the history of the
// room. There two ways that we can handle such a gap: // room. There two ways that we can handle such a gap:
@ -235,6 +239,9 @@ func (t *missingStateReq) processEventWithMissingState(
} }
func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (*parsedRespState, error) { func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (*parsedRespState, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "lookupResolvedStateBeforeEvent")
defer span.Finish()
type respState struct { type respState struct {
// A snapshot is considered trustworthy if it came from our own roomserver. // A snapshot is considered trustworthy if it came from our own roomserver.
// That's because the state will have been through state resolution once // That's because the state will have been through state resolution once
@ -310,6 +317,9 @@ func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e
// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event)
// added into the mix. // added into the mix.
func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*parsedRespState, bool, error) { func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*parsedRespState, bool, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "lookupStateAfterEvent")
defer span.Finish()
// try doing all this locally before we resort to querying federation // try doing all this locally before we resort to querying federation
respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID) respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID)
if respState != nil { if respState != nil {
@ -361,6 +371,9 @@ func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.Event) *gomatrixs
} }
func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *parsedRespState { func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *parsedRespState {
span, ctx := opentracing.StartSpanFromContext(ctx, "lookupStateAfterEventLocally")
defer span.Finish()
var res parsedRespState var res parsedRespState
roomInfo, err := t.db.RoomInfo(ctx, roomID) roomInfo, err := t.db.RoomInfo(ctx, roomID)
if err != nil { if err != nil {
@ -435,12 +448,17 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, room
// the server supports. // the server supports.
func (t *missingStateReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( func (t *missingStateReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (
*parsedRespState, error) { *parsedRespState, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "lookupStateBeforeEvent")
defer span.Finish()
// Attempt to fetch the missing state using /state_ids and /events // Attempt to fetch the missing state using /state_ids and /events
return t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion) return t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion)
} }
func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*parsedRespState, backwardsExtremity *gomatrixserverlib.Event) (*parsedRespState, error) { func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*parsedRespState, backwardsExtremity *gomatrixserverlib.Event) (*parsedRespState, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "resolveStatesAndCheck")
defer span.Finish()
var authEventList []*gomatrixserverlib.Event var authEventList []*gomatrixserverlib.Event
var stateEventList []*gomatrixserverlib.Event var stateEventList []*gomatrixserverlib.Event
for _, state := range states { for _, state := range states {
@ -484,6 +502,9 @@ retryAllowedState:
// get missing events for `e`. If `isGapFilled`=true then `newEvents` contains all the events to inject, // get missing events for `e`. If `isGapFilled`=true then `newEvents` contains all the events to inject,
// without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events // without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events
func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled, prevStateKnown bool, err error) { func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled, prevStateKnown bool, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "getMissingEvents")
defer span.Finish()
logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID())
latest, _, _, err := t.db.LatestEventIDs(ctx, t.roomInfo.RoomNID) latest, _, _, err := t.db.LatestEventIDs(ctx, t.roomInfo.RoomNID)
if err != nil { if err != nil {
@ -608,6 +629,9 @@ func (t *missingStateReq) isPrevStateKnown(ctx context.Context, e *gomatrixserve
func (t *missingStateReq) lookupMissingStateViaState( func (t *missingStateReq) lookupMissingStateViaState(
ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (respState *parsedRespState, err error) { ) (respState *parsedRespState, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaState")
defer span.Finish()
state, err := t.federation.LookupState(ctx, t.origin, roomID, eventID, roomVersion) state, err := t.federation.LookupState(ctx, t.origin, roomID, eventID, roomVersion)
if err != nil { if err != nil {
return nil, err return nil, err
@ -637,6 +661,9 @@ func (t *missingStateReq) lookupMissingStateViaState(
func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
*parsedRespState, error) { *parsedRespState, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaStateIDs")
defer span.Finish()
util.GetLogger(ctx).WithField("room_id", roomID).Infof("lookupMissingStateViaStateIDs %s", eventID) util.GetLogger(ctx).WithField("room_id", roomID).Infof("lookupMissingStateViaStateIDs %s", eventID)
// fetch the state event IDs at the time of the event // fetch the state event IDs at the time of the event
stateIDs, err := t.federation.LookupStateIDs(ctx, t.origin, roomID, eventID) stateIDs, err := t.federation.LookupStateIDs(ctx, t.origin, roomID, eventID)
@ -799,6 +826,9 @@ func (t *missingStateReq) createRespStateFromStateIDs(
} }
func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.Event, error) { func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.Event, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "lookupEvent")
defer span.Finish()
if localFirst { if localFirst {
// fetch from the roomserver // fetch from the roomserver
events, err := t.db.EventsFromIDs(ctx, []string{missingEventID}) events, err := t.db.EventsFromIDs(ctx, []string{missingEventID})

View file

@ -20,9 +20,11 @@ import (
"context" "context"
"fmt" "fmt"
"sort" "sort"
"sync"
"time" "time"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
@ -62,6 +64,9 @@ func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) Sta
func (v *StateResolution) LoadStateAtSnapshot( func (v *StateResolution) LoadStateAtSnapshot(
ctx context.Context, stateNID types.StateSnapshotNID, ctx context.Context, stateNID types.StateSnapshotNID,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtSnapshot")
defer span.Finish()
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
if err != nil { if err != nil {
return nil, err return nil, err
@ -100,6 +105,9 @@ func (v *StateResolution) LoadStateAtSnapshot(
func (v *StateResolution) LoadStateAtEvent( func (v *StateResolution) LoadStateAtEvent(
ctx context.Context, eventID string, ctx context.Context, eventID string,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtEvent")
defer span.Finish()
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
if err != nil { if err != nil {
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %s", eventID, err) return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %s", eventID, err)
@ -122,6 +130,9 @@ func (v *StateResolution) LoadStateAtEvent(
func (v *StateResolution) LoadCombinedStateAfterEvents( func (v *StateResolution) LoadCombinedStateAfterEvents(
ctx context.Context, prevStates []types.StateAtEvent, ctx context.Context, prevStates []types.StateAtEvent,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadCombinedStateAfterEvents")
defer span.Finish()
stateNIDs := make([]types.StateSnapshotNID, len(prevStates)) stateNIDs := make([]types.StateSnapshotNID, len(prevStates))
for i, state := range prevStates { for i, state := range prevStates {
stateNIDs[i] = state.BeforeStateSnapshotNID stateNIDs[i] = state.BeforeStateSnapshotNID
@ -194,6 +205,9 @@ func (v *StateResolution) LoadCombinedStateAfterEvents(
func (v *StateResolution) DifferenceBetweeenStateSnapshots( func (v *StateResolution) DifferenceBetweeenStateSnapshots(
ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID, ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID,
) (removed, added []types.StateEntry, err error) { ) (removed, added []types.StateEntry, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.DifferenceBetweeenStateSnapshots")
defer span.Finish()
if oldStateNID == newStateNID { if oldStateNID == newStateNID {
// If the snapshot NIDs are the same then nothing has changed // If the snapshot NIDs are the same then nothing has changed
return nil, nil, nil return nil, nil, nil
@ -255,6 +269,9 @@ func (v *StateResolution) LoadStateAtSnapshotForStringTuples(
stateNID types.StateSnapshotNID, stateNID types.StateSnapshotNID,
stateKeyTuples []gomatrixserverlib.StateKeyTuple, stateKeyTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtSnapshotForStringTuples")
defer span.Finish()
numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples)
if err != nil { if err != nil {
return nil, err return nil, err
@ -269,6 +286,9 @@ func (v *StateResolution) stringTuplesToNumericTuples(
ctx context.Context, ctx context.Context,
stringTuples []gomatrixserverlib.StateKeyTuple, stringTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateKeyTuple, error) { ) ([]types.StateKeyTuple, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.stringTuplesToNumericTuples")
defer span.Finish()
eventTypes := make([]string, len(stringTuples)) eventTypes := make([]string, len(stringTuples))
stateKeys := make([]string, len(stringTuples)) stateKeys := make([]string, len(stringTuples))
for i := range stringTuples { for i := range stringTuples {
@ -311,6 +331,9 @@ func (v *StateResolution) loadStateAtSnapshotForNumericTuples(
stateNID types.StateSnapshotNID, stateNID types.StateSnapshotNID,
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateAtSnapshotForNumericTuples")
defer span.Finish()
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
if err != nil { if err != nil {
return nil, err return nil, err
@ -359,6 +382,9 @@ func (v *StateResolution) LoadStateAfterEventsForStringTuples(
prevStates []types.StateAtEvent, prevStates []types.StateAtEvent,
stateKeyTuples []gomatrixserverlib.StateKeyTuple, stateKeyTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAfterEventsForStringTuples")
defer span.Finish()
numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples)
if err != nil { if err != nil {
return nil, err return nil, err
@ -371,6 +397,9 @@ func (v *StateResolution) loadStateAfterEventsForNumericTuples(
prevStates []types.StateAtEvent, prevStates []types.StateAtEvent,
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateAfterEventsForNumericTuples")
defer span.Finish()
if len(prevStates) == 1 { if len(prevStates) == 1 {
// Fast path for a single event. // Fast path for a single event.
prevState := prevStates[0] prevState := prevStates[0]
@ -543,6 +572,9 @@ func (v *StateResolution) CalculateAndStoreStateBeforeEvent(
event *gomatrixserverlib.Event, event *gomatrixserverlib.Event,
isRejected bool, isRejected bool,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.CalculateAndStoreStateBeforeEvent")
defer span.Finish()
// Load the state at the prev events. // Load the state at the prev events.
prevStates, err := v.db.StateAtEventIDs(ctx, event.PrevEventIDs()) prevStates, err := v.db.StateAtEventIDs(ctx, event.PrevEventIDs())
if err != nil { if err != nil {
@ -559,6 +591,9 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents(
ctx context.Context, ctx context.Context,
prevStates []types.StateAtEvent, prevStates []types.StateAtEvent,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.CalculateAndStoreStateAfterEvents")
defer span.Finish()
metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)}
if len(prevStates) == 0 { if len(prevStates) == 0 {
@ -631,6 +666,9 @@ func (v *StateResolution) calculateAndStoreStateAfterManyEvents(
prevStates []types.StateAtEvent, prevStates []types.StateAtEvent,
metrics calculateStateMetrics, metrics calculateStateMetrics,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.calculateAndStoreStateAfterManyEvents")
defer span.Finish()
state, algorithm, conflictLength, err := state, algorithm, conflictLength, err :=
v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates) v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates)
metrics.algorithm = algorithm metrics.algorithm = algorithm
@ -649,6 +687,9 @@ func (v *StateResolution) calculateStateAfterManyEvents(
ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, roomVersion gomatrixserverlib.RoomVersion,
prevStates []types.StateAtEvent, prevStates []types.StateAtEvent,
) (state []types.StateEntry, algorithm string, conflictLength int, err error) { ) (state []types.StateEntry, algorithm string, conflictLength int, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.calculateStateAfterManyEvents")
defer span.Finish()
var combined []types.StateEntry var combined []types.StateEntry
// Conflict resolution. // Conflict resolution.
// First stage: load the state after each of the prev events. // First stage: load the state after each of the prev events.
@ -701,6 +742,9 @@ func (v *StateResolution) resolveConflicts(
ctx context.Context, version gomatrixserverlib.RoomVersion, ctx context.Context, version gomatrixserverlib.RoomVersion,
notConflicted, conflicted []types.StateEntry, notConflicted, conflicted []types.StateEntry,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflicts")
defer span.Finish()
stateResAlgo, err := version.StateResAlgorithm() stateResAlgo, err := version.StateResAlgorithm()
if err != nil { if err != nil {
return nil, err return nil, err
@ -725,6 +769,8 @@ func (v *StateResolution) resolveConflictsV1(
ctx context.Context, ctx context.Context,
notConflicted, conflicted []types.StateEntry, notConflicted, conflicted []types.StateEntry,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflictsV1")
defer span.Finish()
// Load the conflicted events // Load the conflicted events
conflictedEvents, eventIDMap, err := v.loadStateEvents(ctx, conflicted) conflictedEvents, eventIDMap, err := v.loadStateEvents(ctx, conflicted)
@ -788,6 +834,9 @@ func (v *StateResolution) resolveConflictsV2(
ctx context.Context, ctx context.Context,
notConflicted, conflicted []types.StateEntry, notConflicted, conflicted []types.StateEntry,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflictsV2")
defer span.Finish()
estimate := len(conflicted) + len(notConflicted) estimate := len(conflicted) + len(notConflicted)
eventIDMap := make(map[string]types.StateEntry, estimate) eventIDMap := make(map[string]types.StateEntry, estimate)
@ -815,17 +864,29 @@ func (v *StateResolution) resolveConflictsV2(
authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3) authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3)
gotAuthEvents := make(map[string]struct{}, estimate*3) gotAuthEvents := make(map[string]struct{}, estimate*3)
authDifference := make([]*gomatrixserverlib.Event, 0, estimate) authDifference := make([]*gomatrixserverlib.Event, 0, estimate)
knownAuthEvents := make(map[string]types.Event, estimate*3)
// For each conflicted event, let's try and get the needed auth events. // For each conflicted event, let's try and get the needed auth events.
if err = func() error {
span, sctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadAuthEvents")
defer span.Finish()
loader := authEventLoader{
v: v,
lookupFromDB: make([]string, 0, len(conflictedEvents)*3),
lookupFromMem: make([]string, 0, len(conflictedEvents)*3),
lookedUpEvents: make([]types.Event, 0, len(conflictedEvents)*3),
eventMap: map[string]types.Event{},
}
for _, conflictedEvent := range conflictedEvents { for _, conflictedEvent := range conflictedEvents {
// Work out which auth events we need to load. // Work out which auth events we need to load.
key := conflictedEvent.EventID() key := conflictedEvent.EventID()
// Store the newly found auth events in the auth set for this event. // Store the newly found auth events in the auth set for this event.
var authEventMap map[string]types.StateEntry var authEventMap map[string]types.StateEntry
authSets[key], authEventMap, err = v.loadAuthEvents(ctx, conflictedEvent) authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, conflictedEvent, knownAuthEvents)
if err != nil { if err != nil {
return nil, err return err
} }
for k, v := range authEventMap { for k, v := range authEventMap {
eventIDMap[k] = v eventIDMap[k] = v
@ -841,6 +902,10 @@ func (v *StateResolution) resolveConflictsV2(
} }
} }
} }
return nil
}(); err != nil {
return nil, err
}
// Kill the reference to this so that the GC may pick it up, since we no // Kill the reference to this so that the GC may pick it up, since we no
// longer need this after this point. // longer need this after this point.
@ -870,19 +935,29 @@ func (v *StateResolution) resolveConflictsV2(
// Look through all of the auth events that we've been given and work out if // Look through all of the auth events that we've been given and work out if
// there are any events which don't appear in all of the auth sets. If they // there are any events which don't appear in all of the auth sets. If they
// don't then we add them to the auth difference. // don't then we add them to the auth difference.
func() {
span, _ := opentracing.StartSpanFromContext(ctx, "isInAllAuthLists")
defer span.Finish()
for _, event := range authEvents { for _, event := range authEvents {
if !isInAllAuthLists(event) { if !isInAllAuthLists(event) {
authDifference = append(authDifference, event) authDifference = append(authDifference, event)
} }
} }
}()
// Resolve the conflicts. // Resolve the conflicts.
resolvedEvents := gomatrixserverlib.ResolveStateConflictsV2( resolvedEvents := func() []*gomatrixserverlib.Event {
span, _ := opentracing.StartSpanFromContext(ctx, "gomatrixserverlib.ResolveStateConflictsV2")
defer span.Finish()
return gomatrixserverlib.ResolveStateConflictsV2(
conflictedEvents, conflictedEvents,
nonConflictedEvents, nonConflictedEvents,
authEvents, authEvents,
authDifference, authDifference,
) )
}()
// Map from the full events back to numeric state entries. // Map from the full events back to numeric state entries.
for _, resolvedEvent := range resolvedEvents { for _, resolvedEvent := range resolvedEvents {
@ -947,6 +1022,9 @@ func (v *StateResolution) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.E
func (v *StateResolution) loadStateEvents( func (v *StateResolution) loadStateEvents(
ctx context.Context, entries []types.StateEntry, ctx context.Context, entries []types.StateEntry,
) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateEvents")
defer span.Finish()
result := make([]*gomatrixserverlib.Event, 0, len(entries)) result := make([]*gomatrixserverlib.Event, 0, len(entries))
eventEntries := make([]types.StateEntry, 0, len(entries)) eventEntries := make([]types.StateEntry, 0, len(entries))
eventNIDs := make([]types.EventNID, 0, len(entries)) eventNIDs := make([]types.EventNID, 0, len(entries))
@ -975,45 +1053,88 @@ func (v *StateResolution) loadStateEvents(
return result, eventIDMap, nil return result, eventIDMap, nil
} }
type authEventLoader struct {
sync.Mutex
v *StateResolution
lookupFromDB []string // scratch space
lookupFromMem []string // scratch space
lookedUpEvents []types.Event // scratch space
eventMap map[string]types.Event
}
// loadAuthEvents loads all of the auth events for a given event recursively, // loadAuthEvents loads all of the auth events for a given event recursively,
// along with a map that contains state entries for all of the auth events. // along with a map that contains state entries for all of the auth events.
func (v *StateResolution) loadAuthEvents( func (l *authEventLoader) loadAuthEvents(
ctx context.Context, event *gomatrixserverlib.Event, ctx context.Context, event *gomatrixserverlib.Event, eventMap map[string]types.Event,
) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
eventMap := map[string]struct{}{} l.Lock()
var lookup []string defer l.Unlock()
var authEvents []types.Event authEvents := []types.Event{} // our returned list
included := map[string]struct{}{} // dedupes authEvents above
queue := event.AuthEventIDs() queue := event.AuthEventIDs()
for i := 0; i < len(queue); i++ { for i := 0; i < len(queue); i++ {
lookup = lookup[:0] // Reuse the same underlying memory, since it reduces the
// amount of allocations we make the more times we call
// loadAuthEvents.
l.lookupFromDB = l.lookupFromDB[:0]
l.lookupFromMem = l.lookupFromMem[:0]
l.lookedUpEvents = l.lookedUpEvents[:0]
// Separate out the list of events in the queue based on if
// we think we already know the event in memory or not.
for _, authEventID := range queue { for _, authEventID := range queue {
if _, ok := eventMap[authEventID]; ok { if _, ok := included[authEventID]; ok {
continue continue
} }
lookup = append(lookup, authEventID) if _, ok := eventMap[authEventID]; ok {
l.lookupFromMem = append(l.lookupFromMem, authEventID)
} else {
l.lookupFromDB = append(l.lookupFromDB, authEventID)
} }
if len(lookup) == 0 { }
// If there's nothing to do, stop here.
if len(l.lookupFromDB) == 0 && len(l.lookupFromMem) == 0 {
break break
} }
events, err := v.db.EventsFromIDs(ctx, lookup)
// If we need to get events from the database, go and fetch
// those now.
if len(l.lookupFromDB) > 0 {
eventsFromDB, err := l.v.db.EventsFromIDs(ctx, l.lookupFromDB)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err) return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err)
} }
l.lookedUpEvents = append(l.lookedUpEvents, eventsFromDB...)
for _, event := range eventsFromDB {
eventMap[event.EventID()] = event
}
}
// Fill in the gaps with events that we already have in memory.
if len(l.lookupFromMem) > 0 {
for _, eventID := range l.lookupFromMem {
l.lookedUpEvents = append(l.lookedUpEvents, eventMap[eventID])
}
}
// From the events that we've retrieved, work out which auth
// events to look up on the next iteration.
add := map[string]struct{}{} add := map[string]struct{}{}
for _, event := range events { for _, event := range l.lookedUpEvents {
eventMap[event.EventID()] = struct{}{}
authEvents = append(authEvents, event) authEvents = append(authEvents, event)
included[event.EventID()] = struct{}{}
for _, authEventID := range event.AuthEventIDs() { for _, authEventID := range event.AuthEventIDs() {
if _, ok := eventMap[authEventID]; ok { if _, ok := included[authEventID]; ok {
continue continue
} }
add[authEventID] = struct{}{} add[authEventID] = struct{}{}
} }
}
for authEventID := range add { for authEventID := range add {
queue = append(queue, authEventID) queue = append(queue, authEventID)
} }
} }
}
authEventTypes := map[string]struct{}{} authEventTypes := map[string]struct{}{}
authEventStateKeys := map[string]struct{}{} authEventStateKeys := map[string]struct{}{}
for _, authEvent := range authEvents { for _, authEvent := range authEvents {
@ -1028,11 +1149,11 @@ func (v *StateResolution) loadAuthEvents(
for eventStateKey := range authEventStateKeys { for eventStateKey := range authEventStateKeys {
lookupAuthEventStateKeys = append(lookupAuthEventStateKeys, eventStateKey) lookupAuthEventStateKeys = append(lookupAuthEventStateKeys, eventStateKey)
} }
eventTypes, err := v.db.EventTypeNIDs(ctx, lookupAuthEventTypes) eventTypes, err := l.v.db.EventTypeNIDs(ctx, lookupAuthEventTypes)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("v.db.EventTypeNIDs: %w", err) return nil, nil, fmt.Errorf("v.db.EventTypeNIDs: %w", err)
} }
eventStateKeys, err := v.db.EventStateKeyNIDs(ctx, lookupAuthEventStateKeys) eventStateKeys, err := l.v.db.EventStateKeyNIDs(ctx, lookupAuthEventStateKeys)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("v.db.EventStateKeyNIDs: %w", err) return nil, nil, fmt.Errorf("v.db.EventStateKeyNIDs: %w", err)
} }