Implement getEvents differently

This commit is contained in:
Neil Alexander 2020-10-13 11:29:22 +01:00
parent 1b14f872df
commit 8f8cc66c03
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -112,8 +112,6 @@ type txnReq struct {
haveEvents map[string]*gomatrixserverlib.HeaderedEvent
// new events which the roomserver does not know about
newEvents map[string]bool
// servers which we should fetch missing events from
servers []gomatrixserverlib.ServerName
}
// A subset of FederationClient functionality that txn requires. Useful for testing.
@ -336,6 +334,19 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli
}
}
func (t *txnReq) getServers(ctx context.Context, roomID string) []gomatrixserverlib.ServerName {
servers := []gomatrixserverlib.ServerName{t.Origin}
serverReq := &api.QueryServerJoinedToRoomRequest{
RoomID: roomID,
}
serverRes := &api.QueryServerJoinedToRoomResponse{}
if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil {
servers = append(servers, serverRes.ServerNames...)
util.GetLogger(ctx).Infof("Found %d server(s) to query for missing events in %q", len(servers), roomID)
}
return servers
}
func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event) error {
logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID())
@ -365,22 +376,8 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event) er
return roomNotFoundError{e.RoomID()}
}
// We will need to know this when fetching missing auth or prev events.
getServers := func() {
t.servers = []gomatrixserverlib.ServerName{}
serverReq := &api.QueryServerJoinedToRoomRequest{
RoomID: e.RoomID(),
}
serverRes := &api.QueryServerJoinedToRoomResponse{}
if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil {
t.servers = append(t.servers, serverRes.ServerNames...)
logger.Infof("Found %d server(s) to query for missing events", len(t.servers))
}
}
if len(stateResp.MissingAuthEventIDs) > 0 {
logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs))
getServers()
if err := t.retrieveMissingAuthEvents(ctx, e, &stateResp); err != nil {
return fmt.Errorf("t.retrieveMissingAuthEvents: %w", err)
}
@ -388,7 +385,6 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event) er
if len(stateResp.MissingPrevEventIDs) > 0 {
logger.Infof("Event refers to %d unknown prev_events", len(stateResp.MissingPrevEventIDs))
getServers()
return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion)
}
@ -416,14 +412,14 @@ func (t *txnReq) retrieveMissingAuthEvents(
missingAuthEvents[missingAuthEventID] = struct{}{}
}
numServers := len(t.servers)
if numServers > 5 {
numServers = 5
servers := t.getServers(ctx, e.RoomID())
if len(servers) > 5 {
servers = servers[:5]
}
withNextEvent:
for missingAuthEventID := range missingAuthEvents {
withNextServer:
for _, server := range append([]gomatrixserverlib.ServerName{t.Origin}, t.servers[:numServers]...) {
for _, server := range servers {
logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server)
tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID)
if err != nil {
@ -435,14 +431,17 @@ withNextEvent:
logger.WithError(err).Warnf("Failed to unmarshal auth event %q", missingAuthEventID)
continue withNextServer
}
if err = api.SendEvents(
if err = api.SendInputRoomEvents(
context.Background(),
t.rsAPI,
[]gomatrixserverlib.HeaderedEvent{
ev.Headered(stateResp.RoomVersion),
[]api.InputRoomEvent{
{
Kind: api.KindOutlier,
Event: ev.Headered(stateResp.RoomVersion),
AuthEventIDs: ev.AuthEventIDs(),
SendAsServer: api.DoNotSendToOtherServers,
},
},
api.DoNotSendToOtherServers,
nil,
); err != nil {
return fmt.Errorf("api.SendEvents: %w", err)
}
@ -586,8 +585,13 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix
return nil, fmt.Errorf("t.lookupStateBeforeEvent: %w", err)
}
servers := t.getServers(ctx, roomID)
if len(servers) > 5 {
servers = servers[:5]
}
// fetch the event we're missing and add it to the pile
h, err := t.lookupEvent(ctx, roomVersion, eventID, false)
h, err := t.lookupEvent(ctx, roomVersion, eventID, false, servers)
switch err.(type) {
case verifySigError:
return respState, nil
@ -695,7 +699,11 @@ retryAllowedState:
if err = checkAllowedByState(*backwardsExtremity, resolvedStateEvents); err != nil {
switch missing := err.(type) {
case gomatrixserverlib.MissingAuthEventError:
h, err2 := t.lookupEvent(ctx, roomVersion, missing.AuthEventID, true)
servers := t.getServers(ctx, backwardsExtremity.RoomID())
if len(servers) > 5 {
servers = servers[:5]
}
h, err2 := t.lookupEvent(ctx, roomVersion, missing.AuthEventID, true, servers)
switch err2.(type) {
case verifySigError:
return &gomatrixserverlib.RespState{
@ -900,6 +908,12 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
"concurrent_requests": concurrentRequests,
}).Info("Fetching missing state at event")
// Get a list of servers to fetch from.
servers := t.getServers(ctx, roomID)
if len(servers) > 5 {
servers = servers[:5]
}
// Create a queue containing all of the missing event IDs that we want
// to retrieve.
pending := make(chan string, missingCount)
@ -925,7 +939,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
// Define what we'll do in order to fetch the missing event ID.
fetch := func(missingEventID string) {
var h *gomatrixserverlib.HeaderedEvent
h, err = t.lookupEvent(ctx, roomVersion, missingEventID, false)
h, err = t.lookupEvent(ctx, roomVersion, missingEventID, false, servers)
switch err.(type) {
case verifySigError:
return
@ -989,7 +1003,7 @@ func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStat
return &respState, nil
}
func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) {
func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, missingEventID string, localFirst bool, servers []gomatrixserverlib.ServerName) (*gomatrixserverlib.HeaderedEvent, error) {
if localFirst {
// fetch from the roomserver
queryReq := api.QueryEventsByIDRequest{
@ -1004,11 +1018,7 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.
}
var event gomatrixserverlib.Event
found := false
numServers := len(t.servers)
if numServers > 5 {
numServers = 5
}
for _, serverName := range append([]gomatrixserverlib.ServerName{t.Origin}, t.servers[:numServers]...) {
for _, serverName := range servers {
txn, err := t.federation.GetEvent(ctx, serverName, missingEventID)
if err != nil || len(txn.PDUs) == 0 {
util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("Failed to get missing /event for event ID")
@ -1023,8 +1033,8 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.
break
}
if !found {
util.GetLogger(ctx).WithField("event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", numServers)
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", numServers)
util.GetLogger(ctx).WithField("event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(servers))
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(servers))
}
if err := gomatrixserverlib.VerifyAllEventSignatures(ctx, []gomatrixserverlib.Event{event}, t.keys); err != nil {
util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID())