diff --git a/roomserver/internal/input_events.go b/roomserver/internal/input_events.go index dabf4ee47..287db1af2 100644 --- a/roomserver/internal/input_events.go +++ b/roomserver/internal/input_events.go @@ -64,7 +64,7 @@ func (r *RoomserverInternalAPI) processRoomEvent( } // Store the event. - roomNID, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) + _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) if err != nil { return "", fmt.Errorf("r.DB.StoreEvent: %w", err) } @@ -89,15 +89,6 @@ func (r *RoomserverInternalAPI) processRoomEvent( return event.EventID(), nil } - if stateAtEvent.BeforeStateSnapshotNID == 0 { - // We haven't calculated a state for this event yet. - // Lets calculate one. - err = r.calculateAndSetState(ctx, input, roomNID, &stateAtEvent, event) - if err != nil { - return "", fmt.Errorf("r.calculateAndSetState: %w", err) - } - } - roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID()) if err != nil { return "", fmt.Errorf("r.DB.RoomInfo: %w", err) @@ -106,6 +97,15 @@ func (r *RoomserverInternalAPI) processRoomEvent( return "", fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID()) } + if stateAtEvent.BeforeStateSnapshotNID == 0 { + // We haven't calculated a state for this event yet. + // Lets calculate one. + err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event) + if err != nil { + return "", fmt.Errorf("r.calculateAndSetState: %w", err) + } + } + if err = r.updateLatestEvents( ctx, // context roomInfo, // room info for the room being updated @@ -143,19 +143,19 @@ func (r *RoomserverInternalAPI) processRoomEvent( func (r *RoomserverInternalAPI) calculateAndSetState( ctx context.Context, input api.InputRoomEvent, - roomNID types.RoomNID, + roomInfo types.RoomInfo, stateAtEvent *types.StateAtEvent, event gomatrixserverlib.Event, ) error { var err error - roomState := state.NewStateResolution(r.DB) + roomState := state.NewStateResolution(r.DB, roomInfo) if input.HasState { // Check here if we think we're in the room already. stateAtEvent.Overwrite = true var joinEventNIDs []types.EventNID // Request join memberships only for local users only. - if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, true); err == nil { + if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil { // If we have no local users that are joined to the room then any state about // the room that we have is quite possibly out of date. Therefore in that case // we should overwrite it rather than merge it. @@ -169,14 +169,14 @@ func (r *RoomserverInternalAPI) calculateAndSetState( return err } - if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { return err } } else { stateAtEvent.Overwrite = false // We haven't been told what the state at the event is so we need to calculate it from the prev_events - if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event, roomNID); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event); err != nil { return err } } diff --git a/roomserver/internal/input_latest_events.go b/roomserver/internal/input_latest_events.go index 76dda3999..d5e38e7a4 100644 --- a/roomserver/internal/input_latest_events.go +++ b/roomserver/internal/input_latest_events.go @@ -55,7 +55,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents( sendAsServer string, transactionID *api.TransactionID, ) (err error) { - updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomInfo.RoomNID) + updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo) if err != nil { return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err) } @@ -209,7 +209,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error - roomState := state.NewStateResolution(u.api.DB) + roomState := state.NewStateResolution(u.api.DB, *u.roomInfo) // Get a list of the current latest events. latestStateAtEvents := make([]types.StateAtEvent, len(u.latest)) @@ -221,7 +221,7 @@ func (u *latestEventsUpdater) latestState() error { // of the state after the events. The snapshot state will be resolved // using the correct state resolution algorithm for the room. u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents( - u.ctx, u.roomInfo.RoomNID, latestStateAtEvents, + u.ctx, latestStateAtEvents, ) if err != nil { return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err) diff --git a/roomserver/internal/perform_backfill.go b/roomserver/internal/perform_backfill.go index 65c88860c..03644a7c8 100644 --- a/roomserver/internal/perform_backfill.go +++ b/roomserver/internal/perform_backfill.go @@ -189,7 +189,17 @@ FindSuccessor: return nil } - stateEntries, err := stateBeforeEvent(ctx, b.db, NIDs[eventID]) + info, err := b.db.RoomInfo(ctx, roomID) + if err != nil { + logrus.WithError(err).WithField("room_id", roomID).Error("ServersAtEvent: failed to get RoomInfo for room") + return nil + } + if info == nil || info.IsStub { + logrus.WithField("room_id", roomID).Error("ServersAtEvent: failed to get RoomInfo for room, room is missing") + return nil + } + + stateEntries, err := stateBeforeEvent(ctx, b.db, *info, NIDs[eventID]) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") return nil diff --git a/roomserver/internal/perform_invite.go b/roomserver/internal/perform_invite.go index 1cfbcc18c..6690de055 100644 --- a/roomserver/internal/perform_invite.go +++ b/roomserver/internal/perform_invite.go @@ -208,7 +208,7 @@ func buildInviteStrippedState( StateKey: "", }) } - roomState := state.NewStateResolution(db) + roomState := state.NewStateResolution(db, *info) stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( ctx, info.StateSnapshotNID, stateWanted, ) diff --git a/roomserver/internal/query.go b/roomserver/internal/query.go index 266e7100d..14f79773a 100644 --- a/roomserver/internal/query.go +++ b/roomserver/internal/query.go @@ -38,27 +38,22 @@ func (r *RoomserverInternalAPI) QueryLatestEventsAndState( request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, ) error { - roomVersion, err := r.roomVersion(request.RoomID) + roomInfo, err := r.DB.RoomInfo(ctx, request.RoomID) if err != nil { + return err + } + if roomInfo == nil || roomInfo.IsStub { response.RoomExists = false return nil } - roomState := state.NewStateResolution(r.DB) - - info, err := r.DB.RoomInfo(ctx, request.RoomID) - if err != nil { - return err - } - if info.IsStub { - return nil - } + roomState := state.NewStateResolution(r.DB, *roomInfo) response.RoomExists = true - response.RoomVersion = roomVersion + response.RoomVersion = roomInfo.RoomVersion var currentStateSnapshotNID types.StateSnapshotNID response.LatestEvents, currentStateSnapshotNID, response.Depth, err = - r.DB.LatestEventIDs(ctx, info.RoomNID) + r.DB.LatestEventIDs(ctx, roomInfo.RoomNID) if err != nil { return err } @@ -85,7 +80,7 @@ func (r *RoomserverInternalAPI) QueryLatestEventsAndState( } for _, event := range stateEvents { - response.StateEvents = append(response.StateEvents, event.Headered(roomVersion)) + response.StateEvents = append(response.StateEvents, event.Headered(roomInfo.RoomVersion)) } return nil @@ -97,23 +92,17 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents( request *api.QueryStateAfterEventsRequest, response *api.QueryStateAfterEventsResponse, ) error { - roomVersion, err := r.roomVersion(request.RoomID) - if err != nil { - response.RoomExists = false - return nil - } - - roomState := state.NewStateResolution(r.DB) - info, err := r.DB.RoomInfo(ctx, request.RoomID) if err != nil { return err } - if info.IsStub { + if info == nil || info.IsStub { return nil } + + roomState := state.NewStateResolution(r.DB, *info) response.RoomExists = true - response.RoomVersion = roomVersion + response.RoomVersion = info.RoomVersion prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs) if err != nil { @@ -128,7 +117,7 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents( // Look up the currrent state for the requested tuples. stateEntries, err := roomState.LoadStateAfterEventsForStringTuples( - ctx, info.RoomNID, prevStates, request.StateToFetch, + ctx, prevStates, request.StateToFetch, ) if err != nil { return err @@ -140,7 +129,7 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents( } for _, event := range stateEvents { - response.StateEvents = append(response.StateEvents, event.Headered(roomVersion)) + response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion)) } return nil @@ -277,7 +266,7 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom( events, err = r.DB.Events(ctx, eventNIDs) } else { - stateEntries, err = stateBeforeEvent(ctx, r.DB, membershipEventNID) + stateEntries, err = stateBeforeEvent(ctx, r.DB, *info, membershipEventNID) if err != nil { logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") return err @@ -297,8 +286,8 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom( return nil } -func stateBeforeEvent(ctx context.Context, db storage.Database, eventNID types.EventNID) ([]types.StateEntry, error) { - roomState := state.NewStateResolution(db) +func stateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { + roomState := state.NewStateResolution(db, info) // Lookup the event NID eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) if err != nil { @@ -370,20 +359,28 @@ func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent( response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see return } - isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, events[0].RoomID()) + roomID := events[0].RoomID() + isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, roomID) if err != nil { return } + info, err := r.DB.RoomInfo(ctx, roomID) + if err != nil { + return err + } + if info == nil { + return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID) + } response.AllowedToSeeEvent, err = r.checkServerAllowedToSeeEvent( - ctx, request.EventID, request.ServerName, isServerInRoom, + ctx, *info, request.EventID, request.ServerName, isServerInRoom, ) return } func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent( - ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, + ctx context.Context, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, ) (bool, error) { - roomState := state.NewStateResolution(r.DB) + roomState := state.NewStateResolution(r.DB, info) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) if err != nil { return false, err @@ -418,8 +415,22 @@ func (r *RoomserverInternalAPI) QueryMissingEvents( eventsToFilter[id] = true } } + events, err := r.DB.EventsFromIDs(ctx, front) + if err != nil { + return err + } + if len(events) == 0 { + return nil // we are missing the events being asked to search from, give up. + } + info, err := r.DB.RoomInfo(ctx, events[0].RoomID()) + if err != nil { + return err + } + if info == nil || info.IsStub { + return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID()) + } - resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName) + resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName) if err != nil { return err } @@ -467,8 +478,16 @@ func (r *RoomserverInternalAPI) PerformBackfill( // this will include these events which is what we want front = request.PrevEventIDs() + info, err := r.DB.RoomInfo(ctx, request.RoomID) + if err != nil { + return err + } + if info == nil || info.IsStub { + return fmt.Errorf("PerformBackfill: missing room info for room %s", request.RoomID) + } + // Scan the event tree for events to send back. - resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName) + resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName) if err != nil { return err } @@ -481,12 +500,7 @@ func (r *RoomserverInternalAPI) PerformBackfill( } for _, event := range loadedEvents { - roomVersion, verr := r.roomVersion(event.RoomID()) - if verr != nil { - return verr - } - - response.Events = append(response.Events, event.Headered(roomVersion)) + response.Events = append(response.Events, event.Headered(info.RoomVersion)) } return err @@ -642,7 +656,7 @@ func (r *RoomserverInternalAPI) fetchAndStoreMissingEvents(ctx context.Context, // TODO: Remove this when we have tests to assert correctness of this function // nolint:gocyclo func (r *RoomserverInternalAPI) scanEventTree( - ctx context.Context, front []string, visited map[string]bool, limit int, + ctx context.Context, info types.RoomInfo, front []string, visited map[string]bool, limit int, serverName gomatrixserverlib.ServerName, ) ([]types.EventNID, error) { var resultNIDs []types.EventNID @@ -708,7 +722,7 @@ BFSLoop: // hasn't been seen before. if !visited[pre] { visited[pre] = true - allowed, err = r.checkServerAllowedToSeeEvent(ctx, pre, serverName, isServerInRoom) + allowed, err = r.checkServerAllowedToSeeEvent(ctx, info, pre, serverName, isServerInRoom) if err != nil { util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( "Error checking if allowed to see event", @@ -744,13 +758,13 @@ func (r *RoomserverInternalAPI) QueryStateAndAuthChain( if err != nil { return err } - if info.IsStub { + if info == nil || info.IsStub { return nil } response.RoomExists = true response.RoomVersion = info.RoomVersion - stateEvents, err := r.loadStateAtEventIDs(ctx, request.PrevEventIDs) + stateEvents, err := r.loadStateAtEventIDs(ctx, *info, request.PrevEventIDs) if err != nil { return err } @@ -788,8 +802,8 @@ func (r *RoomserverInternalAPI) QueryStateAndAuthChain( return err } -func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { - roomState := state.NewStateResolution(r.DB) +func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.Event, error) { + roomState := state.NewStateResolution(r.DB, roomInfo) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { switch err.(type) { @@ -941,6 +955,9 @@ func (r *RoomserverInternalAPI) QueryRoomVersionForRoom( if err != nil { return err } + if info == nil { + return fmt.Errorf("QueryRoomVersionForRoom: missing room info for room %s", request.RoomID) + } response.RoomVersion = info.RoomVersion r.Cache.StoreRoomVersion(request.RoomID, response.RoomVersion) return nil diff --git a/roomserver/state/state.go b/roomserver/state/state.go index b9ad4a504..7b545bea6 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -31,12 +31,14 @@ import ( ) type StateResolution struct { - db storage.Database + db storage.Database + roomInfo types.RoomInfo } -func NewStateResolution(db storage.Database) StateResolution { +func NewStateResolution(db storage.Database, roomInfo types.RoomInfo) StateResolution { return StateResolution{ - db: db, + db: db, + roomInfo: roomInfo, } } @@ -339,7 +341,7 @@ func (v StateResolution) loadStateAtSnapshotForNumericTuples( // This is typically the state before an event. // Returns a sorted list of state entries or an error if there was a problem talking to the database. func (v StateResolution) LoadStateAfterEventsForStringTuples( - ctx context.Context, roomNID types.RoomNID, + ctx context.Context, prevStates []types.StateAtEvent, stateKeyTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateEntry, error) { @@ -347,7 +349,7 @@ func (v StateResolution) LoadStateAfterEventsForStringTuples( if err != nil { return nil, err } - return v.loadStateAfterEventsForNumericTuples(ctx, roomNID, prevStates, numericTuples) + return v.loadStateAfterEventsForNumericTuples(ctx, v.roomInfo.RoomNID, prevStates, numericTuples) } func (v StateResolution) loadStateAfterEventsForNumericTuples( @@ -355,16 +357,10 @@ func (v StateResolution) loadStateAfterEventsForNumericTuples( prevStates []types.StateAtEvent, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { - roomVersion, err := v.db.GetRoomVersionForRoomNID(ctx, roomNID) - if err != nil { - return nil, err - } - if len(prevStates) == 1 { // Fast path for a single event. prevState := prevStates[0] - var result []types.StateEntry - result, err = v.loadStateAtSnapshotForNumericTuples( + result, err := v.loadStateAtSnapshotForNumericTuples( ctx, prevState.BeforeStateSnapshotNID, stateKeyTuples, ) if err != nil { @@ -403,7 +399,7 @@ func (v StateResolution) loadStateAfterEventsForNumericTuples( // TODO: Add metrics for this as it could take a long time for big rooms // with large conflicts. - fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, roomVersion, prevStates) + fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates) if err != nil { return nil, err } @@ -527,7 +523,6 @@ func init() { func (v StateResolution) CalculateAndStoreStateBeforeEvent( ctx context.Context, event gomatrixserverlib.Event, - roomNID types.RoomNID, ) (types.StateSnapshotNID, error) { // Load the state at the prev events. prevEventRefs := event.PrevEvents() @@ -542,14 +537,13 @@ func (v StateResolution) CalculateAndStoreStateBeforeEvent( } // The state before this event will be the state after the events that came before it. - return v.CalculateAndStoreStateAfterEvents(ctx, roomNID, prevStates) + return v.CalculateAndStoreStateAfterEvents(ctx, prevStates) } // CalculateAndStoreStateAfterEvents finds the room state after the given events. // Stores the resulting state in the database and returns a numeric ID for that snapshot. func (v StateResolution) CalculateAndStoreStateAfterEvents( ctx context.Context, - roomNID types.RoomNID, prevStates []types.StateAtEvent, ) (types.StateSnapshotNID, error) { metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} @@ -558,7 +552,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( // 2) There weren't any prev_events for this event so the state is // empty. metrics.algorithm = "empty_state" - stateNID, err := v.db.AddState(ctx, roomNID, nil, nil) + stateNID, err := v.db.AddState(ctx, v.roomInfo.RoomNID, nil, nil) if err != nil { err = fmt.Errorf("v.db.AddState: %w", err) } @@ -590,7 +584,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( // add the state event as a block of size one to the end of the blocks. metrics.algorithm = "single_delta" stateNID, err := v.db.AddState( - ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, + ctx, v.roomInfo.RoomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, ) if err != nil { err = fmt.Errorf("v.db.AddState: %w", err) @@ -601,7 +595,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( // So fall through to calculateAndStoreStateAfterManyEvents } - stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics) + stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, v.roomInfo.RoomNID, prevStates, metrics) if err != nil { return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err) } @@ -624,13 +618,8 @@ func (v StateResolution) calculateAndStoreStateAfterManyEvents( prevStates []types.StateAtEvent, metrics calculateStateMetrics, ) (types.StateSnapshotNID, error) { - roomVersion, err := v.db.GetRoomVersionForRoomNID(ctx, roomNID) - if err != nil { - return metrics.stop(0, err) - } - state, algorithm, conflictLength, err := - v.calculateStateAfterManyEvents(ctx, roomVersion, prevStates) + v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates) metrics.algorithm = algorithm if err != nil { return metrics.stop(0, err) diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 2dfaebd10..ef7a9f090 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -66,8 +66,6 @@ type Database interface { Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) // Look up snapshot NID for an event ID string SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) - // Look up a room version from the room NID. - GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. StoreEvent( ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, @@ -91,7 +89,7 @@ type Database interface { // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. // Returns the latest events in the room and the last eventID sent to the log along with an updater. // If this returns an error then no further action is required. - GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (*shared.LatestEventsUpdater, error) + GetLatestEventsForUpdate(ctx context.Context, roomInfo types.RoomInfo) (*shared.LatestEventsUpdater, error) // Look up event ID by transaction's info. // This is used to determine if the room event is processed/processing already. // Returns an empty string if no such event exists. diff --git a/roomserver/storage/shared/latest_events_updater.go b/roomserver/storage/shared/latest_events_updater.go index e9a0f6982..29eab0c98 100644 --- a/roomserver/storage/shared/latest_events_updater.go +++ b/roomserver/storage/shared/latest_events_updater.go @@ -12,15 +12,15 @@ import ( type LatestEventsUpdater struct { transaction d *Database - roomNID types.RoomNID + roomInfo types.RoomInfo latestEvents []types.StateAtEventAndReference lastEventIDSent string currentStateSnapshotNID types.StateSnapshotNID } -func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomNID types.RoomNID) (*LatestEventsUpdater, error) { +func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) { eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := - d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomNID) + d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID) if err != nil { txn.Rollback() // nolint: errcheck return nil, err @@ -39,14 +39,13 @@ func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomN } } return &LatestEventsUpdater{ - transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, }, nil } // RoomVersion implements types.RoomRecentEventsUpdater func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { - version, _ = u.d.GetRoomVersionForRoomNID(u.ctx, u.roomNID) - return + return u.roomInfo.RoomVersion } // LatestEvents implements types.RoomRecentEventsUpdater @@ -118,5 +117,5 @@ func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { } func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { - return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal) + return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index cba1dcf00..6e0ebd2c2 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -229,19 +229,6 @@ func (d *Database) StateEntries( return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) } -func (d *Database) GetRoomVersionForRoomNID( - ctx context.Context, roomNID types.RoomNID, -) (gomatrixserverlib.RoomVersion, error) { - if roomID, ok := d.Cache.GetRoomServerRoomID(roomNID); ok { - if roomVersion, ok := d.Cache.GetRoomVersion(roomID); ok { - return roomVersion, nil - } - } - return d.RoomsTable.SelectRoomVersionForRoomNID( - ctx, roomNID, - ) -} - func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.RoomAliasesTable.InsertRoomAlias(ctx, txn, alias, roomID, creatorUserID) @@ -376,7 +363,7 @@ func (d *Database) MembershipUpdater( } func (d *Database) GetLatestEventsForUpdate( - ctx context.Context, roomNID types.RoomNID, + ctx context.Context, roomInfo types.RoomInfo, ) (*LatestEventsUpdater, error) { txn, err := d.DB.Begin() if err != nil { @@ -384,7 +371,7 @@ func (d *Database) GetLatestEventsForUpdate( } var updater *LatestEventsUpdater _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { - updater, err = NewLatestEventsUpdater(ctx, d, txn, roomNID) + updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo) return nil }) return updater, err diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 87dce6ad1..1f135fd22 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -150,7 +150,7 @@ func (d *Database) SupportsConcurrentRoomInputs() bool { } func (d *Database) GetLatestEventsForUpdate( - ctx context.Context, roomNID types.RoomNID, + ctx context.Context, roomInfo types.RoomInfo, ) (*shared.LatestEventsUpdater, error) { // TODO: Do not use transactions. We should be holding open this transaction but we cannot have // multiple write transactions on sqlite. The code will perform additional @@ -158,7 +158,7 @@ func (d *Database) GetLatestEventsForUpdate( // 'database is locked' errors. As sqlite doesn't support multi-process on the // same DB anyway, and we only execute updates sequentially, the only worries // are for rolling back when things go wrong. (atomicity) - return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomNID) + return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo) } func (d *Database) MembershipUpdater(