Factor out remaining bits for RoomInfo

This commit is contained in:
Kegan Dougal 2020-09-01 17:06:50 +01:00
parent e66f9f5c30
commit 701c06d4dc
10 changed files with 118 additions and 118 deletions

View file

@ -64,7 +64,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
} }
// Store the event. // Store the event.
roomNID, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs)
if err != nil { if err != nil {
return "", fmt.Errorf("r.DB.StoreEvent: %w", err) return "", fmt.Errorf("r.DB.StoreEvent: %w", err)
} }
@ -89,15 +89,6 @@ func (r *RoomserverInternalAPI) processRoomEvent(
return event.EventID(), nil return event.EventID(), nil
} }
if stateAtEvent.BeforeStateSnapshotNID == 0 {
// We haven't calculated a state for this event yet.
// Lets calculate one.
err = r.calculateAndSetState(ctx, input, roomNID, &stateAtEvent, event)
if err != nil {
return "", fmt.Errorf("r.calculateAndSetState: %w", err)
}
}
roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID()) roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID())
if err != nil { if err != nil {
return "", fmt.Errorf("r.DB.RoomInfo: %w", err) return "", fmt.Errorf("r.DB.RoomInfo: %w", err)
@ -106,6 +97,15 @@ func (r *RoomserverInternalAPI) processRoomEvent(
return "", fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID()) return "", fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID())
} }
if stateAtEvent.BeforeStateSnapshotNID == 0 {
// We haven't calculated a state for this event yet.
// Lets calculate one.
err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event)
if err != nil {
return "", fmt.Errorf("r.calculateAndSetState: %w", err)
}
}
if err = r.updateLatestEvents( if err = r.updateLatestEvents(
ctx, // context ctx, // context
roomInfo, // room info for the room being updated roomInfo, // room info for the room being updated
@ -143,19 +143,19 @@ func (r *RoomserverInternalAPI) processRoomEvent(
func (r *RoomserverInternalAPI) calculateAndSetState( func (r *RoomserverInternalAPI) calculateAndSetState(
ctx context.Context, ctx context.Context,
input api.InputRoomEvent, input api.InputRoomEvent,
roomNID types.RoomNID, roomInfo types.RoomInfo,
stateAtEvent *types.StateAtEvent, stateAtEvent *types.StateAtEvent,
event gomatrixserverlib.Event, event gomatrixserverlib.Event,
) error { ) error {
var err error var err error
roomState := state.NewStateResolution(r.DB) roomState := state.NewStateResolution(r.DB, roomInfo)
if input.HasState { if input.HasState {
// Check here if we think we're in the room already. // Check here if we think we're in the room already.
stateAtEvent.Overwrite = true stateAtEvent.Overwrite = true
var joinEventNIDs []types.EventNID var joinEventNIDs []types.EventNID
// Request join memberships only for local users only. // Request join memberships only for local users only.
if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, true); err == nil { if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil {
// If we have no local users that are joined to the room then any state about // If we have no local users that are joined to the room then any state about
// the room that we have is quite possibly out of date. Therefore in that case // the room that we have is quite possibly out of date. Therefore in that case
// we should overwrite it rather than merge it. // we should overwrite it rather than merge it.
@ -169,14 +169,14 @@ func (r *RoomserverInternalAPI) calculateAndSetState(
return err return err
} }
if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil {
return err return err
} }
} else { } else {
stateAtEvent.Overwrite = false stateAtEvent.Overwrite = false
// We haven't been told what the state at the event is so we need to calculate it from the prev_events // We haven't been told what the state at the event is so we need to calculate it from the prev_events
if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event, roomNID); err != nil { if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event); err != nil {
return err return err
} }
} }

View file

@ -55,7 +55,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
sendAsServer string, sendAsServer string,
transactionID *api.TransactionID, transactionID *api.TransactionID,
) (err error) { ) (err error) {
updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomInfo.RoomNID) updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err) return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err)
} }
@ -209,7 +209,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
func (u *latestEventsUpdater) latestState() error { func (u *latestEventsUpdater) latestState() error {
var err error var err error
roomState := state.NewStateResolution(u.api.DB) roomState := state.NewStateResolution(u.api.DB, *u.roomInfo)
// Get a list of the current latest events. // Get a list of the current latest events.
latestStateAtEvents := make([]types.StateAtEvent, len(u.latest)) latestStateAtEvents := make([]types.StateAtEvent, len(u.latest))
@ -221,7 +221,7 @@ func (u *latestEventsUpdater) latestState() error {
// of the state after the events. The snapshot state will be resolved // of the state after the events. The snapshot state will be resolved
// using the correct state resolution algorithm for the room. // using the correct state resolution algorithm for the room.
u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents( u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents(
u.ctx, u.roomInfo.RoomNID, latestStateAtEvents, u.ctx, latestStateAtEvents,
) )
if err != nil { if err != nil {
return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err) return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err)

View file

@ -189,7 +189,17 @@ FindSuccessor:
return nil return nil
} }
stateEntries, err := stateBeforeEvent(ctx, b.db, NIDs[eventID]) info, err := b.db.RoomInfo(ctx, roomID)
if err != nil {
logrus.WithError(err).WithField("room_id", roomID).Error("ServersAtEvent: failed to get RoomInfo for room")
return nil
}
if info == nil || info.IsStub {
logrus.WithField("room_id", roomID).Error("ServersAtEvent: failed to get RoomInfo for room, room is missing")
return nil
}
stateEntries, err := stateBeforeEvent(ctx, b.db, *info, NIDs[eventID])
if err != nil { if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
return nil return nil

View file

@ -208,7 +208,7 @@ func buildInviteStrippedState(
StateKey: "", StateKey: "",
}) })
} }
roomState := state.NewStateResolution(db) roomState := state.NewStateResolution(db, *info)
stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples(
ctx, info.StateSnapshotNID, stateWanted, ctx, info.StateSnapshotNID, stateWanted,
) )

