From 92ed93db38c908166bb5036a1a3cb64bdc6b7530 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 19 Mar 2020 15:18:29 +0000 Subject: [PATCH] Try to use room version to get correct state resolution algorithm --- roomserver/input/events.go | 7 ++-- roomserver/input/latest_events.go | 6 +--- roomserver/query/query.go | 32 ++++------------- roomserver/state/database/database.go | 3 ++ roomserver/state/state.go | 50 ++++++++++++--------------- roomserver/storage/interface.go | 1 - 6 files changed, 35 insertions(+), 64 deletions(-) diff --git a/roomserver/input/events.go b/roomserver/input/events.go index a9afcaee0..6502b6432 100644 --- a/roomserver/input/events.go +++ b/roomserver/input/events.go @@ -156,11 +156,8 @@ func calculateAndSetState( stateAtEvent *types.StateAtEvent, event gomatrixserverlib.Event, ) error { - // TODO: get the correct room version - roomState, err := state.Prepare(db, gomatrixserverlib.RoomVersionV1) - if err != nil { - return err - } + var err error + roomState := state.Prepare(db) if input.HasState { // We've been told what the state at the event is so we don't need to calculate it. diff --git a/roomserver/input/latest_events.go b/roomserver/input/latest_events.go index 761043051..04d2b81c4 100644 --- a/roomserver/input/latest_events.go +++ b/roomserver/input/latest_events.go @@ -178,11 +178,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error - // TODO: get the correct room version - roomState, err := state.Prepare(u.db, gomatrixserverlib.RoomVersionV1) - if err != nil { - return err - } + roomState := state.Prepare(u.db) latestStateAtEvents := make([]types.StateAtEvent, len(u.latest)) for i := range u.latest { diff --git a/roomserver/query/query.go b/roomserver/query/query.go index fe20ace36..9f3f44198 100644 --- a/roomserver/query/query.go +++ b/roomserver/query/query.go @@ -111,10 +111,7 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState( return err } - roomState, err := state.Prepare(r.DB, roomVersion) - if err != nil { - return err - } + roomState := state.Prepare(r.DB) response.QueryLatestEventsAndStateRequest = *request roomNID, err := r.DB.RoomNID(ctx, request.RoomID) @@ -164,10 +161,7 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents( return err } - roomState, err := state.Prepare(r.DB, roomVersion) - if err != nil { - return err - } + roomState := state.Prepare(r.DB) response.QueryStateAfterEventsRequest = *request roomNID, err := r.DB.RoomNID(ctx, request.RoomID) @@ -192,7 +186,7 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents( // Look up the currrent state for the requested tuples. stateEntries, err := roomState.LoadStateAfterEventsForStringTuples( - ctx, prevStates, request.StateToFetch, + ctx, roomNID, prevStates, request.StateToFetch, ) if err != nil { return err @@ -358,11 +352,7 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom( func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID( ctx context.Context, eventNID types.EventNID, joinedOnly bool, ) ([]types.Event, error) { - // TODO: get the correct room version - roomState, err := state.Prepare(r.DB, gomatrixserverlib.RoomVersionV1) - if err != nil { - return []types.Event{}, err - } + roomState := state.Prepare(r.DB) events := []types.Event{} // Lookup the event NID eIDs, err := r.DB.EventIDs(ctx, []types.EventNID{eventNID}) @@ -464,12 +454,7 @@ func (r *RoomserverQueryAPI) QueryServerAllowedToSeeEvent( func (r *RoomserverQueryAPI) checkServerAllowedToSeeEvent( ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName, ) (bool, error) { - // TODO: get the correct room version - roomState, err := state.Prepare(r.DB, gomatrixserverlib.RoomVersionV1) - if err != nil { - return false, err - } - + roomState := state.Prepare(r.DB) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) if err != nil { return false, err @@ -689,12 +674,7 @@ func (r *RoomserverQueryAPI) QueryStateAndAuthChain( } func (r *RoomserverQueryAPI) loadStateAtEventIDs(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { - // TODO: get the correct room version - roomState, err := state.Prepare(r.DB, gomatrixserverlib.RoomVersionV1) - if err != nil { - return nil, err - } - + roomState := state.Prepare(r.DB) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { switch err.(type) { diff --git a/roomserver/state/database/database.go b/roomserver/state/database/database.go index ede6c5ec3..80f1b14f4 100644 --- a/roomserver/state/database/database.go +++ b/roomserver/state/database/database.go @@ -20,6 +20,7 @@ import ( "context" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" ) // A RoomStateDatabase has the storage APIs needed to load state from the database @@ -61,4 +62,6 @@ type RoomStateDatabase interface { Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) // Look up snapshot NID for an event ID string SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + // Look up a room version from the room NID. + GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) } diff --git a/roomserver/state/state.go b/roomserver/state/state.go index c5e2d62a5..84e0d459a 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -32,26 +32,12 @@ import ( ) type StateResolution struct { - db database.RoomStateDatabase - version gomatrixserverlib.RoomVersion + db database.RoomStateDatabase } -func Prepare(db database.RoomStateDatabase, version gomatrixserverlib.RoomVersion) (StateResolution, error) { - stateResAlgo, err := version.StateResAlgorithm() - if err != nil { - return StateResolution{}, err - } - - switch stateResAlgo { - case gomatrixserverlib.StateResV1: - fallthrough - case gomatrixserverlib.StateResV2: - return StateResolution{ - db: db, - version: version, - }, nil - default: - return StateResolution{}, errors.New("unsupported state resolution algorithm") +func Prepare(db database.RoomStateDatabase) StateResolution { + return StateResolution{ + db: db, } } @@ -350,7 +336,7 @@ func (v StateResolution) loadStateAtSnapshotForNumericTuples( // This is typically the state before an event. // Returns a sorted list of state entries or an error if there was a problem talking to the database. func (v StateResolution) LoadStateAfterEventsForStringTuples( - ctx context.Context, + ctx context.Context, roomNID types.RoomNID, prevStates []types.StateAtEvent, stateKeyTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateEntry, error) { @@ -358,14 +344,19 @@ func (v StateResolution) LoadStateAfterEventsForStringTuples( if err != nil { return nil, err } - return v.loadStateAfterEventsForNumericTuples(ctx, prevStates, numericTuples) + return v.loadStateAfterEventsForNumericTuples(ctx, roomNID, prevStates, numericTuples) } func (v StateResolution) loadStateAfterEventsForNumericTuples( - ctx context.Context, + ctx context.Context, roomNID types.RoomNID, prevStates []types.StateAtEvent, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { + roomVersion, err := v.db.GetRoomVersionForRoomNID(ctx, roomNID) + if err != nil { + return nil, err + } + if len(prevStates) == 1 { // Fast path for a single event. prevState := prevStates[0] @@ -408,7 +399,7 @@ func (v StateResolution) loadStateAfterEventsForNumericTuples( // TODO: Add metrics for this as it could take a long time for big rooms // with large conflicts. - fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, prevStates) + fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, roomVersion, prevStates) if err != nil { return nil, err } @@ -617,9 +608,13 @@ func (v StateResolution) calculateAndStoreStateAfterManyEvents( prevStates []types.StateAtEvent, metrics calculateStateMetrics, ) (types.StateSnapshotNID, error) { + roomVersion, err := v.db.GetRoomVersionForRoomNID(ctx, roomNID) + if err != nil { + return metrics.stop(0, err) + } state, algorithm, conflictLength, err := - v.calculateStateAfterManyEvents(ctx, prevStates) + v.calculateStateAfterManyEvents(ctx, roomVersion, prevStates) metrics.algorithm = algorithm if err != nil { return metrics.stop(0, err) @@ -633,7 +628,8 @@ func (v StateResolution) calculateAndStoreStateAfterManyEvents( } func (v StateResolution) calculateStateAfterManyEvents( - ctx context.Context, prevStates []types.StateAtEvent, + ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, + prevStates []types.StateAtEvent, ) (state []types.StateEntry, algorithm string, conflictLength int, err error) { var combined []types.StateEntry // Conflict resolution. @@ -668,7 +664,7 @@ func (v StateResolution) calculateStateAfterManyEvents( } var resolved []types.StateEntry - resolved, err = v.resolveConflicts(ctx, notConflicted, conflicts) + resolved, err = v.resolveConflicts(ctx, roomVersion, notConflicted, conflicts) if err != nil { algorithm = "_resolve_conflicts" return @@ -684,10 +680,10 @@ func (v StateResolution) calculateStateAfterManyEvents( } func (v StateResolution) resolveConflicts( - ctx context.Context, + ctx context.Context, version gomatrixserverlib.RoomVersion, notConflicted, conflicted []types.StateEntry, ) ([]types.StateEntry, error) { - stateResAlgo, err := v.version.StateResAlgorithm() + stateResAlgo, err := version.StateResAlgorithm() if err != nil { return nil, err } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 98df4708d..20db7ef7f 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -46,5 +46,4 @@ type Database interface { GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) - GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) }