This commit is contained in:
Kegan Dougal 2020-04-27 19:18:26 +01:00
parent 7ae6884cc8
commit be7c636e70
2 changed files with 74 additions and 55 deletions

View file

@ -49,40 +49,12 @@ func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent
if !ok { if !ok {
goto FederationHit goto FederationHit
} }
// The state IDs BEFORE the target event are the state IDs BEFORE the prev_event PLUS the prev_event itself newStateIDs := b.calculateNewStateIDs(targetEvent.Unwrap(), prevEvent, prevEventStateIDs)
newStateIDs := prevEventStateIDs[:] if newStateIDs != nil {
if prevEvent.StateKey() == nil {
// state is the same as the previous event
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs
return newStateIDs, nil return newStateIDs, nil
} }
// else we failed to calculate the new state, so fallthrough
missingState := false // true if we are missing the info for a state event ID
foundEvent := false // true if we found a (type, state_key) match
// find which state ID to replace, if any
for i, id := range newStateIDs {
ev, ok := b.eventIDMap[id]
if !ok {
missingState = true
continue
}
if ev.Type() == prevEvent.Type() && ev.StateKey() != nil && ev.StateKey() == prevEvent.StateKey() {
newStateIDs[i] = prevEvent.EventID()
foundEvent = true
break
}
}
if !foundEvent && !missingState {
// we can be certain that this is new state
newStateIDs = append(newStateIDs, prevEvent.EventID())
foundEvent = true
}
if foundEvent {
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs
return newStateIDs, nil
}
// else fallthrough because we don't know if one of the missing state IDs was the one we could replace.
} }
FederationHit: FederationHit:
@ -105,6 +77,43 @@ FederationHit:
return nil, lastErr return nil, lastErr
} }
func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent gomatrixserverlib.Event, prevEventStateIDs []string) []string {
newStateIDs := prevEventStateIDs[:]
if prevEvent.StateKey() == nil {
// state is the same as the previous event
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs
return newStateIDs
}
missingState := false // true if we are missing the info for a state event ID
foundEvent := false // true if we found a (type, state_key) match
// find which state ID to replace, if any
for i, id := range newStateIDs {
ev, ok := b.eventIDMap[id]
if !ok {
missingState = true
continue
}
// The state IDs BEFORE the target event are the state IDs BEFORE the prev_event PLUS the prev_event itself
if ev.Type() == prevEvent.Type() && ev.StateKey() != nil && ev.StateKey() == prevEvent.StateKey() {
newStateIDs[i] = prevEvent.EventID()
foundEvent = true
break
}
}
if !foundEvent && !missingState {
// we can be certain that this is new state
newStateIDs = append(newStateIDs, prevEvent.EventID())
foundEvent = true
}
if foundEvent {
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs
return newStateIDs
}
return nil
}
func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, event gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) { func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, event gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) {
// try to fetch the events from the database first // try to fetch the events from the database first
events, err := b.ProvideEvents(roomVer, eventIDs) events, err := b.ProvideEvents(roomVer, eventIDs)

View file

@ -551,37 +551,18 @@ func (r *RoomserverQueryAPI) backfillViaFederation(ctx context.Context, req *api
} }
logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
backfilledEventMap := make(map[string]types.Event)
var roomNID types.RoomNID
// persist these new events - auth checks have already been done // persist these new events - auth checks have already been done
for _, ev := range events { roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events)
nidMap, err := r.DB.EventNIDs(ctx, ev.AuthEventIDs())
if err != nil { // this shouldn't happen as RequestBackill already found them
logrus.WithError(err).WithField("auth_events", ev.AuthEventIDs()).Error("Failed to find one or more auth events")
continue
}
authNids := make([]types.EventNID, len(nidMap))
i := 0
for _, nid := range nidMap {
authNids[i] = nid
i++
}
var stateAtEvent types.StateAtEvent
roomNID, stateAtEvent, err = r.DB.StoreEvent(ctx, ev.Unwrap(), nil, authNids)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to store backfilled event") return err
continue
}
backfilledEventMap[ev.EventID()] = types.Event{
EventNID: stateAtEvent.StateEntry.EventNID,
Event: ev.Unwrap(),
}
} }
for _, ev := range backfilledEventMap { for _, ev := range backfilledEventMap {
// now add state for these events // now add state for these events
stateIDs, ok := requester.eventIDToBeforeStateIDs[ev.EventID()] stateIDs, ok := requester.eventIDToBeforeStateIDs[ev.EventID()]
if !ok { if !ok {
// this should be impossible as all events returned must have pass Step 5 of the PDU checks
// which requires a list of state IDs.
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to find state IDs for event which passed auth checks") logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to find state IDs for event which passed auth checks")
continue continue
} }
@ -858,6 +839,35 @@ func getAuthChain(
return authEvents, nil return authEvents, nil
} }
func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) {
var roomNID types.RoomNID
backfilledEventMap := make(map[string]types.Event)
for _, ev := range events {
nidMap, err := db.EventNIDs(ctx, ev.AuthEventIDs())
if err != nil { // this shouldn't happen as RequestBackfill already found them
logrus.WithError(err).WithField("auth_events", ev.AuthEventIDs()).Error("Failed to find one or more auth events")
continue
}
authNids := make([]types.EventNID, len(nidMap))
i := 0
for _, nid := range nidMap {
authNids[i] = nid
i++
}
var stateAtEvent types.StateAtEvent
roomNID, stateAtEvent, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids)
if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to store backfilled event")
continue
}
backfilledEventMap[ev.EventID()] = types.Event{
EventNID: stateAtEvent.StateEntry.EventNID,
Event: ev.Unwrap(),
}
}
return roomNID, backfilledEventMap
}
// QueryRoomVersionCapabilities implements api.RoomserverQueryAPI // QueryRoomVersionCapabilities implements api.RoomserverQueryAPI
func (r *RoomserverQueryAPI) QueryRoomVersionCapabilities( func (r *RoomserverQueryAPI) QueryRoomVersionCapabilities(
ctx context.Context, ctx context.Context,