View file

@ -38,27 +38,22 @@ func (r *RoomserverInternalAPI) QueryLatestEventsAndState(
request *api.QueryLatestEventsAndStateRequest, request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse, response *api.QueryLatestEventsAndStateResponse,
) error { ) error {
roomVersion, err := r.roomVersion(request.RoomID) roomInfo, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil { if err != nil {
return err
}
if roomInfo == nil || roomInfo.IsStub {
response.RoomExists = false response.RoomExists = false
return nil return nil
} }
roomState := state.NewStateResolution(r.DB) roomState := state.NewStateResolution(r.DB, *roomInfo)
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil {
return err
}
if info.IsStub {
return nil
}
response.RoomExists = true response.RoomExists = true
response.RoomVersion = roomVersion response.RoomVersion = roomInfo.RoomVersion
var currentStateSnapshotNID types.StateSnapshotNID var currentStateSnapshotNID types.StateSnapshotNID
response.LatestEvents, currentStateSnapshotNID, response.Depth, err = response.LatestEvents, currentStateSnapshotNID, response.Depth, err =
r.DB.LatestEventIDs(ctx, info.RoomNID) r.DB.LatestEventIDs(ctx, roomInfo.RoomNID)
if err != nil { if err != nil {
return err return err
} }
@ -85,7 +80,7 @@ func (r *RoomserverInternalAPI) QueryLatestEventsAndState(
} }
for _, event := range stateEvents { for _, event := range stateEvents {
response.StateEvents = append(response.StateEvents, event.Headered(roomVersion)) response.StateEvents = append(response.StateEvents, event.Headered(roomInfo.RoomVersion))
} }
return nil return nil
@ -97,23 +92,17 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents(
request *api.QueryStateAfterEventsRequest, request *api.QueryStateAfterEventsRequest,
response *api.QueryStateAfterEventsResponse, response *api.QueryStateAfterEventsResponse,
) error { ) error {
roomVersion, err := r.roomVersion(request.RoomID)
if err != nil {
response.RoomExists = false
return nil
}
roomState := state.NewStateResolution(r.DB)
info, err := r.DB.RoomInfo(ctx, request.RoomID) info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil { if err != nil {
return err return err
} }
if info.IsStub { if info == nil || info.IsStub {
return nil return nil
} }
roomState := state.NewStateResolution(r.DB, *info)
response.RoomExists = true response.RoomExists = true
response.RoomVersion = roomVersion response.RoomVersion = info.RoomVersion
prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs) prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs)
if err != nil { if err != nil {
@ -128,7 +117,7 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents(
// Look up the currrent state for the requested tuples. // Look up the currrent state for the requested tuples.
stateEntries, err := roomState.LoadStateAfterEventsForStringTuples( stateEntries, err := roomState.LoadStateAfterEventsForStringTuples(
ctx, info.RoomNID, prevStates, request.StateToFetch, ctx, prevStates, request.StateToFetch,
) )
if err != nil { if err != nil {
return err return err
@ -140,7 +129,7 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents(
} }
for _, event := range stateEvents { for _, event := range stateEvents {
response.StateEvents = append(response.StateEvents, event.Headered(roomVersion)) response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion))
} }
return nil return nil
@ -277,7 +266,7 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom(
events, err = r.DB.Events(ctx, eventNIDs) events, err = r.DB.Events(ctx, eventNIDs)
} else { } else {
stateEntries, err = stateBeforeEvent(ctx, r.DB, membershipEventNID) stateEntries, err = stateBeforeEvent(ctx, r.DB, *info, membershipEventNID)
if err != nil { if err != nil {
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
return err return err
@ -297,8 +286,8 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom(
return nil return nil
} }
func stateBeforeEvent(ctx context.Context, db storage.Database, eventNID types.EventNID) ([]types.StateEntry, error) { func stateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) {
roomState := state.NewStateResolution(db) roomState := state.NewStateResolution(db, info)
// Lookup the event NID // Lookup the event NID
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
if err != nil { if err != nil {
@ -370,20 +359,28 @@ func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent(
response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see
return return
} }
isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, events[0].RoomID()) roomID := events[0].RoomID()
isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, roomID)
if err != nil { if err != nil {
return return
} }
info, err := r.DB.RoomInfo(ctx, roomID)
if err != nil {
return err
}
if info == nil {
return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID)
}
response.AllowedToSeeEvent, err = r.checkServerAllowedToSeeEvent( response.AllowedToSeeEvent, err = r.checkServerAllowedToSeeEvent(
ctx, request.EventID, request.ServerName, isServerInRoom, ctx, *info, request.EventID, request.ServerName, isServerInRoom,
) )
return return
} }
func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent( func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent(
ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, ctx context.Context, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
) (bool, error) { ) (bool, error) {
roomState := state.NewStateResolution(r.DB) roomState := state.NewStateResolution(r.DB, info)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
if err != nil { if err != nil {
return false, err return false, err
@ -418,8 +415,22 @@ func (r *RoomserverInternalAPI) QueryMissingEvents(
eventsToFilter[id] = true eventsToFilter[id] = true
} }
} }
events, err := r.DB.EventsFromIDs(ctx, front)
if err != nil {
return err
}
if len(events) == 0 {
return nil // we are missing the events being asked to search from, give up.
}
info, err := r.DB.RoomInfo(ctx, events[0].RoomID())
if err != nil {
return err
}
if info == nil || info.IsStub {
return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID())
}
resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName) resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName)
if err != nil { if err != nil {
return err return err
} }
@ -467,8 +478,16 @@ func (r *RoomserverInternalAPI) PerformBackfill(
// this will include these events which is what we want // this will include these events which is what we want
front = request.PrevEventIDs() front = request.PrevEventIDs()
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil {
return err
}
if info == nil || info.IsStub {
return fmt.Errorf("PerformBackfill: missing room info for room %s", request.RoomID)
}
// Scan the event tree for events to send back. // Scan the event tree for events to send back.
resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName) resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName)
if err != nil { if err != nil {
return err return err
} }
@ -481,12 +500,7 @@ func (r *RoomserverInternalAPI) PerformBackfill(
} }
for _, event := range loadedEvents { for _, event := range loadedEvents {
roomVersion, verr := r.roomVersion(event.RoomID()) response.Events = append(response.Events, event.Headered(info.RoomVersion))
if verr != nil {
return verr
}
response.Events = append(response.Events, event.Headered(roomVersion))
} }
return err return err
@ -642,7 +656,7 @@ func (r *RoomserverInternalAPI) fetchAndStoreMissingEvents(ctx context.Context,
// TODO: Remove this when we have tests to assert correctness of this function // TODO: Remove this when we have tests to assert correctness of this function
// nolint:gocyclo // nolint:gocyclo
func (r *RoomserverInternalAPI) scanEventTree( func (r *RoomserverInternalAPI) scanEventTree(
ctx context.Context, front []string, visited map[string]bool, limit int, ctx context.Context, info types.RoomInfo, front []string, visited map[string]bool, limit int,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
) ([]types.EventNID, error) { ) ([]types.EventNID, error) {
var resultNIDs []types.EventNID var resultNIDs []types.EventNID
@ -708,7 +722,7 @@ BFSLoop:
// hasn't been seen before. // hasn't been seen before.
if !visited[pre] { if !visited[pre] {
visited[pre] = true visited[pre] = true
allowed, err = r.checkServerAllowedToSeeEvent(ctx, pre, serverName, isServerInRoom) allowed, err = r.checkServerAllowedToSeeEvent(ctx, info, pre, serverName, isServerInRoom)
if err != nil { if err != nil {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
"Error checking if allowed to see event", "Error checking if allowed to see event",
@ -744,13 +758,13 @@ func (r *RoomserverInternalAPI) QueryStateAndAuthChain(
if err != nil { if err != nil {
return err return err
} }
if info.IsStub { if info == nil || info.IsStub {
return nil return nil
} }
response.RoomExists = true response.RoomExists = true
response.RoomVersion = info.RoomVersion response.RoomVersion = info.RoomVersion
stateEvents, err := r.loadStateAtEventIDs(ctx, request.PrevEventIDs) stateEvents, err := r.loadStateAtEventIDs(ctx, *info, request.PrevEventIDs)
if err != nil { if err != nil {
return err return err
} }
@ -788,8 +802,8 @@ func (r *RoomserverInternalAPI) QueryStateAndAuthChain(
return err return err
} }
func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.Event, error) {
roomState := state.NewStateResolution(r.DB) roomState := state.NewStateResolution(r.DB, roomInfo)
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
@ -941,6 +955,9 @@ func (r *RoomserverInternalAPI) QueryRoomVersionForRoom(
if err != nil { if err != nil {
return err return err
} }
if info == nil {
return fmt.Errorf("QueryRoomVersionForRoom: missing room info for room %s", request.RoomID)
}
response.RoomVersion = info.RoomVersion response.RoomVersion = info.RoomVersion
r.Cache.StoreRoomVersion(request.RoomID, response.RoomVersion) r.Cache.StoreRoomVersion(request.RoomID, response.RoomVersion)
return nil return nil

View file

@ -31,12 +31,14 @@ import (
) )
type StateResolution struct { type StateResolution struct {
db storage.Database db storage.Database
roomInfo types.RoomInfo
} }
func NewStateResolution(db storage.Database) StateResolution { func NewStateResolution(db storage.Database, roomInfo types.RoomInfo) StateResolution {
return StateResolution{ return StateResolution{
db: db, db: db,
roomInfo: roomInfo,
} }
} }
@ -339,7 +341,7 @@ func (v StateResolution) loadStateAtSnapshotForNumericTuples(
// This is typically the state before an event. // This is typically the state before an event.
// Returns a sorted list of state entries or an error if there was a problem talking to the database. // Returns a sorted list of state entries or an error if there was a problem talking to the database.
func (v StateResolution) LoadStateAfterEventsForStringTuples( func (v StateResolution) LoadStateAfterEventsForStringTuples(
ctx context.Context, roomNID types.RoomNID, ctx context.Context,
prevStates []types.StateAtEvent, prevStates []types.StateAtEvent,
stateKeyTuples []gomatrixserverlib.StateKeyTuple, stateKeyTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
@ -347,7 +349,7 @@ func (v StateResolution) LoadStateAfterEventsForStringTuples(
if err != nil { if err != nil {
return nil, err return nil, err
} }
return v.loadStateAfterEventsForNumericTuples(ctx, roomNID, prevStates, numericTuples) return v.loadStateAfterEventsForNumericTuples(ctx, v.roomInfo.RoomNID, prevStates, numericTuples)
} }
func (v StateResolution) loadStateAfterEventsForNumericTuples( func (v StateResolution) loadStateAfterEventsForNumericTuples(
@ -355,16 +357,10 @@ func (v StateResolution) loadStateAfterEventsForNumericTuples(
prevStates []types.StateAtEvent, prevStates []types.StateAtEvent,
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
roomVersion, err := v.db.GetRoomVersionForRoomNID(ctx, roomNID)
if err != nil {
return nil, err
}
if len(prevStates) == 1 { if len(prevStates) == 1 {
// Fast path for a single event. // Fast path for a single event.
prevState := prevStates[0] prevState := prevStates[0]
var result []types.StateEntry result, err := v.loadStateAtSnapshotForNumericTuples(
result, err = v.loadStateAtSnapshotForNumericTuples(
ctx, prevState.BeforeStateSnapshotNID, stateKeyTuples, ctx, prevState.BeforeStateSnapshotNID, stateKeyTuples,
) )
if err != nil { if err != nil {
@ -403,7 +399,7 @@ func (v StateResolution) loadStateAfterEventsForNumericTuples(
// TODO: Add metrics for this as it could take a long time for big rooms // TODO: Add metrics for this as it could take a long time for big rooms
// with large conflicts. // with large conflicts.
fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, roomVersion, prevStates) fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -527,7 +523,6 @@ func init() {
func (v StateResolution) CalculateAndStoreStateBeforeEvent( func (v StateResolution) CalculateAndStoreStateBeforeEvent(
ctx context.Context, ctx context.Context,
event gomatrixserverlib.Event, event gomatrixserverlib.Event,
roomNID types.RoomNID,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
// Load the state at the prev events. // Load the state at the prev events.
prevEventRefs := event.PrevEvents() prevEventRefs := event.PrevEvents()
@ -542,14 +537,13 @@ func (v StateResolution) CalculateAndStoreStateBeforeEvent(
} }
// The state before this event will be the state after the events that came before it. // The state before this event will be the state after the events that came before it.
return v.CalculateAndStoreStateAfterEvents(ctx, roomNID, prevStates) return v.CalculateAndStoreStateAfterEvents(ctx, prevStates)
} }
// CalculateAndStoreStateAfterEvents finds the room state after the given events. // CalculateAndStoreStateAfterEvents finds the room state after the given events.
// Stores the resulting state in the database and returns a numeric ID for that snapshot. // Stores the resulting state in the database and returns a numeric ID for that snapshot.
func (v StateResolution) CalculateAndStoreStateAfterEvents( func (v StateResolution) CalculateAndStoreStateAfterEvents(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID,
prevStates []types.StateAtEvent, prevStates []types.StateAtEvent,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)}
@ -558,7 +552,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents(
// 2) There weren't any prev_events for this event so the state is // 2) There weren't any prev_events for this event so the state is
// empty. // empty.
metrics.algorithm = "empty_state" metrics.algorithm = "empty_state"
stateNID, err := v.db.AddState(ctx, roomNID, nil, nil) stateNID, err := v.db.AddState(ctx, v.roomInfo.RoomNID, nil, nil)
if err != nil { if err != nil {
err = fmt.Errorf("v.db.AddState: %w", err) err = fmt.Errorf("v.db.AddState: %w", err)
} }
@ -590,7 +584,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents(
// add the state event as a block of size one to the end of the blocks. // add the state event as a block of size one to the end of the blocks.
metrics.algorithm = "single_delta" metrics.algorithm = "single_delta"
stateNID, err := v.db.AddState( stateNID, err := v.db.AddState(
ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, ctx, v.roomInfo.RoomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
) )
if err != nil { if err != nil {
err = fmt.Errorf("v.db.AddState: %w", err) err = fmt.Errorf("v.db.AddState: %w", err)
@ -601,7 +595,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents(
// So fall through to calculateAndStoreStateAfterManyEvents // So fall through to calculateAndStoreStateAfterManyEvents
} }
stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics) stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, v.roomInfo.RoomNID, prevStates, metrics)
if err != nil { if err != nil {
return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err) return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err)
} }
@ -624,13 +618,8 @@ func (v StateResolution) calculateAndStoreStateAfterManyEvents(
prevStates []types.StateAtEvent, prevStates []types.StateAtEvent,
metrics calculateStateMetrics, metrics calculateStateMetrics,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
roomVersion, err := v.db.GetRoomVersionForRoomNID(ctx, roomNID)
if err != nil {
return metrics.stop(0, err)
}
state, algorithm, conflictLength, err := state, algorithm, conflictLength, err :=
v.calculateStateAfterManyEvents(ctx, roomVersion, prevStates) v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates)
metrics.algorithm = algorithm metrics.algorithm = algorithm
if err != nil { if err != nil {
return metrics.stop(0, err) return metrics.stop(0, err)

View file

@ -66,8 +66,6 @@ type Database interface {
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
// Look up snapshot NID for an event ID string // Look up snapshot NID for an event ID string
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
// Look up a room version from the room NID.
GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
StoreEvent( StoreEvent(
ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
@ -91,7 +89,7 @@ type Database interface {
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
// Returns the latest events in the room and the last eventID sent to the log along with an updater. // Returns the latest events in the room and the last eventID sent to the log along with an updater.
// If this returns an error then no further action is required. // If this returns an error then no further action is required.
GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (*shared.LatestEventsUpdater, error) GetLatestEventsForUpdate(ctx context.Context, roomInfo types.RoomInfo) (*shared.LatestEventsUpdater, error)
// Look up event ID by transaction's info. // Look up event ID by transaction's info.
// This is used to determine if the room event is processed/processing already. // This is used to determine if the room event is processed/processing already.
// Returns an empty string if no such event exists. // Returns an empty string if no such event exists.

View file

@ -12,15 +12,15 @@ import (
type LatestEventsUpdater struct { type LatestEventsUpdater struct {
transaction transaction
d *Database d *Database
roomNID types.RoomNID roomInfo types.RoomInfo
latestEvents []types.StateAtEventAndReference latestEvents []types.StateAtEventAndReference
lastEventIDSent string lastEventIDSent string
currentStateSnapshotNID types.StateSnapshotNID currentStateSnapshotNID types.StateSnapshotNID
} }
func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomNID types.RoomNID) (*LatestEventsUpdater, error) { func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) {
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomNID) d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID)
if err != nil { if err != nil {
txn.Rollback() // nolint: errcheck txn.Rollback() // nolint: errcheck
return nil, err return nil, err
@ -39,14 +39,13 @@ func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomN
} }
} }
return &LatestEventsUpdater{ return &LatestEventsUpdater{
transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
}, nil }, nil
} }
// RoomVersion implements types.RoomRecentEventsUpdater // RoomVersion implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
version, _ = u.d.GetRoomVersionForRoomNID(u.ctx, u.roomNID) return u.roomInfo.RoomVersion
return
} }
// LatestEvents implements types.RoomRecentEventsUpdater // LatestEvents implements types.RoomRecentEventsUpdater
@ -118,5 +117,5 @@ func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
} }
func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal) return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
} }

View file

@ -229,19 +229,6 @@ func (d *Database) StateEntries(
return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs)
} }
func (d *Database) GetRoomVersionForRoomNID(
ctx context.Context, roomNID types.RoomNID,
) (gomatrixserverlib.RoomVersion, error) {
if roomID, ok := d.Cache.GetRoomServerRoomID(roomNID); ok {
if roomVersion, ok := d.Cache.GetRoomVersion(roomID); ok {
return roomVersion, nil
}
}
return d.RoomsTable.SelectRoomVersionForRoomNID(
ctx, roomNID,
)
}
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.RoomAliasesTable.InsertRoomAlias(ctx, txn, alias, roomID, creatorUserID) return d.RoomAliasesTable.InsertRoomAlias(ctx, txn, alias, roomID, creatorUserID)
@ -376,7 +363,7 @@ func (d *Database) MembershipUpdater(
} }
func (d *Database) GetLatestEventsForUpdate( func (d *Database) GetLatestEventsForUpdate(
ctx context.Context, roomNID types.RoomNID, ctx context.Context, roomInfo types.RoomInfo,
) (*LatestEventsUpdater, error) { ) (*LatestEventsUpdater, error) {
txn, err := d.DB.Begin() txn, err := d.DB.Begin()
if err != nil { if err != nil {
@ -384,7 +371,7 @@ func (d *Database) GetLatestEventsForUpdate(
} }
var updater *LatestEventsUpdater var updater *LatestEventsUpdater
_ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
updater, err = NewLatestEventsUpdater(ctx, d, txn, roomNID) updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo)
return nil return nil
}) })
return updater, err return updater, err

View file

@ -150,7 +150,7 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
} }
func (d *Database) GetLatestEventsForUpdate( func (d *Database) GetLatestEventsForUpdate(
ctx context.Context, roomNID types.RoomNID, ctx context.Context, roomInfo types.RoomInfo,
) (*shared.LatestEventsUpdater, error) { ) (*shared.LatestEventsUpdater, error) {
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have // TODO: Do not use transactions. We should be holding open this transaction but we cannot have
// multiple write transactions on sqlite. The code will perform additional // multiple write transactions on sqlite. The code will perform additional
@ -158,7 +158,7 @@ func (d *Database) GetLatestEventsForUpdate(
// 'database is locked' errors. As sqlite doesn't support multi-process on the // 'database is locked' errors. As sqlite doesn't support multi-process on the
// same DB anyway, and we only execute updates sequentially, the only worries // same DB anyway, and we only execute updates sequentially, the only worries
// are for rolling back when things go wrong. (atomicity) // are for rolling back when things go wrong. (atomicity)
return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomNID) return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo)
} }
func (d *Database) MembershipUpdater( func (d *Database) MembershipUpdater(