Try to use room version to get correct state resolution algorithm

This commit is contained in:
Neil Alexander 2020-03-19 15:18:29 +00:00
parent 6654b86b0b
commit 92ed93db38
6 changed files with 35 additions and 64 deletions

View file

@ -156,11 +156,8 @@ func calculateAndSetState(
stateAtEvent *types.StateAtEvent,
event gomatrixserverlib.Event,
) error {
// TODO: get the correct room version
roomState, err := state.Prepare(db, gomatrixserverlib.RoomVersionV1)
if err != nil {
return err
}
var err error
roomState := state.Prepare(db)
if input.HasState {
// We've been told what the state at the event is so we don't need to calculate it.

View file

@ -178,11 +178,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
func (u *latestEventsUpdater) latestState() error {
var err error
// TODO: get the correct room version
roomState, err := state.Prepare(u.db, gomatrixserverlib.RoomVersionV1)
if err != nil {
return err
}
roomState := state.Prepare(u.db)
latestStateAtEvents := make([]types.StateAtEvent, len(u.latest))
for i := range u.latest {

View file

@ -111,10 +111,7 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
return err
}
roomState, err := state.Prepare(r.DB, roomVersion)
if err != nil {
return err
}
roomState := state.Prepare(r.DB)
response.QueryLatestEventsAndStateRequest = *request
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
@ -164,10 +161,7 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents(
return err
}
roomState, err := state.Prepare(r.DB, roomVersion)
if err != nil {
return err
}
roomState := state.Prepare(r.DB)
response.QueryStateAfterEventsRequest = *request
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
@ -192,7 +186,7 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents(
// Look up the currrent state for the requested tuples.
stateEntries, err := roomState.LoadStateAfterEventsForStringTuples(
ctx, prevStates, request.StateToFetch,
ctx, roomNID, prevStates, request.StateToFetch,
)
if err != nil {
return err
@ -358,11 +352,7 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(
ctx context.Context, eventNID types.EventNID, joinedOnly bool,
) ([]types.Event, error) {
// TODO: get the correct room version
roomState, err := state.Prepare(r.DB, gomatrixserverlib.RoomVersionV1)
if err != nil {
return []types.Event{}, err
}
roomState := state.Prepare(r.DB)
events := []types.Event{}
// Lookup the event NID
eIDs, err := r.DB.EventIDs(ctx, []types.EventNID{eventNID})
@ -464,12 +454,7 @@ func (r *RoomserverQueryAPI) QueryServerAllowedToSeeEvent(
func (r *RoomserverQueryAPI) checkServerAllowedToSeeEvent(
ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName,
) (bool, error) {
// TODO: get the correct room version
roomState, err := state.Prepare(r.DB, gomatrixserverlib.RoomVersionV1)
if err != nil {
return false, err
}
roomState := state.Prepare(r.DB)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
if err != nil {
return false, err
@ -689,12 +674,7 @@ func (r *RoomserverQueryAPI) QueryStateAndAuthChain(
}
func (r *RoomserverQueryAPI) loadStateAtEventIDs(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) {
// TODO: get the correct room version
roomState, err := state.Prepare(r.DB, gomatrixserverlib.RoomVersionV1)
if err != nil {
return nil, err
}
roomState := state.Prepare(r.DB)
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
if err != nil {
switch err.(type) {

View file

@ -20,6 +20,7 @@ import (
"context"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
// A RoomStateDatabase has the storage APIs needed to load state from the database
@ -61,4 +62,6 @@ type RoomStateDatabase interface {
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
// Look up snapshot NID for an event ID string
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
// Look up a room version from the room NID.
GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
}

View file

@ -33,25 +33,11 @@ import (
type StateResolution struct {
db database.RoomStateDatabase
version gomatrixserverlib.RoomVersion
}
func Prepare(db database.RoomStateDatabase, version gomatrixserverlib.RoomVersion) (StateResolution, error) {
stateResAlgo, err := version.StateResAlgorithm()
if err != nil {
return StateResolution{}, err
}
switch stateResAlgo {
case gomatrixserverlib.StateResV1:
fallthrough
case gomatrixserverlib.StateResV2:
func Prepare(db database.RoomStateDatabase) StateResolution {
return StateResolution{
db: db,
version: version,
}, nil
default:
return StateResolution{}, errors.New("unsupported state resolution algorithm")
}
}
@ -350,7 +336,7 @@ func (v StateResolution) loadStateAtSnapshotForNumericTuples(
// This is typically the state before an event.
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
func (v StateResolution) LoadStateAfterEventsForStringTuples(
ctx context.Context,
ctx context.Context, roomNID types.RoomNID,
prevStates []types.StateAtEvent,
stateKeyTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateEntry, error) {
@ -358,14 +344,19 @@ func (v StateResolution) LoadStateAfterEventsForStringTuples(
if err != nil {
return nil, err
}
return v.loadStateAfterEventsForNumericTuples(ctx, prevStates, numericTuples)
return v.loadStateAfterEventsForNumericTuples(ctx, roomNID, prevStates, numericTuples)
}
func (v StateResolution) loadStateAfterEventsForNumericTuples(
ctx context.Context,
ctx context.Context, roomNID types.RoomNID,
prevStates []types.StateAtEvent,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) {
roomVersion, err := v.db.GetRoomVersionForRoomNID(ctx, roomNID)
if err != nil {
return nil, err
}
if len(prevStates) == 1 {
// Fast path for a single event.
prevState := prevStates[0]
@ -408,7 +399,7 @@ func (v StateResolution) loadStateAfterEventsForNumericTuples(
// TODO: Add metrics for this as it could take a long time for big rooms
// with large conflicts.
fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, prevStates)
fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, roomVersion, prevStates)
if err != nil {
return nil, err
}
@ -617,9 +608,13 @@ func (v StateResolution) calculateAndStoreStateAfterManyEvents(
prevStates []types.StateAtEvent,
metrics calculateStateMetrics,
) (types.StateSnapshotNID, error) {
roomVersion, err := v.db.GetRoomVersionForRoomNID(ctx, roomNID)
if err != nil {
return metrics.stop(0, err)
}
state, algorithm, conflictLength, err :=
v.calculateStateAfterManyEvents(ctx, prevStates)
v.calculateStateAfterManyEvents(ctx, roomVersion, prevStates)
metrics.algorithm = algorithm
if err != nil {
return metrics.stop(0, err)
@ -633,7 +628,8 @@ func (v StateResolution) calculateAndStoreStateAfterManyEvents(
}
func (v StateResolution) calculateStateAfterManyEvents(
ctx context.Context, prevStates []types.StateAtEvent,
ctx context.Context, roomVersion gomatrixserverlib.RoomVersion,
prevStates []types.StateAtEvent,
) (state []types.StateEntry, algorithm string, conflictLength int, err error) {
var combined []types.StateEntry
// Conflict resolution.
@ -668,7 +664,7 @@ func (v StateResolution) calculateStateAfterManyEvents(
}
var resolved []types.StateEntry
resolved, err = v.resolveConflicts(ctx, notConflicted, conflicts)
resolved, err = v.resolveConflicts(ctx, roomVersion, notConflicted, conflicts)
if err != nil {
algorithm = "_resolve_conflicts"
return
@ -684,10 +680,10 @@ func (v StateResolution) calculateStateAfterManyEvents(
}
func (v StateResolution) resolveConflicts(
ctx context.Context,
ctx context.Context, version gomatrixserverlib.RoomVersion,
notConflicted, conflicted []types.StateEntry,
) ([]types.StateEntry, error) {
stateResAlgo, err := v.version.StateResAlgorithm()
stateResAlgo, err := version.StateResAlgorithm()
if err != nil {
return nil, err
}

View file

@ -46,5 +46,4 @@ type Database interface {
GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error)
EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error)
GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
}