diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 9af0bf591..0229f822f 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -20,22 +20,17 @@ import ( "sort" "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) -type checkForAuthAndSoftFailStorage interface { - state.StateResolutionStorage - StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) - RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) -} - // CheckForSoftFail returns true if the event should be soft-failed // and false otherwise. The return error value should be checked before // the soft-fail bool. func CheckForSoftFail( ctx context.Context, - db checkForAuthAndSoftFailStorage, + db storage.Database, event *gomatrixserverlib.HeaderedEvent, stateEventIDs []string, ) (bool, error) { @@ -97,7 +92,7 @@ func CheckForSoftFail( // Returns the numeric IDs for the auth events. func CheckAuthEvents( ctx context.Context, - db checkForAuthAndSoftFailStorage, + db storage.Database, event *gomatrixserverlib.HeaderedEvent, authEventIDs []string, ) ([]types.EventNID, error) { diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 22e4b67a0..178533ded 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -19,7 +19,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "sync" "time" @@ -40,19 +39,6 @@ import ( "github.com/tidwall/gjson" ) -type retryAction int -type commitAction int - -const ( - doNotRetry retryAction = iota - retryLater -) - -const ( - commitTransaction commitAction = iota - rollbackTransaction -) - var keyContentFields = map[string]string{ "m.room.join_rules": "join_rule", "m.room.history_visibility": "history_visibility", @@ -117,8 +103,7 @@ func (r *Inputer) Start() error { _ = msg.InProgress() // resets the acknowledgement wait timer defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - action, err := r.processRoomEventUsingUpdater(r.ProcessContext.Context(), roomID, &inputRoomEvent) - if err != nil { + if err := r.processRoomEvent(r.ProcessContext.Context(), &inputRoomEvent); err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) } @@ -127,11 +112,8 @@ func (r *Inputer) Start() error { "event_id": inputRoomEvent.Event.EventID(), "type": inputRoomEvent.Event.Type(), }).Warn("Roomserver failed to process async event") - } - switch action { - case retryLater: - _ = msg.Nak() - case doNotRetry: + _ = msg.Term() + } else { _ = msg.Ack() } }) @@ -153,37 +135,6 @@ func (r *Inputer) Start() error { return err } -// processRoomEventUsingUpdater opens up a room updater and tries to -// process the event. It returns whether or not we should positively -// or negatively acknowledge the event (i.e. for NATS) and an error -// if it occurred. -func (r *Inputer) processRoomEventUsingUpdater( - ctx context.Context, - roomID string, - inputRoomEvent *api.InputRoomEvent, -) (retryAction, error) { - roomInfo, err := r.DB.RoomInfo(ctx, roomID) - if err != nil { - return doNotRetry, fmt.Errorf("r.DB.RoomInfo: %w", err) - } - updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) - if err != nil { - return retryLater, fmt.Errorf("r.DB.GetRoomUpdater: %w", err) - } - action, err := r.processRoomEvent(ctx, updater, inputRoomEvent) - switch action { - case commitTransaction: - if cerr := updater.Commit(); cerr != nil { - return retryLater, fmt.Errorf("updater.Commit: %w", cerr) - } - case rollbackTransaction: - if rerr := updater.Rollback(); rerr != nil { - return retryLater, fmt.Errorf("updater.Rollback: %w", rerr) - } - } - return doNotRetry, err -} - // InputRoomEvents implements api.RoomserverInternalAPI func (r *Inputer) InputRoomEvents( ctx context.Context, @@ -230,7 +181,7 @@ func (r *Inputer) InputRoomEvents( worker.Act(nil, func() { defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - _, err := r.processRoomEventUsingUpdater(ctx, roomID, &inputRoomEvent) + err := r.processRoomEvent(ctx, &inputRoomEvent) if err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 4e151699e..531d6959e 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -26,10 +26,10 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/hooks" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" - "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -68,15 +68,14 @@ var processRoomEventDuration = prometheus.NewHistogramVec( // nolint:gocyclo func (r *Inputer) processRoomEvent( ctx context.Context, - updater *shared.RoomUpdater, input *api.InputRoomEvent, -) (commitAction, error) { +) error { select { case <-ctx.Done(): // Before we do anything, make sure the context hasn't expired for this pending task. // If it has then we'll give up straight away — it's probably a synchronous input // request and the caller has already given up, but the inbox task was still queued. - return rollbackTransaction, context.DeadlineExceeded + return context.DeadlineExceeded default: } @@ -109,7 +108,7 @@ func (r *Inputer) processRoomEvent( // if we have already got this event then do not process it again, if the input kind is an outlier. // Outliers contain no extra information which may warrant a re-processing. if input.Kind == api.KindOutlier { - evs, err2 := updater.EventsFromIDs(ctx, []string{event.EventID()}) + evs, err2 := r.DB.EventsFromIDs(ctx, []string{event.EventID()}) if err2 == nil && len(evs) == 1 { // check hash matches if we're on early room versions where the event ID was a random string idFormat, err2 := headered.RoomVersion.EventIDFormat() @@ -118,11 +117,11 @@ func (r *Inputer) processRoomEvent( case gomatrixserverlib.EventIDFormatV1: if bytes.Equal(event.EventReference().EventSHA256, evs[0].EventReference().EventSHA256) { logger.Debugf("Already processed event; ignoring") - return rollbackTransaction, nil + return nil } default: logger.Debugf("Already processed event; ignoring") - return rollbackTransaction, nil + return nil } } } @@ -131,17 +130,21 @@ func (r *Inputer) processRoomEvent( // Don't waste time processing the event if the room doesn't exist. // A room entry locally will only be created in response to a create // event. + roomInfo, rerr := r.DB.RoomInfo(ctx, event.RoomID()) + if rerr != nil { + return fmt.Errorf("r.DB.RoomInfo: %w", rerr) + } isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") - if !updater.RoomExists() && !isCreateEvent { - return rollbackTransaction, fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) + if roomInfo == nil && !isCreateEvent { + return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) } var missingAuth, missingPrev bool serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{} if !isCreateEvent { - missingAuthIDs, missingPrevIDs, err := updater.MissingAuthPrevEvents(ctx, event) + missingAuthIDs, missingPrevIDs, err := r.DB.MissingAuthPrevEvents(ctx, event) if err != nil { - return rollbackTransaction, fmt.Errorf("updater.MissingAuthPrevEvents: %w", err) + return fmt.Errorf("updater.MissingAuthPrevEvents: %w", err) } missingAuth = len(missingAuthIDs) > 0 missingPrev = !input.HasState && len(missingPrevIDs) > 0 @@ -153,7 +156,7 @@ func (r *Inputer) processRoomEvent( ExcludeSelf: true, } if err := r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { - return rollbackTransaction, fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) + return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) } // Sort all of the servers into a map so that we can randomise // their order. Then make sure that the input origin and the @@ -182,8 +185,8 @@ func (r *Inputer) processRoomEvent( isRejected := false authEvents := gomatrixserverlib.NewAuthEvents(nil) knownEvents := map[string]*types.Event{} - if err := r.fetchAuthEvents(ctx, updater, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { - return rollbackTransaction, fmt.Errorf("r.fetchAuthEvents: %w", err) + if err := r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + return fmt.Errorf("r.fetchAuthEvents: %w", err) } // Check if the event is allowed by its auth events. If it isn't then @@ -205,12 +208,12 @@ func (r *Inputer) processRoomEvent( // but weren't found. if isRejected { if event.StateKey() != nil { - return commitTransaction, fmt.Errorf( + return fmt.Errorf( "missing auth event %s for state event %s (type %q, state key %q)", authEventID, event.EventID(), event.Type(), *event.StateKey(), ) } else { - return commitTransaction, fmt.Errorf( + return fmt.Errorf( "missing auth event %s for timeline event %s (type %q)", authEventID, event.EventID(), event.Type(), ) @@ -226,7 +229,7 @@ func (r *Inputer) processRoomEvent( // Check that the event passes authentication checks based on the // current room state. var err error - softfail, err = helpers.CheckForSoftFail(ctx, updater, headered, input.StateEventIDs) + softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) if err != nil { logger.WithError(err).Warn("Error authing soft-failed event") } @@ -250,7 +253,8 @@ func (r *Inputer) processRoomEvent( missingState := missingStateReq{ origin: input.Origin, inputer: r, - db: updater, + db: r.DB, + roomInfo: roomInfo, federation: r.FSAPI, keys: r.KeyRing, roomsMu: internal.NewMutexByRoom(), @@ -290,16 +294,16 @@ func (r *Inputer) processRoomEvent( } // Store the event. - _, _, stateAtEvent, redactionEvent, redactedEventID, err := updater.StoreEvent(ctx, event, authEventNIDs, isRejected) + _, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected) if err != nil { - return rollbackTransaction, fmt.Errorf("updater.StoreEvent: %w", err) + return fmt.Errorf("updater.StoreEvent: %w", err) } // if storing this event results in it being redacted then do so. if !isRejected && redactedEventID == event.EventID() { r, rerr := eventutil.RedactEvent(redactionEvent, event) if rerr != nil { - return rollbackTransaction, fmt.Errorf("eventutil.RedactEvent: %w", rerr) + return fmt.Errorf("eventutil.RedactEvent: %w", rerr) } event = r } @@ -310,23 +314,25 @@ func (r *Inputer) processRoomEvent( if input.Kind == api.KindOutlier { logger.Debug("Stored outlier") hooks.Run(hooks.KindNewEventPersisted, headered) - return commitTransaction, nil + return nil } - roomInfo, err := updater.RoomInfo(ctx, event.RoomID()) + // Request the room info again — it's possible that the room has been + // created by now if it didn't exist already. + roomInfo, err = r.DB.RoomInfo(ctx, event.RoomID()) if err != nil { - return rollbackTransaction, fmt.Errorf("updater.RoomInfo: %w", err) + return fmt.Errorf("updater.RoomInfo: %w", err) } if roomInfo == nil { - return rollbackTransaction, fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID()) + return fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID()) } if input.HasState || (!missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0) { // We haven't calculated a state for this event yet. // Lets calculate one. - err = r.calculateAndSetState(ctx, updater, input, roomInfo, &stateAtEvent, event, isRejected) + err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected) if err != nil { - return rollbackTransaction, fmt.Errorf("r.calculateAndSetState: %w", err) + return fmt.Errorf("r.calculateAndSetState: %w", err) } } @@ -337,16 +343,15 @@ func (r *Inputer) processRoomEvent( "missing_prev": missingPrev, }).Warn("Stored rejected event") if rejectionErr != nil { - return commitTransaction, types.RejectedError(rejectionErr.Error()) + return types.RejectedError(rejectionErr.Error()) } - return commitTransaction, nil + return nil } switch input.Kind { case api.KindNew: if err = r.updateLatestEvents( ctx, // context - updater, // room updater roomInfo, // room info for the room being updated stateAtEvent, // state at event (below) event, // event @@ -354,7 +359,7 @@ func (r *Inputer) processRoomEvent( input.TransactionID, // transaction ID input.HasState, // rewrites state? ); err != nil { - return rollbackTransaction, fmt.Errorf("r.updateLatestEvents: %w", err) + return fmt.Errorf("r.updateLatestEvents: %w", err) } case api.KindOld: err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{ @@ -366,7 +371,7 @@ func (r *Inputer) processRoomEvent( }, }) if err != nil { - return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (old): %w", err) + return fmt.Errorf("r.WriteOutputEvents (old): %w", err) } } @@ -385,14 +390,14 @@ func (r *Inputer) processRoomEvent( }, }) if err != nil { - return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (redactions): %w", err) + return fmt.Errorf("r.WriteOutputEvents (redactions): %w", err) } } // Everything was OK — the latest events updater didn't error and // we've sent output events. Finally, generate a hook call. hooks.Run(hooks.KindNewEventPersisted, headered) - return commitTransaction, nil + return nil } // fetchAuthEvents will check to see if any of the @@ -404,7 +409,6 @@ func (r *Inputer) processRoomEvent( // they are now in the database. func (r *Inputer) fetchAuthEvents( ctx context.Context, - updater *shared.RoomUpdater, logger *logrus.Entry, event *gomatrixserverlib.HeaderedEvent, auth *gomatrixserverlib.AuthEvents, @@ -418,7 +422,7 @@ func (r *Inputer) fetchAuthEvents( } for _, authEventID := range authEventIDs { - authEvents, err := updater.EventsFromIDs(ctx, []string{authEventID}) + authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID}) if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { unknown[authEventID] = struct{}{} continue @@ -495,7 +499,7 @@ nextAuthEvent: } // Finally, store the event in the database. - eventNID, _, _, _, _, err := updater.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) + eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) if err != nil { return fmt.Errorf("updater.StoreEvent: %w", err) } @@ -520,14 +524,18 @@ nextAuthEvent: func (r *Inputer) calculateAndSetState( ctx context.Context, - updater *shared.RoomUpdater, input *api.InputRoomEvent, roomInfo *types.RoomInfo, stateAtEvent *types.StateAtEvent, event *gomatrixserverlib.Event, isRejected bool, ) error { - var err error + var succeeded bool + updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) + if err != nil { + return fmt.Errorf("r.DB.GetRoomUpdater: %w", err) + } + defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) roomState := state.NewStateResolution(updater, roomInfo) if input.HasState { @@ -536,7 +544,7 @@ func (r *Inputer) calculateAndSetState( // We've been told what the state at the event is so we don't need to calculate it. // Check that those state events are in the database and store the state. var entries []types.StateEntry - if entries, err = updater.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { + if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err) } entries = types.DeduplicateStateEntries(entries) @@ -557,5 +565,6 @@ func (r *Inputer) calculateAndSetState( if err != nil { return fmt.Errorf("r.DB.SetState: %w", err) } + succeeded = true return nil } diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index ae28ebefa..f4a52031a 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -20,6 +20,7 @@ import ( "context" "fmt" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -47,7 +48,6 @@ import ( // Can only be called once at a time func (r *Inputer) updateLatestEvents( ctx context.Context, - updater *shared.RoomUpdater, roomInfo *types.RoomInfo, stateAtEvent types.StateAtEvent, event *gomatrixserverlib.Event, @@ -55,6 +55,14 @@ func (r *Inputer) updateLatestEvents( transactionID *api.TransactionID, rewritesState bool, ) (err error) { + var succeeded bool + updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) + if err != nil { + return fmt.Errorf("r.DB.GetRoomUpdater: %w", err) + } + + defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) + u := latestEventsUpdater{ ctx: ctx, api: r, @@ -71,6 +79,7 @@ func (r *Inputer) updateLatestEvents( return fmt.Errorf("u.doUpdateLatestEvents: %w", err) } + succeeded = true return } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index fc3be7987..4655e92a9 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -11,7 +11,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" - "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -25,7 +25,8 @@ type parsedRespState struct { type missingStateReq struct { origin gomatrixserverlib.ServerName - db *shared.RoomUpdater + db storage.Database + roomInfo *types.RoomInfo inputer *Inputer keys gomatrixserverlib.JSONVerifier federation fedapi.FederationInternalAPI @@ -80,7 +81,7 @@ func (t *missingStateReq) processEventWithMissingState( // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // in the gap in the DAG for _, newEvent := range newEvents { - _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -139,8 +140,7 @@ func (t *missingStateReq) processEventWithMissingState( }) } for _, ire := range outlierRoomEvents { - _, err = t.inputer.processRoomEvent(ctx, t.db, &ire) - if err != nil { + if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { if _, ok := err.(types.RejectedError); !ok { return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err) } @@ -163,7 +163,7 @@ func (t *missingStateReq) processEventWithMissingState( stateIDs = append(stateIDs, event.EventID()) } - _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ Kind: api.KindOld, Event: backwardsExtremity.Headered(roomVersion), Origin: t.origin, @@ -182,7 +182,7 @@ func (t *missingStateReq) processEventWithMissingState( // they will automatically fast-forward based on the room state at the // extremity in the last step. for _, newEvent := range newEvents { - _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -473,8 +473,10 @@ retryAllowedState: // without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled, prevStateKnown bool, err error) { logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) - - latest := t.db.LatestEvents() + latest, _, _, err := t.db.LatestEventIDs(ctx, t.roomInfo.RoomNID) + if err != nil { + return nil, false, false, fmt.Errorf("t.DB.LatestEventIDs: %w", err) + } latestEvents := make([]string, len(latest)) for i, ev := range latest { latestEvents[i] = ev.EventID diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index a9851e05b..685505d52 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -35,6 +35,11 @@ type Database interface { stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, ) (types.StateSnapshotNID, error) + + MissingAuthPrevEvents( + ctx context.Context, e *gomatrixserverlib.Event, + ) (missingAuth, missingPrev []string, err error) + // Look up the state of a room at each event for a list of string event IDs. // Returns an error if there is an error talking to the database. // The length of []types.StateAtEvent is guaranteed to equal the length of eventIDs if no error is returned. diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 810a18ef2..d4a2ee3b9 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -103,25 +103,6 @@ func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { return u.currentStateSnapshotNID } -func (u *RoomUpdater) MissingAuthPrevEvents( - ctx context.Context, e *gomatrixserverlib.Event, -) (missingAuth, missingPrev []string, err error) { - for _, authEventID := range e.AuthEventIDs() { - if nids, err := u.EventNIDs(ctx, []string{authEventID}); err != nil || len(nids) == 0 { - missingAuth = append(missingAuth, authEventID) - } - } - - for _, prevEventID := range e.PrevEventIDs() { - state, err := u.StateAtEventIDs(ctx, []string{prevEventID}) - if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) { - missingPrev = append(missingPrev, prevEventID) - } - } - - return -} - // StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { @@ -146,13 +127,6 @@ func (u *RoomUpdater) SnapshotNIDFromEventID( return u.d.snapshotNIDFromEventID(ctx, u.txn, eventID) } -func (u *RoomUpdater) StoreEvent( - ctx context.Context, event *gomatrixserverlib.Event, - authEventNIDs []types.EventNID, isRejected bool, -) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { - return u.d.storeEvent(ctx, u, event, authEventNIDs, isRejected) -} - func (u *RoomUpdater) StateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { @@ -212,44 +186,16 @@ func (u *RoomUpdater) EventIDs( return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs) } -func (u *RoomUpdater) EventNIDs( - ctx context.Context, eventIDs []string, -) (map[string]types.EventNID, error) { - return u.d.eventNIDs(ctx, u.txn, eventIDs, NoFilter) -} - -func (u *RoomUpdater) UnsentEventNIDs( - ctx context.Context, eventIDs []string, -) (map[string]types.EventNID, error) { - return u.d.eventNIDs(ctx, u.txn, eventIDs, FilterUnsentOnly) -} - func (u *RoomUpdater) StateAtEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) } -func (u *RoomUpdater) StateEntriesForEventIDs( - ctx context.Context, eventIDs []string, -) ([]types.StateEntry, error) { - return u.d.EventsTable.BulkSelectStateEventByID(ctx, u.txn, eventIDs) -} - -func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false) -} - func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true) } -func (u *RoomUpdater) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, -) ([]types.EventNID, error) { - return u.d.getMembershipEventNIDsForRoom(ctx, u.txn, roomNID, joinOnly, localOnly) -} - // IsReferenced implements types.RoomRecentEventsUpdater func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index e270e121c..6e84b2832 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -674,6 +674,29 @@ func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true) } +func (d *Database) MissingAuthPrevEvents( + ctx context.Context, e *gomatrixserverlib.Event, +) (missingAuth, missingPrev []string, err error) { + authEventNIDs, err := d.EventNIDs(ctx, e.AuthEventIDs()) + if err != nil { + return nil, nil, fmt.Errorf("d.EventNIDs: %w", err) + } + for _, authEventID := range e.AuthEventIDs() { + if _, ok := authEventNIDs[authEventID]; !ok { + missingAuth = append(missingAuth, authEventID) + } + } + + for _, prevEventID := range e.PrevEventIDs() { + state, err := d.StateAtEventIDs(ctx, []string{prevEventID}) + if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) { + missingPrev = append(missingPrev, prevEventID) + } + } + + return +} + func (d *Database) assignRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion,