Move context

This commit is contained in:
Till Faelligen 2023-11-23 20:34:48 +01:00
parent 0f74cbfb27
commit 7db3e9f689
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E

View file

@ -21,16 +21,15 @@ import (
"fmt" "fmt"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
// updateLatestEvents updates the list of latest events for this room in the database and writes the // updateLatestEvents updates the list of latest events for this room in the database and writes the
@ -71,7 +70,6 @@ func (r *Inputer) updateLatestEvents(
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
u := latestEventsUpdater{ u := latestEventsUpdater{
ctx: ctx,
api: r, api: r,
updater: updater, updater: updater,
stateAtEvent: stateAtEvent, stateAtEvent: stateAtEvent,
@ -80,12 +78,12 @@ func (r *Inputer) updateLatestEvents(
} }
var updates []api.OutputEvent var updates []api.OutputEvent
updates, err = u.doUpdateLatestEvents(roomInfo) updates, err = u.doUpdateLatestEvents(ctx, roomInfo)
if err != nil { if err != nil {
return fmt.Errorf("u.doUpdateLatestEvents: %w", err) return fmt.Errorf("u.doUpdateLatestEvents: %w", err)
} }
update, err := u.makeOutputNewRoomEvent(transactionID, sendAsServer, updater.LastEventIDSent(), historyVisibility) update, err := u.makeOutputNewRoomEvent(ctx, transactionID, sendAsServer, updater.LastEventIDSent(), historyVisibility)
if err != nil { if err != nil {
return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err) return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err)
} }
@ -118,7 +116,6 @@ func (r *Inputer) updateLatestEvents(
// The state could be passed using function arguments, but it becomes impractical // The state could be passed using function arguments, but it becomes impractical
// when there are so many variables to pass around. // when there are so many variables to pass around.
type latestEventsUpdater struct { type latestEventsUpdater struct {
ctx context.Context
api *Inputer api *Inputer
updater *shared.RoomUpdater updater *shared.RoomUpdater
stateAtEvent types.StateAtEvent stateAtEvent types.StateAtEvent
@ -140,7 +137,7 @@ type latestEventsUpdater struct {
newStateNID types.StateSnapshotNID newStateNID types.StateSnapshotNID
} }
func (u *latestEventsUpdater) doUpdateLatestEvents(roomInfo *types.RoomInfo) ([]api.OutputEvent, error) { func (u *latestEventsUpdater) doUpdateLatestEvents(ctx context.Context, roomInfo *types.RoomInfo) ([]api.OutputEvent, error) {
// If we are doing a regular event update then we will get the // If we are doing a regular event update then we will get the
// previous latest events to use as a part of the calculation. If // previous latest events to use as a part of the calculation. If
// we are overwriting the latest events because we have a complete // we are overwriting the latest events because we have a complete
@ -164,6 +161,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents(roomInfo *types.RoomInfo) ([]
// Work out what the latest events are. This will include the new // Work out what the latest events are. This will include the new
// event if it is not already referenced. // event if it is not already referenced.
extremitiesChanged, err := u.calculateLatest( extremitiesChanged, err := u.calculateLatest(
ctx,
u.oldLatest, u.event, u.oldLatest, u.event,
types.StateAtEventAndReference{ types.StateAtEventAndReference{
EventID: u.event.EventID(), EventID: u.event.EventID(),
@ -178,13 +176,13 @@ func (u *latestEventsUpdater) doUpdateLatestEvents(roomInfo *types.RoomInfo) ([]
// latest state. // latest state.
var membershipUpdates []api.OutputEvent var membershipUpdates []api.OutputEvent
if extremitiesChanged || u.rewritesState { if extremitiesChanged || u.rewritesState {
if err = u.latestState(roomInfo); err != nil { if err = u.latestState(ctx, roomInfo); err != nil {
return nil, fmt.Errorf("u.latestState: %w", err) return nil, fmt.Errorf("u.latestState: %w", err)
} }
// If we need to generate any output events then here's where we do it. // If we need to generate any output events then here's where we do it.
// TODO: Move this! // TODO: Move this!
if membershipUpdates, err = u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added); err != nil { if membershipUpdates, err = u.api.updateMemberships(ctx, u.updater, u.removed, u.added); err != nil {
return nil, fmt.Errorf("u.api.updateMemberships: %w", err) return nil, fmt.Errorf("u.api.updateMemberships: %w", err)
} }
} else { } else {
@ -198,8 +196,8 @@ func (u *latestEventsUpdater) doUpdateLatestEvents(roomInfo *types.RoomInfo) ([]
return membershipUpdates, nil return membershipUpdates, nil
} }
func (u *latestEventsUpdater) latestState(roomInfo *types.RoomInfo) error { func (u *latestEventsUpdater) latestState(ctx context.Context, roomInfo *types.RoomInfo) error {
trace, ctx := internal.StartRegion(u.ctx, "processEventWithMissingState") trace, ctx := internal.StartRegion(ctx, "processEventWithMissingState")
defer trace.EndRegion() defer trace.EndRegion()
var err error var err error
@ -315,11 +313,12 @@ func (u *latestEventsUpdater) latestState(roomInfo *types.RoomInfo) error {
// calculateLatest works out the new set of forward extremities. Returns // calculateLatest works out the new set of forward extremities. Returns
// true if the new event is included in those extremites, false otherwise. // true if the new event is included in those extremites, false otherwise.
func (u *latestEventsUpdater) calculateLatest( func (u *latestEventsUpdater) calculateLatest(
ctx context.Context,
oldLatest []types.StateAtEventAndReference, oldLatest []types.StateAtEventAndReference,
newEvent gomatrixserverlib.PDU, newEvent gomatrixserverlib.PDU,
newStateAndRef types.StateAtEventAndReference, newStateAndRef types.StateAtEventAndReference,
) (bool, error) { ) (bool, error) {
trace, _ := internal.StartRegion(u.ctx, "calculateLatest") trace, _ := internal.StartRegion(ctx, "calculateLatest")
defer trace.EndRegion() defer trace.EndRegion()
// First of all, get a list of all of the events in our current // First of all, get a list of all of the events in our current
@ -377,6 +376,7 @@ func (u *latestEventsUpdater) calculateLatest(
} }
func (u *latestEventsUpdater) makeOutputNewRoomEvent( func (u *latestEventsUpdater) makeOutputNewRoomEvent(
ctx context.Context,
transactionID *api.TransactionID, transactionID *api.TransactionID,
sendAsServer string, sendAsServer string,
lastEventIDSent string, lastEventIDSent string,
@ -397,7 +397,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent(
HistoryVisibility: historyVisibility, HistoryVisibility: historyVisibility,
} }
eventIDMap, err := u.stateEventMap() eventIDMap, err := u.stateEventMap(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -421,7 +421,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent(
} }
// retrieve an event nid -> event ID map for all events that need updating // retrieve an event nid -> event ID map for all events that need updating
func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error) { func (u *latestEventsUpdater) stateEventMap(ctx context.Context) (map[types.EventNID]string, error) {
cap := len(u.added) + len(u.removed) + len(u.stateBeforeEventRemoves) + len(u.stateBeforeEventAdds) cap := len(u.added) + len(u.removed) + len(u.stateBeforeEventRemoves) + len(u.stateBeforeEventAdds)
stateEventNIDs := make(types.EventNIDs, 0, cap) stateEventNIDs := make(types.EventNIDs, 0, cap)
allStateEntries := make([]types.StateEntry, 0, cap) allStateEntries := make([]types.StateEntry, 0, cap)
@ -433,5 +433,5 @@ func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error)
stateEventNIDs = append(stateEventNIDs, entry.EventNID) stateEventNIDs = append(stateEventNIDs, entry.EventNID)
} }
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(stateEventNIDs)] stateEventNIDs = stateEventNIDs[:util.SortAndUnique(stateEventNIDs)]
return u.updater.EventIDs(u.ctx, stateEventNIDs) return u.updater.EventIDs(ctx, stateEventNIDs)
} }