Ensure the input API only uses a single transaction

This commit is contained in:
Neil Alexander 2022-02-11 13:04:29 +00:00
parent f800cae6d2
commit 421d819d83
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
8 changed files with 100 additions and 149 deletions

View file

@ -83,13 +83,6 @@ type RoomserverInternalAPI interface {
response *QueryStateAfterEventsResponse,
) error
// Query whether the roomserver is missing any auth or prev events.
QueryMissingAuthPrevEvents(
ctx context.Context,
request *QueryMissingAuthPrevEventsRequest,
response *QueryMissingAuthPrevEventsResponse,
) error
// Query a list of events by event ID.
QueryEventsByID(
ctx context.Context,

View file

@ -129,16 +129,6 @@ func (t *RoomserverInternalAPITrace) QueryStateAfterEvents(
return err
}
func (t *RoomserverInternalAPITrace) QueryMissingAuthPrevEvents(
ctx context.Context,
req *QueryMissingAuthPrevEventsRequest,
res *QueryMissingAuthPrevEventsResponse,
) error {
err := t.Impl.QueryMissingAuthPrevEvents(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("QueryMissingAuthPrevEvents req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) QueryEventsByID(
ctx context.Context,
req *QueryEventsByIDRequest,

View file

@ -128,20 +128,16 @@ func (r *Inputer) processRoomEvent(
}
}
missingRes := &api.QueryMissingAuthPrevEventsResponse{}
var missingAuth, missingPrev bool
serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{}
if event.Type() != gomatrixserverlib.MRoomCreate || !event.StateKeyEquals("") {
missingReq := &api.QueryMissingAuthPrevEventsRequest{
RoomID: event.RoomID(),
AuthEventIDs: event.AuthEventIDs(),
PrevEventIDs: event.PrevEventIDs(),
}
if err := r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil {
return rollbackTransaction, fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err)
missingAuthIDs, missingPrevIDs, err := updater.MissingAuthPrevEvents(ctx, event)
if err != nil {
return rollbackTransaction, fmt.Errorf("r.DB.MissingAuthPrevEvents: %w", err)
}
missingAuth = len(missingAuthIDs) > 0
missingPrev = !input.HasState && len(missingPrevIDs) > 0
}
missingAuth := len(missingRes.MissingAuthEventIDs) > 0
missingPrev := !input.HasState && len(missingRes.MissingPrevEventIDs) > 0
if missingAuth || missingPrev {
serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{
@ -246,14 +242,13 @@ func (r *Inputer) processRoomEvent(
missingState := missingStateReq{
origin: input.Origin,
inputer: r,
queryer: r.Queryer,
db: updater,
federation: r.FSAPI,
keys: r.KeyRing,
roomsMu: internal.NewMutexByRoom(),
servers: serverRes.ServerNames,
hadEvents: map[string]bool{},
haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{},
haveEvents: map[string]*gomatrixserverlib.Event{},
}
if stateSnapshot, err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil {
// Something went wrong with retrieving the missing state, so we can't

View file

@ -10,7 +10,7 @@ import (
fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
@ -27,14 +27,14 @@ type missingStateReq struct {
origin gomatrixserverlib.ServerName
db *shared.RoomUpdater
inputer *Inputer
queryer *query.Queryer
roomInfo *types.RoomInfo
keys gomatrixserverlib.JSONVerifier
federation fedapi.FederationInternalAPI
roomsMu *internal.MutexByRoom
servers []gomatrixserverlib.ServerName
hadEvents map[string]bool
hadEventsMutex sync.Mutex
haveEvents map[string]*gomatrixserverlib.HeaderedEvent
haveEvents map[string]*gomatrixserverlib.Event
haveEventsMutex sync.Mutex
}
@ -326,20 +326,20 @@ func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion
for i := range respState.StateEvents {
se := respState.StateEvents[i]
if se.Type() == h.Type() && se.StateKeyEquals(*h.StateKey()) {
respState.StateEvents[i] = h.Unwrap()
respState.StateEvents[i] = h
addedToState = true
break
}
}
if !addedToState {
respState.StateEvents = append(respState.StateEvents, h.Unwrap())
respState.StateEvents = append(respState.StateEvents, h)
}
}
return respState, false, nil
}
func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *gomatrixserverlib.HeaderedEvent {
func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.Event) *gomatrixserverlib.Event {
t.haveEventsMutex.Lock()
defer t.haveEventsMutex.Unlock()
if cached, exists := t.haveEvents[ev.EventID()]; exists {
@ -350,32 +350,51 @@ func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *g
}
func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *parsedRespState {
var res api.QueryStateAfterEventsResponse
err := t.queryer.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{
RoomID: roomID,
PrevEventIDs: []string{eventID},
}, &res)
if err != nil || !res.PrevEventsExist {
util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to query state after %s locally, prev exists=%v", eventID, res.PrevEventsExist)
var res parsedRespState
roomInfo, err := t.db.RoomInfo(ctx, roomID)
if err != nil {
return nil
}
stateEvents := make([]*gomatrixserverlib.HeaderedEvent, len(res.StateEvents))
for i, ev := range res.StateEvents {
roomState := state.NewStateResolution(t.db, roomInfo)
stateAtEvents, err := t.db.StateAtEventIDs(ctx, []string{eventID})
if err != nil {
util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to get state after %s locally", eventID)
return nil
}
stateEntries, err := roomState.LoadCombinedStateAfterEvents(ctx, stateAtEvents)
if err != nil {
util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to load combined state after %s locally", eventID)
return nil
}
stateEventNIDs := make([]types.EventNID, 0, len(stateEntries))
for _, entry := range stateEntries {
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
}
stateEvents, err := t.db.Events(ctx, stateEventNIDs)
if err != nil {
util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to load state events locally")
return nil
}
res.StateEvents = make([]*gomatrixserverlib.Event, 0, len(stateEvents))
for _, ev := range stateEvents {
// set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this
// processEvent request, which is better for memory.
stateEvents[i] = t.cacheAndReturn(ev)
res.StateEvents = append(res.StateEvents, t.cacheAndReturn(ev.Event))
t.hadEvent(ev.EventID())
}
// we should never access res.StateEvents again so we delete it here to make GC faster
res.StateEvents = nil
var authEvents []*gomatrixserverlib.Event
stateEvents = nil
stateEventNIDs = nil
stateEntries = nil
stateAtEvents = nil
missingAuthEvents := map[string]bool{}
res.AuthEvents = make([]*gomatrixserverlib.Event, 0, len(stateEvents)*3)
for _, ev := range stateEvents {
t.haveEventsMutex.Lock()
for _, ae := range ev.AuthEventIDs() {
if aev, ok := t.haveEvents[ae]; ok {
authEvents = append(authEvents, aev.Unwrap())
res.AuthEvents = append(res.AuthEvents, aev)
} else {
missingAuthEvents[ae] = true
}
@ -389,25 +408,18 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, room
for evID := range missingAuthEvents {
missingEventList = append(missingEventList, evID)
}
queryReq := api.QueryEventsByIDRequest{
EventIDs: missingEventList,
}
util.GetLogger(ctx).WithField("count", len(missingEventList)).Debugf("Fetching missing auth events")
var queryRes api.QueryEventsByIDResponse
if err = t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
events, err := t.db.EventsFromIDs(ctx, missingEventList)
if err != nil {
return nil
}
for i, ev := range queryRes.Events {
authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap())
for i, ev := range events {
res.AuthEvents = append(res.AuthEvents, t.cacheAndReturn(events[i].Event))
t.hadEvent(ev.EventID())
}
queryRes.Events = nil
}
return &parsedRespState{
StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateEvents),
AuthEvents: authEvents,
}
return &res
}
// lookuptStateBeforeEvent returns the room state before the event e, which is just /state_ids and/or /state depending on what
@ -448,7 +460,7 @@ retryAllowedState:
return nil, fmt.Errorf("missing auth event %s and failed to look it up: %w", missing.AuthEventID, err2)
}
util.GetLogger(ctx).Tracef("fetched event %s", missing.AuthEventID)
resolvedStateEvents = append(resolvedStateEvents, h.Unwrap())
resolvedStateEvents = append(resolvedStateEvents, h)
goto retryAllowedState
default:
}
@ -513,7 +525,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve
logger.Debugf("get_missing_events returned %d events", len(missingResp.Events))
missingEvents := make([]*gomatrixserverlib.Event, 0, len(missingResp.Events))
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
missingEvents = append(missingEvents, t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap())
missingEvents = append(missingEvents, t.cacheAndReturn(ev))
}
// topologically sort and sanity check that we are making forward progress
@ -602,11 +614,11 @@ func (t *missingStateReq) lookupMissingStateViaState(
// We load these as trusted as we called state.Check before which loaded them as untrusted.
for i, evJSON := range state.AuthEvents {
ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion)
parsedState.AuthEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap()
parsedState.AuthEvents[i] = t.cacheAndReturn(ev)
}
for i, evJSON := range state.StateEvents {
ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion)
parsedState.StateEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap()
parsedState.StateEvents[i] = t.cacheAndReturn(ev)
}
return parsedState, nil
}
@ -634,23 +646,20 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
}
t.haveEventsMutex.Unlock()
// fetch as many as we can from the roomserver
queryReq := api.QueryEventsByIDRequest{
EventIDs: missingEventList,
events, err := t.db.EventsFromIDs(ctx, missingEventList)
if err != nil {
return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err)
}
var queryRes api.QueryEventsByIDResponse
if err = t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
return nil, err
}
for i, ev := range queryRes.Events {
queryRes.Events[i] = t.cacheAndReturn(queryRes.Events[i])
for i, ev := range events {
events[i].Event = t.cacheAndReturn(events[i].Event)
t.hadEvent(ev.EventID())
evID := queryRes.Events[i].EventID()
evID := events[i].EventID()
if missing[evID] {
delete(missing, evID)
}
}
queryRes.Events = nil // allow it to be GCed
events = nil // allow GC
concurrentRequests := 8
missingCount := len(missing)
@ -704,7 +713,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
// Define what we'll do in order to fetch the missing event ID.
fetch := func(missingEventID string) {
var h *gomatrixserverlib.HeaderedEvent
var h *gomatrixserverlib.Event
h, err = t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false)
switch err.(type) {
case verifySigError:
@ -759,7 +768,7 @@ func (t *missingStateReq) createRespStateFromStateIDs(
logrus.Tracef("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i])
continue
}
respState.StateEvents = append(respState.StateEvents, ev.Unwrap())
respState.StateEvents = append(respState.StateEvents, ev)
}
for i := range stateIDs.AuthEventIDs {
ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]]
@ -767,7 +776,7 @@ func (t *missingStateReq) createRespStateFromStateIDs(
logrus.Tracef("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i])
continue
}
respState.AuthEvents = append(respState.AuthEvents, ev.Unwrap())
respState.AuthEvents = append(respState.AuthEvents, ev)
}
// We purposefully do not do auth checks on the returned events, as they will still
// be processed in the exact same way, just as a 'rejected' event
@ -775,17 +784,14 @@ func (t *missingStateReq) createRespStateFromStateIDs(
return &respState, nil
}
func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) {
func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.Event, error) {
if localFirst {
// fetch from the roomserver
queryReq := api.QueryEventsByIDRequest{
EventIDs: []string{missingEventID},
}
var queryRes api.QueryEventsByIDResponse
if err := t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
events, err := t.db.EventsFromIDs(ctx, []string{missingEventID})
if err != nil {
util.GetLogger(ctx).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err)
} else if len(queryRes.Events) == 1 {
return queryRes.Events[0], nil
} else if len(events) == 1 {
return events[0].Event, nil
}
}
var event *gomatrixserverlib.Event
@ -822,7 +828,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
util.GetLogger(ctx).WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
return nil, verifySigError{event.EventID(), err}
}
return t.cacheAndReturn(event.Headered(roomVersion)), nil
return t.cacheAndReturn(event), nil
}
func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserverlib.Event) error {

View file

@ -125,39 +125,6 @@ func (r *Queryer) QueryStateAfterEvents(
return nil
}
// QueryMissingAuthPrevEvents implements api.RoomserverInternalAPI
func (r *Queryer) QueryMissingAuthPrevEvents(
ctx context.Context,
request *api.QueryMissingAuthPrevEventsRequest,
response *api.QueryMissingAuthPrevEventsResponse,
) error {
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil {
return err
}
if info == nil {
return errors.New("room doesn't exist")
}
response.RoomExists = !info.IsStub
response.RoomVersion = info.RoomVersion
for _, authEventID := range request.AuthEventIDs {
if nids, err := r.DB.EventNIDs(ctx, []string{authEventID}); err != nil || len(nids) == 0 {
response.MissingAuthEventIDs = append(response.MissingAuthEventIDs, authEventID)
}
}
for _, prevEventID := range request.PrevEventIDs {
state, err := r.DB.StateAtEventIDs(ctx, []string{prevEventID})
if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) {
response.MissingPrevEventIDs = append(response.MissingPrevEventIDs, prevEventID)
}
}
return nil
}
// QueryEventsByID implements api.RoomserverInternalAPI
func (r *Queryer) QueryEventsByID(
ctx context.Context,

View file

@ -40,7 +40,6 @@ const (
// Query operations
RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState"
RoomserverQueryStateAfterEventsPath = "/roomserver/queryStateAfterEvents"
RoomserverQueryMissingAuthPrevEventsPath = "/roomserver/queryMissingAuthPrevEvents"
RoomserverQueryEventsByIDPath = "/roomserver/queryEventsByID"
RoomserverQueryMembershipForUserPath = "/roomserver/queryMembershipForUser"
RoomserverQueryMembershipsForRoomPath = "/roomserver/queryMembershipsForRoom"
@ -302,19 +301,6 @@ func (h *httpRoomserverInternalAPI) QueryStateAfterEvents(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
// QueryStateAfterEvents implements RoomserverQueryAPI
func (h *httpRoomserverInternalAPI) QueryMissingAuthPrevEvents(
ctx context.Context,
request *api.QueryMissingAuthPrevEventsRequest,
response *api.QueryMissingAuthPrevEventsResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMissingAuthPrevEvents")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryMissingAuthPrevEventsPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
// QueryEventsByID implements RoomserverQueryAPI
func (h *httpRoomserverInternalAPI) QueryEventsByID(
ctx context.Context,

View file

@ -149,20 +149,6 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(
RoomserverQueryMissingAuthPrevEventsPath,
httputil.MakeInternalAPI("queryMissingAuthPrevEvents", func(req *http.Request) util.JSONResponse {
var request api.QueryMissingAuthPrevEventsRequest
var response api.QueryMissingAuthPrevEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryMissingAuthPrevEvents(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(
RoomserverQueryEventsByIDPath,
httputil.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse {

View file

@ -97,6 +97,34 @@ func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
return u.currentStateSnapshotNID
}
func (u *RoomUpdater) MissingAuthPrevEvents(
ctx context.Context, e *gomatrixserverlib.Event,
) (missingAuth, missingPrev []string, err error) {
var info *types.RoomInfo
info, err = u.RoomInfo(ctx, e.RoomID())
if err != nil {
return
}
if info == nil || !info.IsStub {
return
}
for _, authEventID := range e.AuthEventIDs() {
if nids, err := u.EventNIDs(ctx, []string{authEventID}); err != nil || len(nids) == 0 {
missingAuth = append(missingAuth, authEventID)
}
}
for _, prevEventID := range e.PrevEventIDs() {
state, err := u.StateAtEventIDs(ctx, []string{prevEventID})
if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) {
missingPrev = append(missingPrev, prevEventID)
}
}
return
}
// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {