Recursively fetch auth events if needed

This commit is contained in:
Neil Alexander 2020-10-12 11:20:11 +01:00
parent 0804594a61
commit 868f2d5a80
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -112,6 +112,8 @@ type txnReq struct {
haveEvents map[string]*gomatrixserverlib.HeaderedEvent
// new events which the roomserver does not know about
newEvents map[string]bool
// servers which we should fetch missing events from
servers []gomatrixserverlib.ServerName
}
// A subset of FederationClient functionality that txn requires. Useful for testing.
@ -350,7 +352,7 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is
}
var stateResp api.QueryMissingAuthPrevEventsResponse
if err := t.rsAPI.QueryMissingAuthPrevEvents(ctx, &stateReq, &stateResp); err != nil {
return err
return fmt.Errorf("t.rsAPI.QueryMissingAuthPrevEvents: %w", err)
}
if !stateResp.RoomExists {
@ -366,45 +368,20 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is
if len(stateResp.MissingAuthEventIDs) > 0 {
logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs))
servers := []gomatrixserverlib.ServerName{t.Origin}
serverReq := &api.QueryServerJoinedToRoomRequest{
RoomID: e.RoomID(),
}
serverRes := &api.QueryServerJoinedToRoomResponse{}
if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil {
servers = append(servers, serverRes.ServerNames...)
logger.Infof("Found %d server(s) to query for missing events", len(servers))
if len(t.servers) == 0 {
t.servers = []gomatrixserverlib.ServerName{t.Origin}
serverReq := &api.QueryServerJoinedToRoomRequest{
RoomID: e.RoomID(),
}
serverRes := &api.QueryServerJoinedToRoomResponse{}
if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil {
t.servers = append(t.servers, serverRes.ServerNames...)
logger.Infof("Found %d server(s) to query for missing events", len(t.servers))
}
}
getAuthEvent:
for _, missingAuthEventID := range stateResp.MissingAuthEventIDs {
for _, server := range servers {
logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server)
tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID)
if err != nil {
continue // try the next server
}
ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion)
if err != nil {
logger.WithError(err).Errorf("Failed to unmarshal auth event %q", missingAuthEventID)
continue // try the next server
}
if err = api.SendInputRoomEvents(
context.Background(),
t.rsAPI,
[]api.InputRoomEvent{
{
Kind: api.KindOutlier,
Event: ev.Headered(stateResp.RoomVersion),
AuthEventIDs: ev.AuthEventIDs(),
SendAsServer: api.DoNotSendToOtherServers,
},
},
); err != nil {
logger.WithError(err).Errorf("Failed to send auth event %q to roomserver", missingAuthEventID)
continue getAuthEvent // move onto the next event
}
}
if err := t.retrieveMissingAuthEvents(ctx, e, &stateResp); err != nil {
return fmt.Errorf("t.retrieveMissingAuthEvents: %w", err)
}
}
@ -427,6 +404,45 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is
)
}
func (t *txnReq) retrieveMissingAuthEvents(
ctx context.Context, e gomatrixserverlib.Event, stateResp *api.QueryMissingAuthPrevEventsResponse,
) error {
logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID())
missingAuthEvents := make(map[string]struct{})
for _, missingAuthEventID := range stateResp.MissingAuthEventIDs {
missingAuthEvents[missingAuthEventID] = struct{}{}
}
withNextEvent:
for missingAuthEventID := range missingAuthEvents {
withNextServer:
for _, server := range t.servers {
logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server)
tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID)
if err != nil {
logger.WithError(err).Warnf("Failed to retrieve auth event %q", missingAuthEventID)
continue withNextServer
}
ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion)
if err != nil {
logger.WithError(err).Warnf("Failed to unmarshal auth event %q", missingAuthEventID)
continue withNextServer
}
if err = t.processEvent(ctx, ev, false); err != nil {
return fmt.Errorf("recursive t.processEvent: %w", err)
}
delete(missingAuthEvents, missingAuthEventID)
continue withNextEvent
}
}
if missing := len(missingAuthEvents); missing > 0 {
return fmt.Errorf("Event refers to %d auth_events which we failed to fetch", missing)
}
return nil
}
func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserverlib.Event) error {
authUsingState := gomatrixserverlib.NewAuthEvents(nil)
for i := range stateEvents {