Implement getEvents differently

This commit is contained in:
Neil Alexander 2020-10-13 11:29:22 +01:00
parent 1b14f872df
commit 8f8cc66c03
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -112,8 +112,6 @@ type txnReq struct {
haveEvents map[string]*gomatrixserverlib.HeaderedEvent haveEvents map[string]*gomatrixserverlib.HeaderedEvent
// new events which the roomserver does not know about // new events which the roomserver does not know about
newEvents map[string]bool newEvents map[string]bool
// servers which we should fetch missing events from
servers []gomatrixserverlib.ServerName
} }
// A subset of FederationClient functionality that txn requires. Useful for testing. // A subset of FederationClient functionality that txn requires. Useful for testing.
@ -336,6 +334,19 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli
} }
} }
func (t *txnReq) getServers(ctx context.Context, roomID string) []gomatrixserverlib.ServerName {
servers := []gomatrixserverlib.ServerName{t.Origin}
serverReq := &api.QueryServerJoinedToRoomRequest{
RoomID: roomID,
}
serverRes := &api.QueryServerJoinedToRoomResponse{}
if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil {
servers = append(servers, serverRes.ServerNames...)
util.GetLogger(ctx).Infof("Found %d server(s) to query for missing events in %q", len(servers), roomID)
}
return servers
}
func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event) error { func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event) error {
logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID())
@ -365,22 +376,8 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event) er
return roomNotFoundError{e.RoomID()} return roomNotFoundError{e.RoomID()}
} }
// We will need to know this when fetching missing auth or prev events.
getServers := func() {
t.servers = []gomatrixserverlib.ServerName{}
serverReq := &api.QueryServerJoinedToRoomRequest{
RoomID: e.RoomID(),
}
serverRes := &api.QueryServerJoinedToRoomResponse{}
if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil {
t.servers = append(t.servers, serverRes.ServerNames...)
logger.Infof("Found %d server(s) to query for missing events", len(t.servers))
}
}
if len(stateResp.MissingAuthEventIDs) > 0 { if len(stateResp.MissingAuthEventIDs) > 0 {
logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs)) logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs))
getServers()
if err := t.retrieveMissingAuthEvents(ctx, e, &stateResp); err != nil { if err := t.retrieveMissingAuthEvents(ctx, e, &stateResp); err != nil {
return fmt.Errorf("t.retrieveMissingAuthEvents: %w", err) return fmt.Errorf("t.retrieveMissingAuthEvents: %w", err)
} }
@ -388,7 +385,6 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event) er
if len(stateResp.MissingPrevEventIDs) > 0 { if len(stateResp.MissingPrevEventIDs) > 0 {
logger.Infof("Event refers to %d unknown prev_events", len(stateResp.MissingPrevEventIDs)) logger.Infof("Event refers to %d unknown prev_events", len(stateResp.MissingPrevEventIDs))
getServers()
return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion) return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion)
} }
@ -416,14 +412,14 @@ func (t *txnReq) retrieveMissingAuthEvents(
missingAuthEvents[missingAuthEventID] = struct{}{} missingAuthEvents[missingAuthEventID] = struct{}{}
} }
numServers := len(t.servers) servers := t.getServers(ctx, e.RoomID())
if numServers > 5 { if len(servers) > 5 {
numServers = 5 servers = servers[:5]
} }
withNextEvent: withNextEvent:
for missingAuthEventID := range missingAuthEvents { for missingAuthEventID := range missingAuthEvents {
withNextServer: withNextServer:
for _, server := range append([]gomatrixserverlib.ServerName{t.Origin}, t.servers[:numServers]...) { for _, server := range servers {
logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server) logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server)
tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID) tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID)
if err != nil { if err != nil {
@ -435,14 +431,17 @@ withNextEvent:
logger.WithError(err).Warnf("Failed to unmarshal auth event %q", missingAuthEventID) logger.WithError(err).Warnf("Failed to unmarshal auth event %q", missingAuthEventID)
continue withNextServer continue withNextServer
} }
if err = api.SendEvents( if err = api.SendInputRoomEvents(
context.Background(), context.Background(),
t.rsAPI, t.rsAPI,
[]gomatrixserverlib.HeaderedEvent{ []api.InputRoomEvent{
ev.Headered(stateResp.RoomVersion), {
Kind: api.KindOutlier,
Event: ev.Headered(stateResp.RoomVersion),
AuthEventIDs: ev.AuthEventIDs(),
SendAsServer: api.DoNotSendToOtherServers,
},
}, },
api.DoNotSendToOtherServers,
nil,
); err != nil { ); err != nil {
return fmt.Errorf("api.SendEvents: %w", err) return fmt.Errorf("api.SendEvents: %w", err)
} }
@ -586,8 +585,13 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix
return nil, fmt.Errorf("t.lookupStateBeforeEvent: %w", err) return nil, fmt.Errorf("t.lookupStateBeforeEvent: %w", err)
} }
servers := t.getServers(ctx, roomID)
if len(servers) > 5 {
servers = servers[:5]
}
// fetch the event we're missing and add it to the pile // fetch the event we're missing and add it to the pile
h, err := t.lookupEvent(ctx, roomVersion, eventID, false) h, err := t.lookupEvent(ctx, roomVersion, eventID, false, servers)
switch err.(type) { switch err.(type) {
case verifySigError: case verifySigError:
return respState, nil return respState, nil
@ -695,7 +699,11 @@ retryAllowedState:
if err = checkAllowedByState(*backwardsExtremity, resolvedStateEvents); err != nil { if err = checkAllowedByState(*backwardsExtremity, resolvedStateEvents); err != nil {
switch missing := err.(type) { switch missing := err.(type) {
case gomatrixserverlib.MissingAuthEventError: case gomatrixserverlib.MissingAuthEventError:
h, err2 := t.lookupEvent(ctx, roomVersion, missing.AuthEventID, true) servers := t.getServers(ctx, backwardsExtremity.RoomID())
if len(servers) > 5 {
servers = servers[:5]
}
h, err2 := t.lookupEvent(ctx, roomVersion, missing.AuthEventID, true, servers)
switch err2.(type) { switch err2.(type) {
case verifySigError: case verifySigError:
return &gomatrixserverlib.RespState{ return &gomatrixserverlib.RespState{
@ -900,6 +908,12 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
"concurrent_requests": concurrentRequests, "concurrent_requests": concurrentRequests,
}).Info("Fetching missing state at event") }).Info("Fetching missing state at event")
// Get a list of servers to fetch from.
servers := t.getServers(ctx, roomID)
if len(servers) > 5 {
servers = servers[:5]
}
// Create a queue containing all of the missing event IDs that we want // Create a queue containing all of the missing event IDs that we want
// to retrieve. // to retrieve.
pending := make(chan string, missingCount) pending := make(chan string, missingCount)
@ -925,7 +939,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
// Define what we'll do in order to fetch the missing event ID. // Define what we'll do in order to fetch the missing event ID.
fetch := func(missingEventID string) { fetch := func(missingEventID string) {
var h *gomatrixserverlib.HeaderedEvent var h *gomatrixserverlib.HeaderedEvent
h, err = t.lookupEvent(ctx, roomVersion, missingEventID, false) h, err = t.lookupEvent(ctx, roomVersion, missingEventID, false, servers)
switch err.(type) { switch err.(type) {
case verifySigError: case verifySigError:
return return
@ -989,7 +1003,7 @@ func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStat
return &respState, nil return &respState, nil
} }
func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) { func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, missingEventID string, localFirst bool, servers []gomatrixserverlib.ServerName) (*gomatrixserverlib.HeaderedEvent, error) {
if localFirst { if localFirst {
// fetch from the roomserver // fetch from the roomserver
queryReq := api.QueryEventsByIDRequest{ queryReq := api.QueryEventsByIDRequest{
@ -1004,11 +1018,7 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.
} }
var event gomatrixserverlib.Event var event gomatrixserverlib.Event
found := false found := false
numServers := len(t.servers) for _, serverName := range servers {
if numServers > 5 {
numServers = 5
}
for _, serverName := range append([]gomatrixserverlib.ServerName{t.Origin}, t.servers[:numServers]...) {
txn, err := t.federation.GetEvent(ctx, serverName, missingEventID) txn, err := t.federation.GetEvent(ctx, serverName, missingEventID)
if err != nil || len(txn.PDUs) == 0 { if err != nil || len(txn.PDUs) == 0 {
util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("Failed to get missing /event for event ID") util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("Failed to get missing /event for event ID")
@ -1023,8 +1033,8 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.
break break
} }
if !found { if !found {
util.GetLogger(ctx).WithField("event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", numServers) util.GetLogger(ctx).WithField("event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(servers))
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", numServers) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(servers))
} }
if err := gomatrixserverlib.VerifyAllEventSignatures(ctx, []gomatrixserverlib.Event{event}, t.keys); err != nil { if err := gomatrixserverlib.VerifyAllEventSignatures(ctx, []gomatrixserverlib.Event{event}, t.keys); err != nil {
util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID())