Make get_missing_events test pass

This commit is contained in:
Till Faelligen 2022-10-11 13:35:10 +02:00
parent aee546f2b9
commit 79a9eaa20c
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
3 changed files with 16 additions and 10 deletions

View file

@ -324,7 +324,7 @@ func slowGetHistoryVisibilityState(
func ScanEventTree( func ScanEventTree(
ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int, ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
) ([]types.EventNID, error) { ) ([]types.EventNID, map[string]struct{}, error) {
var resultNIDs []types.EventNID var resultNIDs []types.EventNID
var err error var err error
var allowed bool var allowed bool
@ -345,6 +345,7 @@ func ScanEventTree(
var checkedServerInRoom bool var checkedServerInRoom bool
var isServerInRoom bool var isServerInRoom bool
redactEventIDs := make(map[string]struct{})
// Loop through the event IDs to retrieve the requested events and go // Loop through the event IDs to retrieve the requested events and go
// through the whole tree (up to the provided limit) using the events' // through the whole tree (up to the provided limit) using the events'
@ -358,7 +359,7 @@ BFSLoop:
// Retrieve the events to process from the database. // Retrieve the events to process from the database.
events, err = db.EventsFromIDs(ctx, front) events, err = db.EventsFromIDs(ctx, front)
if err != nil { if err != nil {
return resultNIDs, err return resultNIDs, redactEventIDs, err
} }
if !checkedServerInRoom && len(events) > 0 { if !checkedServerInRoom && len(events) > 0 {
@ -395,16 +396,16 @@ BFSLoop:
) )
// drop the error, as we will often error at the DB level if we don't have the prev_event itself. Let's // drop the error, as we will often error at the DB level if we don't have the prev_event itself. Let's
// just return what we have. // just return what we have.
return resultNIDs, nil return resultNIDs, redactEventIDs, nil
} }
// If the event hasn't been seen before and the HS // If the event hasn't been seen before and the HS
// requesting to retrieve it is allowed to do so, add it to // requesting to retrieve it is allowed to do so, add it to
// the list of events to retrieve. // the list of events to retrieve.
if allowed { next = append(next, pre)
next = append(next, pre) if !allowed {
} else {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event") util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event")
redactEventIDs[pre] = struct{}{}
} }
} }
} }
@ -413,7 +414,7 @@ BFSLoop:
front = next front = next
} }
return resultNIDs, err return resultNIDs, redactEventIDs, err
} }
func QueryLatestEventsAndState( func QueryLatestEventsAndState(

View file

@ -78,7 +78,7 @@ func (r *Backfiller) PerformBackfill(
} }
// Scan the event tree for events to send back. // Scan the event tree for events to send back.
resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
if err != nil { if err != nil {
return err return err
} }
@ -95,6 +95,9 @@ func (r *Backfiller) PerformBackfill(
} }
for _, event := range loadedEvents { for _, event := range loadedEvents {
if _, ok := redactEventIDs[event.EventID()]; ok {
event.Redact()
}
response.Events = append(response.Events, event.Headered(info.RoomVersion)) response.Events = append(response.Events, event.Headered(info.RoomVersion))
} }

View file

@ -453,7 +453,7 @@ func (r *Queryer) QueryMissingEvents(
return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID()) return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID())
} }
resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) resultNIDs, redact, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
if err != nil { if err != nil {
return err return err
} }
@ -470,7 +470,9 @@ func (r *Queryer) QueryMissingEvents(
if verr != nil { if verr != nil {
return verr return verr
} }
if _, ok := redact[event.EventID()]; ok {
event.Redact()
}
response.Events = append(response.Events, event.Headered(roomVersion)) response.Events = append(response.Events, event.Headered(roomVersion))
} }
} }