mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-22 14:21:55 -06:00
Use background context when processing event with missing state (#1403)
* Use background context when processing event with missing state * Five minute timeout * Remove context from txnreq, thread through instead * Fix unit tests
This commit is contained in:
parent
b9caccbce8
commit
895ead8048
|
@ -19,6 +19,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
|
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
|
||||||
|
@ -43,7 +44,6 @@ func Send(
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
t := txnReq{
|
t := txnReq{
|
||||||
context: httpReq.Context(),
|
|
||||||
rsAPI: rsAPI,
|
rsAPI: rsAPI,
|
||||||
eduAPI: eduAPI,
|
eduAPI: eduAPI,
|
||||||
keys: keys,
|
keys: keys,
|
||||||
|
@ -82,7 +82,7 @@ func Send(
|
||||||
|
|
||||||
util.GetLogger(httpReq.Context()).Infof("Received transaction %q containing %d PDUs, %d EDUs", txnID, len(t.PDUs), len(t.EDUs))
|
util.GetLogger(httpReq.Context()).Infof("Received transaction %q containing %d PDUs, %d EDUs", txnID, len(t.PDUs), len(t.EDUs))
|
||||||
|
|
||||||
resp, jsonErr := t.processTransaction()
|
resp, jsonErr := t.processTransaction(httpReq.Context())
|
||||||
if jsonErr != nil {
|
if jsonErr != nil {
|
||||||
util.GetLogger(httpReq.Context()).WithField("jsonErr", jsonErr).Error("t.processTransaction failed")
|
util.GetLogger(httpReq.Context()).WithField("jsonErr", jsonErr).Error("t.processTransaction failed")
|
||||||
return *jsonErr
|
return *jsonErr
|
||||||
|
@ -100,7 +100,6 @@ func Send(
|
||||||
|
|
||||||
type txnReq struct {
|
type txnReq struct {
|
||||||
gomatrixserverlib.Transaction
|
gomatrixserverlib.Transaction
|
||||||
context context.Context
|
|
||||||
rsAPI api.RoomserverInternalAPI
|
rsAPI api.RoomserverInternalAPI
|
||||||
eduAPI eduserverAPI.EDUServerInputAPI
|
eduAPI eduserverAPI.EDUServerInputAPI
|
||||||
keyAPI keyapi.KeyInternalAPI
|
keyAPI keyapi.KeyInternalAPI
|
||||||
|
@ -124,7 +123,7 @@ type txnFederationClient interface {
|
||||||
roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
|
roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONResponse) {
|
func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) {
|
||||||
results := make(map[string]gomatrixserverlib.PDUResult)
|
results := make(map[string]gomatrixserverlib.PDUResult)
|
||||||
|
|
||||||
pdus := []gomatrixserverlib.HeaderedEvent{}
|
pdus := []gomatrixserverlib.HeaderedEvent{}
|
||||||
|
@ -133,15 +132,15 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe
|
||||||
RoomID string `json:"room_id"`
|
RoomID string `json:"room_id"`
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(pdu, &header); err != nil {
|
if err := json.Unmarshal(pdu, &header); err != nil {
|
||||||
util.GetLogger(t.context).WithError(err).Warn("Transaction: Failed to extract room ID from event")
|
util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to extract room ID from event")
|
||||||
// We don't know the event ID at this point so we can't return the
|
// We don't know the event ID at this point so we can't return the
|
||||||
// failure in the PDU results
|
// failure in the PDU results
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
verReq := api.QueryRoomVersionForRoomRequest{RoomID: header.RoomID}
|
verReq := api.QueryRoomVersionForRoomRequest{RoomID: header.RoomID}
|
||||||
verRes := api.QueryRoomVersionForRoomResponse{}
|
verRes := api.QueryRoomVersionForRoomResponse{}
|
||||||
if err := t.rsAPI.QueryRoomVersionForRoom(t.context, &verReq, &verRes); err != nil {
|
if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil {
|
||||||
util.GetLogger(t.context).WithError(err).Warn("Transaction: Failed to query room version for room", verReq.RoomID)
|
util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to query room version for room", verReq.RoomID)
|
||||||
// We don't know the event ID at this point so we can't return the
|
// We don't know the event ID at this point so we can't return the
|
||||||
// failure in the PDU results
|
// failure in the PDU results
|
||||||
continue
|
continue
|
||||||
|
@ -161,17 +160,17 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe
|
||||||
JSON: jsonerror.BadJSON("PDU contains bad JSON"),
|
JSON: jsonerror.BadJSON("PDU contains bad JSON"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
util.GetLogger(t.context).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %s", string(pdu))
|
util.GetLogger(ctx).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %s", string(pdu))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if api.IsServerBannedFromRoom(t.context, t.rsAPI, event.RoomID(), t.Origin) {
|
if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) {
|
||||||
results[event.EventID()] = gomatrixserverlib.PDUResult{
|
results[event.EventID()] = gomatrixserverlib.PDUResult{
|
||||||
Error: "Forbidden by server ACLs",
|
Error: "Forbidden by server ACLs",
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err = gomatrixserverlib.VerifyAllEventSignatures(t.context, []gomatrixserverlib.Event{event}, t.keys); err != nil {
|
if err = gomatrixserverlib.VerifyAllEventSignatures(ctx, []gomatrixserverlib.Event{event}, t.keys); err != nil {
|
||||||
util.GetLogger(t.context).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID())
|
util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID())
|
||||||
results[event.EventID()] = gomatrixserverlib.PDUResult{
|
results[event.EventID()] = gomatrixserverlib.PDUResult{
|
||||||
Error: err.Error(),
|
Error: err.Error(),
|
||||||
}
|
}
|
||||||
|
@ -182,7 +181,7 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe
|
||||||
|
|
||||||
// Process the events.
|
// Process the events.
|
||||||
for _, e := range pdus {
|
for _, e := range pdus {
|
||||||
if err := t.processEvent(e.Unwrap(), true); err != nil {
|
if err := t.processEvent(ctx, e.Unwrap(), true); err != nil {
|
||||||
// If the error is due to the event itself being bad then we skip
|
// If the error is due to the event itself being bad then we skip
|
||||||
// it and move onto the next event. We report an error so that the
|
// it and move onto the next event. We report an error so that the
|
||||||
// sender knows that we have skipped processing it.
|
// sender knows that we have skipped processing it.
|
||||||
|
@ -201,7 +200,7 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe
|
||||||
if isProcessingErrorFatal(err) {
|
if isProcessingErrorFatal(err) {
|
||||||
// Any other error should be the result of a temporary error in
|
// Any other error should be the result of a temporary error in
|
||||||
// our server so we should bail processing the transaction entirely.
|
// our server so we should bail processing the transaction entirely.
|
||||||
util.GetLogger(t.context).Warnf("Processing %s failed fatally: %s", e.EventID(), err)
|
util.GetLogger(ctx).Warnf("Processing %s failed fatally: %s", e.EventID(), err)
|
||||||
jsonErr := util.ErrorResponse(err)
|
jsonErr := util.ErrorResponse(err)
|
||||||
return nil, &jsonErr
|
return nil, &jsonErr
|
||||||
} else {
|
} else {
|
||||||
|
@ -211,7 +210,7 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe
|
||||||
if rejected {
|
if rejected {
|
||||||
errMsg = ""
|
errMsg = ""
|
||||||
}
|
}
|
||||||
util.GetLogger(t.context).WithError(err).WithField("event_id", e.EventID()).WithField("rejected", rejected).Warn(
|
util.GetLogger(ctx).WithError(err).WithField("event_id", e.EventID()).WithField("rejected", rejected).Warn(
|
||||||
"Failed to process incoming federation event, skipping",
|
"Failed to process incoming federation event, skipping",
|
||||||
)
|
)
|
||||||
results[e.EventID()] = gomatrixserverlib.PDUResult{
|
results[e.EventID()] = gomatrixserverlib.PDUResult{
|
||||||
|
@ -223,9 +222,9 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t.processEDUs(t.EDUs)
|
t.processEDUs(ctx)
|
||||||
if c := len(results); c > 0 {
|
if c := len(results); c > 0 {
|
||||||
util.GetLogger(t.context).Infof("Processed %d PDUs from transaction %q", c, t.TransactionID)
|
util.GetLogger(ctx).Infof("Processed %d PDUs from transaction %q", c, t.TransactionID)
|
||||||
}
|
}
|
||||||
return &gomatrixserverlib.RespSend{PDUs: results}, nil
|
return &gomatrixserverlib.RespSend{PDUs: results}, nil
|
||||||
}
|
}
|
||||||
|
@ -284,8 +283,9 @@ func (t *txnReq) haveEventIDs() map[string]bool {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) {
|
// nolint:gocyclo
|
||||||
for _, e := range edus {
|
func (t *txnReq) processEDUs(ctx context.Context) {
|
||||||
|
for _, e := range t.EDUs {
|
||||||
switch e.Type {
|
switch e.Type {
|
||||||
case gomatrixserverlib.MTyping:
|
case gomatrixserverlib.MTyping:
|
||||||
// https://matrix.org/docs/spec/server_server/latest#typing-notifications
|
// https://matrix.org/docs/spec/server_server/latest#typing-notifications
|
||||||
|
@ -295,24 +295,24 @@ func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) {
|
||||||
Typing bool `json:"typing"`
|
Typing bool `json:"typing"`
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(e.Content, &typingPayload); err != nil {
|
if err := json.Unmarshal(e.Content, &typingPayload); err != nil {
|
||||||
util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal typing event")
|
util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal typing event")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := eduserverAPI.SendTyping(t.context, t.eduAPI, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil {
|
if err := eduserverAPI.SendTyping(ctx, t.eduAPI, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil {
|
||||||
util.GetLogger(t.context).WithError(err).Error("Failed to send typing event to edu server")
|
util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to edu server")
|
||||||
}
|
}
|
||||||
case gomatrixserverlib.MDirectToDevice:
|
case gomatrixserverlib.MDirectToDevice:
|
||||||
// https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema
|
// https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema
|
||||||
var directPayload gomatrixserverlib.ToDeviceMessage
|
var directPayload gomatrixserverlib.ToDeviceMessage
|
||||||
if err := json.Unmarshal(e.Content, &directPayload); err != nil {
|
if err := json.Unmarshal(e.Content, &directPayload); err != nil {
|
||||||
util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal send-to-device events")
|
util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal send-to-device events")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for userID, byUser := range directPayload.Messages {
|
for userID, byUser := range directPayload.Messages {
|
||||||
for deviceID, message := range byUser {
|
for deviceID, message := range byUser {
|
||||||
// TODO: check that the user and the device actually exist here
|
// TODO: check that the user and the device actually exist here
|
||||||
if err := eduserverAPI.SendToDevice(t.context, t.eduAPI, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil {
|
if err := eduserverAPI.SendToDevice(ctx, t.eduAPI, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil {
|
||||||
util.GetLogger(t.context).WithError(err).WithFields(logrus.Fields{
|
util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{
|
||||||
"sender": directPayload.Sender,
|
"sender": directPayload.Sender,
|
||||||
"user_id": userID,
|
"user_id": userID,
|
||||||
"device_id": deviceID,
|
"device_id": deviceID,
|
||||||
|
@ -321,17 +321,17 @@ func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case gomatrixserverlib.MDeviceListUpdate:
|
case gomatrixserverlib.MDeviceListUpdate:
|
||||||
t.processDeviceListUpdate(e)
|
t.processDeviceListUpdate(ctx, e)
|
||||||
default:
|
default:
|
||||||
util.GetLogger(t.context).WithField("type", e.Type).Debug("Unhandled EDU")
|
util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *txnReq) processDeviceListUpdate(e gomatrixserverlib.EDU) {
|
func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverlib.EDU) {
|
||||||
var payload gomatrixserverlib.DeviceListUpdateEvent
|
var payload gomatrixserverlib.DeviceListUpdateEvent
|
||||||
if err := json.Unmarshal(e.Content, &payload); err != nil {
|
if err := json.Unmarshal(e.Content, &payload); err != nil {
|
||||||
util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal device list update event")
|
util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal device list update event")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var inputRes keyapi.InputDeviceListUpdateResponse
|
var inputRes keyapi.InputDeviceListUpdateResponse
|
||||||
|
@ -339,11 +339,11 @@ func (t *txnReq) processDeviceListUpdate(e gomatrixserverlib.EDU) {
|
||||||
Event: payload,
|
Event: payload,
|
||||||
}, &inputRes)
|
}, &inputRes)
|
||||||
if inputRes.Error != nil {
|
if inputRes.Error != nil {
|
||||||
util.GetLogger(t.context).WithError(inputRes.Error).WithField("user_id", payload.UserID).Error("failed to InputDeviceListUpdate")
|
util.GetLogger(ctx).WithError(inputRes.Error).WithField("user_id", payload.UserID).Error("failed to InputDeviceListUpdate")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) error {
|
func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, isInboundTxn bool) error {
|
||||||
prevEventIDs := e.PrevEventIDs()
|
prevEventIDs := e.PrevEventIDs()
|
||||||
|
|
||||||
// Fetch the state needed to authenticate the event.
|
// Fetch the state needed to authenticate the event.
|
||||||
|
@ -354,7 +354,7 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro
|
||||||
StateToFetch: needed.Tuples(),
|
StateToFetch: needed.Tuples(),
|
||||||
}
|
}
|
||||||
var stateResp api.QueryStateAfterEventsResponse
|
var stateResp api.QueryStateAfterEventsResponse
|
||||||
if err := t.rsAPI.QueryStateAfterEvents(t.context, &stateReq, &stateResp); err != nil {
|
if err := t.rsAPI.QueryStateAfterEvents(ctx, &stateReq, &stateResp); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -369,7 +369,7 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro
|
||||||
}
|
}
|
||||||
|
|
||||||
if !stateResp.PrevEventsExist {
|
if !stateResp.PrevEventsExist {
|
||||||
return t.processEventWithMissingState(e, stateResp.RoomVersion, isInboundTxn)
|
return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion, isInboundTxn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that the event is allowed by the state at the event.
|
// Check that the event is allowed by the state at the event.
|
||||||
|
@ -379,7 +379,8 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro
|
||||||
|
|
||||||
// pass the event to the roomserver
|
// pass the event to the roomserver
|
||||||
return api.SendEvents(
|
return api.SendEvents(
|
||||||
t.context, t.rsAPI,
|
context.Background(),
|
||||||
|
t.rsAPI,
|
||||||
[]gomatrixserverlib.HeaderedEvent{
|
[]gomatrixserverlib.HeaderedEvent{
|
||||||
e.Headered(stateResp.RoomVersion),
|
e.Headered(stateResp.RoomVersion),
|
||||||
},
|
},
|
||||||
|
@ -399,7 +400,12 @@ func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserver
|
||||||
return gomatrixserverlib.Allowed(e, &authUsingState)
|
return gomatrixserverlib.Allowed(e, &authUsingState)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) error {
|
func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) error {
|
||||||
|
// Do this with a fresh context, so that we keep working even if the
|
||||||
|
// original request times out. With any luck, by the time the remote
|
||||||
|
// side retries, we'll have fetched the missing state.
|
||||||
|
gmectx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
|
||||||
|
defer cancel()
|
||||||
// We are missing the previous events for this events.
|
// We are missing the previous events for this events.
|
||||||
// This means that there is a gap in our view of the history of the
|
// This means that there is a gap in our view of the history of the
|
||||||
// room. There two ways that we can handle such a gap:
|
// room. There two ways that we can handle such a gap:
|
||||||
|
@ -420,7 +426,7 @@ func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVer
|
||||||
// - fill in the gap completely then process event `e` returning no backwards extremity
|
// - fill in the gap completely then process event `e` returning no backwards extremity
|
||||||
// - fail to fill in the gap and tell us to terminate the transaction err=not nil
|
// - fail to fill in the gap and tell us to terminate the transaction err=not nil
|
||||||
// - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction
|
// - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction
|
||||||
backwardsExtremity, err := t.getMissingEvents(e, roomVersion, isInboundTxn)
|
backwardsExtremity, err := t.getMissingEvents(gmectx, e, roomVersion, isInboundTxn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -437,16 +443,16 @@ func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVer
|
||||||
needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{*backwardsExtremity}).Tuples()
|
needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{*backwardsExtremity}).Tuples()
|
||||||
for _, prevEventID := range backwardsExtremity.PrevEventIDs() {
|
for _, prevEventID := range backwardsExtremity.PrevEventIDs() {
|
||||||
var prevState *gomatrixserverlib.RespState
|
var prevState *gomatrixserverlib.RespState
|
||||||
prevState, err = t.lookupStateAfterEvent(roomVersion, backwardsExtremity.RoomID(), prevEventID, needed)
|
prevState, err = t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID, needed)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(t.context).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID)
|
util.GetLogger(ctx).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
states = append(states, prevState)
|
states = append(states, prevState)
|
||||||
}
|
}
|
||||||
resolvedState, err := t.resolveStatesAndCheck(roomVersion, states, backwardsExtremity)
|
resolvedState, err := t.resolveStatesAndCheck(gmectx, roomVersion, states, backwardsExtremity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(t.context).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID())
|
util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -457,20 +463,20 @@ func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVer
|
||||||
|
|
||||||
// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event)
|
// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event)
|
||||||
// added into the mix.
|
// added into the mix.
|
||||||
func (t *txnReq) lookupStateAfterEvent(roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) (*gomatrixserverlib.RespState, error) {
|
func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) (*gomatrixserverlib.RespState, error) {
|
||||||
// try doing all this locally before we resort to querying federation
|
// try doing all this locally before we resort to querying federation
|
||||||
respState := t.lookupStateAfterEventLocally(roomID, eventID, needed)
|
respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID, needed)
|
||||||
if respState != nil {
|
if respState != nil {
|
||||||
return respState, nil
|
return respState, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
respState, err := t.lookupStateBeforeEvent(roomVersion, roomID, eventID)
|
respState, err := t.lookupStateBeforeEvent(ctx, roomVersion, roomID, eventID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// fetch the event we're missing and add it to the pile
|
// fetch the event we're missing and add it to the pile
|
||||||
h, err := t.lookupEvent(roomVersion, eventID, false)
|
h, err := t.lookupEvent(ctx, roomVersion, eventID, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -493,15 +499,15 @@ func (t *txnReq) lookupStateAfterEvent(roomVersion gomatrixserverlib.RoomVersion
|
||||||
return respState, nil
|
return respState, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *txnReq) lookupStateAfterEventLocally(roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.RespState {
|
func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.RespState {
|
||||||
var res api.QueryStateAfterEventsResponse
|
var res api.QueryStateAfterEventsResponse
|
||||||
err := t.rsAPI.QueryStateAfterEvents(t.context, &api.QueryStateAfterEventsRequest{
|
err := t.rsAPI.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{
|
||||||
RoomID: roomID,
|
RoomID: roomID,
|
||||||
PrevEventIDs: []string{eventID},
|
PrevEventIDs: []string{eventID},
|
||||||
StateToFetch: needed,
|
StateToFetch: needed,
|
||||||
}, &res)
|
}, &res)
|
||||||
if err != nil || !res.PrevEventsExist {
|
if err != nil || !res.PrevEventsExist {
|
||||||
util.GetLogger(t.context).WithError(err).Warnf("failed to query state after %s locally", eventID)
|
util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
for i, ev := range res.StateEvents {
|
for i, ev := range res.StateEvents {
|
||||||
|
@ -528,9 +534,9 @@ func (t *txnReq) lookupStateAfterEventLocally(roomID, eventID string, needed []g
|
||||||
queryReq := api.QueryEventsByIDRequest{
|
queryReq := api.QueryEventsByIDRequest{
|
||||||
EventIDs: missingEventList,
|
EventIDs: missingEventList,
|
||||||
}
|
}
|
||||||
util.GetLogger(t.context).Infof("Fetching missing auth events: %v", missingEventList)
|
util.GetLogger(ctx).Infof("Fetching missing auth events: %v", missingEventList)
|
||||||
var queryRes api.QueryEventsByIDResponse
|
var queryRes api.QueryEventsByIDResponse
|
||||||
if err = t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil {
|
if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
for i := range queryRes.Events {
|
for i := range queryRes.Events {
|
||||||
|
@ -548,22 +554,22 @@ func (t *txnReq) lookupStateAfterEventLocally(roomID, eventID string, needed []g
|
||||||
|
|
||||||
// lookuptStateBeforeEvent returns the room state before the event e, which is just /state_ids and/or /state depending on what
|
// lookuptStateBeforeEvent returns the room state before the event e, which is just /state_ids and/or /state depending on what
|
||||||
// the server supports.
|
// the server supports.
|
||||||
func (t *txnReq) lookupStateBeforeEvent(roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (
|
func (t *txnReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (
|
||||||
respState *gomatrixserverlib.RespState, err error) {
|
respState *gomatrixserverlib.RespState, err error) {
|
||||||
|
|
||||||
util.GetLogger(t.context).Infof("lookupStateBeforeEvent %s", eventID)
|
util.GetLogger(ctx).Infof("lookupStateBeforeEvent %s", eventID)
|
||||||
|
|
||||||
// Attempt to fetch the missing state using /state_ids and /events
|
// Attempt to fetch the missing state using /state_ids and /events
|
||||||
respState, err = t.lookupMissingStateViaStateIDs(roomID, eventID, roomVersion)
|
respState, err = t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Fallback to /state
|
// Fallback to /state
|
||||||
util.GetLogger(t.context).WithError(err).Warn("lookupStateBeforeEvent failed to /state_ids, falling back to /state")
|
util.GetLogger(ctx).WithError(err).Warn("lookupStateBeforeEvent failed to /state_ids, falling back to /state")
|
||||||
respState, err = t.lookupMissingStateViaState(roomID, eventID, roomVersion)
|
respState, err = t.lookupMissingStateViaState(ctx, roomID, eventID, roomVersion)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *txnReq) resolveStatesAndCheck(roomVersion gomatrixserverlib.RoomVersion, states []*gomatrixserverlib.RespState, backwardsExtremity *gomatrixserverlib.Event) (*gomatrixserverlib.RespState, error) {
|
func (t *txnReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*gomatrixserverlib.RespState, backwardsExtremity *gomatrixserverlib.Event) (*gomatrixserverlib.RespState, error) {
|
||||||
var authEventList []gomatrixserverlib.Event
|
var authEventList []gomatrixserverlib.Event
|
||||||
var stateEventList []gomatrixserverlib.Event
|
var stateEventList []gomatrixserverlib.Event
|
||||||
for _, state := range states {
|
for _, state := range states {
|
||||||
|
@ -579,11 +585,11 @@ retryAllowedState:
|
||||||
if err = checkAllowedByState(*backwardsExtremity, resolvedStateEvents); err != nil {
|
if err = checkAllowedByState(*backwardsExtremity, resolvedStateEvents); err != nil {
|
||||||
switch missing := err.(type) {
|
switch missing := err.(type) {
|
||||||
case gomatrixserverlib.MissingAuthEventError:
|
case gomatrixserverlib.MissingAuthEventError:
|
||||||
h, err2 := t.lookupEvent(roomVersion, missing.AuthEventID, true)
|
h, err2 := t.lookupEvent(ctx, roomVersion, missing.AuthEventID, true)
|
||||||
if err2 != nil {
|
if err2 != nil {
|
||||||
return nil, fmt.Errorf("missing auth event %s and failed to look it up: %w", missing.AuthEventID, err2)
|
return nil, fmt.Errorf("missing auth event %s and failed to look it up: %w", missing.AuthEventID, err2)
|
||||||
}
|
}
|
||||||
util.GetLogger(t.context).Infof("fetched event %s", missing.AuthEventID)
|
util.GetLogger(ctx).Infof("fetched event %s", missing.AuthEventID)
|
||||||
resolvedStateEvents = append(resolvedStateEvents, h.Unwrap())
|
resolvedStateEvents = append(resolvedStateEvents, h.Unwrap())
|
||||||
goto retryAllowedState
|
goto retryAllowedState
|
||||||
default:
|
default:
|
||||||
|
@ -600,12 +606,12 @@ retryAllowedState:
|
||||||
// begin from. Returns an error only if we should terminate the transaction which initiated /get_missing_events
|
// begin from. Returns an error only if we should terminate the transaction which initiated /get_missing_events
|
||||||
// This function recursively calls txnReq.processEvent with the missing events, which will be processed before this function returns.
|
// This function recursively calls txnReq.processEvent with the missing events, which will be processed before this function returns.
|
||||||
// This means that we may recursively call this function, as we spider back up prev_events to the min depth.
|
// This means that we may recursively call this function, as we spider back up prev_events to the min depth.
|
||||||
func (t *txnReq) getMissingEvents(e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) (backwardsExtremity *gomatrixserverlib.Event, err error) {
|
func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) (backwardsExtremity *gomatrixserverlib.Event, err error) {
|
||||||
if !isInboundTxn {
|
if !isInboundTxn {
|
||||||
// we've recursed here, so just take a state snapshot please!
|
// we've recursed here, so just take a state snapshot please!
|
||||||
return &e, nil
|
return &e, nil
|
||||||
}
|
}
|
||||||
logger := util.GetLogger(t.context).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID())
|
logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID())
|
||||||
needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{e})
|
needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{e})
|
||||||
// query latest events (our trusted forward extremities)
|
// query latest events (our trusted forward extremities)
|
||||||
req := api.QueryLatestEventsAndStateRequest{
|
req := api.QueryLatestEventsAndStateRequest{
|
||||||
|
@ -613,7 +619,7 @@ func (t *txnReq) getMissingEvents(e gomatrixserverlib.Event, roomVersion gomatri
|
||||||
StateToFetch: needed.Tuples(),
|
StateToFetch: needed.Tuples(),
|
||||||
}
|
}
|
||||||
var res api.QueryLatestEventsAndStateResponse
|
var res api.QueryLatestEventsAndStateResponse
|
||||||
if err = t.rsAPI.QueryLatestEventsAndState(t.context, &req, &res); err != nil {
|
if err = t.rsAPI.QueryLatestEventsAndState(ctx, &req, &res); err != nil {
|
||||||
logger.WithError(err).Warn("Failed to query latest events")
|
logger.WithError(err).Warn("Failed to query latest events")
|
||||||
return &e, nil
|
return &e, nil
|
||||||
}
|
}
|
||||||
|
@ -626,7 +632,7 @@ func (t *txnReq) getMissingEvents(e gomatrixserverlib.Event, roomVersion gomatri
|
||||||
if minDepth < 0 {
|
if minDepth < 0 {
|
||||||
minDepth = 0
|
minDepth = 0
|
||||||
}
|
}
|
||||||
missingResp, err := t.federation.LookupMissingEvents(t.context, t.Origin, e.RoomID(), gomatrixserverlib.MissingEvents{
|
missingResp, err := t.federation.LookupMissingEvents(ctx, t.Origin, e.RoomID(), gomatrixserverlib.MissingEvents{
|
||||||
Limit: 20,
|
Limit: 20,
|
||||||
// synapse uses the min depth they've ever seen in that room
|
// synapse uses the min depth they've ever seen in that room
|
||||||
MinDepth: minDepth,
|
MinDepth: minDepth,
|
||||||
|
@ -685,7 +691,7 @@ Event:
|
||||||
}
|
}
|
||||||
// process the missing events then the event which started this whole thing
|
// process the missing events then the event which started this whole thing
|
||||||
for _, ev := range append(newEvents, e) {
|
for _, ev := range append(newEvents, e) {
|
||||||
err := t.processEvent(ev, false)
|
err := t.processEvent(ctx, ev, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -695,24 +701,24 @@ Event:
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *txnReq) lookupMissingStateViaState(roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
|
func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
|
||||||
respState *gomatrixserverlib.RespState, err error) {
|
respState *gomatrixserverlib.RespState, err error) {
|
||||||
state, err := t.federation.LookupState(t.context, t.Origin, roomID, eventID, roomVersion)
|
state, err := t.federation.LookupState(ctx, t.Origin, roomID, eventID, roomVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// Check that the returned state is valid.
|
// Check that the returned state is valid.
|
||||||
if err := state.Check(t.context, t.keys, nil); err != nil {
|
if err := state.Check(ctx, t.keys, nil); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &state, nil
|
return &state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *txnReq) lookupMissingStateViaStateIDs(roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
|
func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
|
||||||
*gomatrixserverlib.RespState, error) {
|
*gomatrixserverlib.RespState, error) {
|
||||||
util.GetLogger(t.context).Infof("lookupMissingStateViaStateIDs %s", eventID)
|
util.GetLogger(ctx).Infof("lookupMissingStateViaStateIDs %s", eventID)
|
||||||
// fetch the state event IDs at the time of the event
|
// fetch the state event IDs at the time of the event
|
||||||
stateIDs, err := t.federation.LookupStateIDs(t.context, t.Origin, roomID, eventID)
|
stateIDs, err := t.federation.LookupStateIDs(ctx, t.Origin, roomID, eventID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -734,7 +740,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(roomID, eventID string, roomVersi
|
||||||
EventIDs: missingEventList,
|
EventIDs: missingEventList,
|
||||||
}
|
}
|
||||||
var queryRes api.QueryEventsByIDResponse
|
var queryRes api.QueryEventsByIDResponse
|
||||||
if err = t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil {
|
if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
for i := range queryRes.Events {
|
for i := range queryRes.Events {
|
||||||
|
@ -745,7 +751,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(roomID, eventID string, roomVersi
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
util.GetLogger(t.context).WithFields(logrus.Fields{
|
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||||
"missing": len(missing),
|
"missing": len(missing),
|
||||||
"event_id": eventID,
|
"event_id": eventID,
|
||||||
"room_id": roomID,
|
"room_id": roomID,
|
||||||
|
@ -755,7 +761,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(roomID, eventID string, roomVersi
|
||||||
|
|
||||||
for missingEventID := range missing {
|
for missingEventID := range missing {
|
||||||
var h *gomatrixserverlib.HeaderedEvent
|
var h *gomatrixserverlib.HeaderedEvent
|
||||||
h, err = t.lookupEvent(roomVersion, missingEventID, false)
|
h, err = t.lookupEvent(ctx, roomVersion, missingEventID, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -793,33 +799,33 @@ func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStat
|
||||||
return &respState, nil
|
return &respState, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *txnReq) lookupEvent(roomVersion gomatrixserverlib.RoomVersion, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) {
|
func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
if localFirst {
|
if localFirst {
|
||||||
// fetch from the roomserver
|
// fetch from the roomserver
|
||||||
queryReq := api.QueryEventsByIDRequest{
|
queryReq := api.QueryEventsByIDRequest{
|
||||||
EventIDs: []string{missingEventID},
|
EventIDs: []string{missingEventID},
|
||||||
}
|
}
|
||||||
var queryRes api.QueryEventsByIDResponse
|
var queryRes api.QueryEventsByIDResponse
|
||||||
if err := t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil {
|
if err := t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
|
||||||
util.GetLogger(t.context).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err)
|
util.GetLogger(ctx).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err)
|
||||||
} else if len(queryRes.Events) == 1 {
|
} else if len(queryRes.Events) == 1 {
|
||||||
return &queryRes.Events[0], nil
|
return &queryRes.Events[0], nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
txn, err := t.federation.GetEvent(t.context, t.Origin, missingEventID)
|
txn, err := t.federation.GetEvent(ctx, t.Origin, missingEventID)
|
||||||
if err != nil || len(txn.PDUs) == 0 {
|
if err != nil || len(txn.PDUs) == 0 {
|
||||||
util.GetLogger(t.context).WithError(err).WithField("event_id", missingEventID).Warn("failed to get missing /event for event ID")
|
util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("failed to get missing /event for event ID")
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
pdu := txn.PDUs[0]
|
pdu := txn.PDUs[0]
|
||||||
var event gomatrixserverlib.Event
|
var event gomatrixserverlib.Event
|
||||||
event, err = gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion)
|
event, err = gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(t.context).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %q", event.EventID())
|
util.GetLogger(ctx).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %q", event.EventID())
|
||||||
return nil, unmarshalError{err}
|
return nil, unmarshalError{err}
|
||||||
}
|
}
|
||||||
if err = gomatrixserverlib.VerifyAllEventSignatures(t.context, []gomatrixserverlib.Event{event}, t.keys); err != nil {
|
if err = gomatrixserverlib.VerifyAllEventSignatures(ctx, []gomatrixserverlib.Event{event}, t.keys); err != nil {
|
||||||
util.GetLogger(t.context).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID())
|
util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID())
|
||||||
return nil, verifySigError{event.EventID(), err}
|
return nil, verifySigError{event.EventID(), err}
|
||||||
}
|
}
|
||||||
h := event.Headered(roomVersion)
|
h := event.Headered(roomVersion)
|
||||||
|
|
|
@ -365,7 +365,6 @@ func (c *txnFedClient) LookupMissingEvents(ctx context.Context, s gomatrixserver
|
||||||
|
|
||||||
func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq {
|
func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq {
|
||||||
t := &txnReq{
|
t := &txnReq{
|
||||||
context: context.Background(),
|
|
||||||
rsAPI: rsAPI,
|
rsAPI: rsAPI,
|
||||||
eduAPI: &testEDUProducer{},
|
eduAPI: &testEDUProducer{},
|
||||||
keys: &test.NopJSONVerifier{},
|
keys: &test.NopJSONVerifier{},
|
||||||
|
@ -381,7 +380,7 @@ func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederat
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustProcessTransaction(t *testing.T, txn *txnReq, pdusWithErrors []string) {
|
func mustProcessTransaction(t *testing.T, txn *txnReq, pdusWithErrors []string) {
|
||||||
res, err := txn.processTransaction()
|
res, err := txn.processTransaction(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("txn.processTransaction returned an error: %v", err)
|
t.Errorf("txn.processTransaction returned an error: %v", err)
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in a new issue