Merge branch 'master' into master

This commit is contained in:
Kilos Liu 2021-07-02 20:56:45 +08:00 committed by GitHub
commit 13bd3bd7dd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 403 additions and 143 deletions

View file

@ -33,7 +33,7 @@ func FederationAPI(base *setup.BaseDendrite, cfg *config.Dendrite) {
base.PublicFederationAPIMux, base.PublicKeyAPIMux, base.PublicFederationAPIMux, base.PublicKeyAPIMux,
&base.Cfg.FederationAPI, userAPI, federation, keyRing, &base.Cfg.FederationAPI, userAPI, federation, keyRing,
rsAPI, fsAPI, base.EDUServerClient(), keyAPI, rsAPI, fsAPI, base.EDUServerClient(), keyAPI,
&base.Cfg.MSCs, &base.Cfg.MSCs, nil,
) )
base.SetupAndServeHTTP( base.SetupAndServeHTTP(

View file

@ -0,0 +1,11 @@
package api
import (
"context"
"github.com/matrix-org/gomatrixserverlib"
)
type ServersInRoomProvider interface {
GetServersForRoom(ctx context.Context, roomID string, event *gomatrixserverlib.Event) []gomatrixserverlib.ServerName
}

View file

@ -17,6 +17,7 @@ package federationapi
import ( import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
@ -39,10 +40,12 @@ func AddPublicRoutes(
eduAPI eduserverAPI.EDUServerInputAPI, eduAPI eduserverAPI.EDUServerInputAPI,
keyAPI keyserverAPI.KeyInternalAPI, keyAPI keyserverAPI.KeyInternalAPI,
mscCfg *config.MSCs, mscCfg *config.MSCs,
servers federationAPI.ServersInRoomProvider,
) { ) {
routing.Setup( routing.Setup(
fedRouter, keyRouter, cfg, rsAPI, fedRouter, keyRouter, cfg, rsAPI,
eduAPI, federationSenderAPI, keyRing, eduAPI, federationSenderAPI, keyRing,
federation, userAPI, keyAPI, mscCfg, federation, userAPI, keyAPI, mscCfg,
servers,
) )
} }

View file

@ -31,7 +31,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
fsAPI := base.FederationSenderHTTPClient() fsAPI := base.FederationSenderHTTPClient()
// TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break.
// Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing.
federationapi.AddPublicRoutes(base.PublicFederationAPIMux, base.PublicKeyAPIMux, &cfg.FederationAPI, nil, nil, keyRing, nil, fsAPI, nil, nil, &cfg.MSCs) federationapi.AddPublicRoutes(base.PublicFederationAPIMux, base.PublicKeyAPIMux, &cfg.FederationAPI, nil, nil, keyRing, nil, fsAPI, nil, nil, &cfg.MSCs, nil)
baseURL, cancel := test.ListenAndServe(t, base.PublicFederationAPIMux, true) baseURL, cancel := test.ListenAndServe(t, base.PublicFederationAPIMux, true)
defer cancel() defer cancel()
serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://"))

View file

@ -20,6 +20,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
@ -50,6 +51,7 @@ func Setup(
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
keyAPI keyserverAPI.KeyInternalAPI, keyAPI keyserverAPI.KeyInternalAPI,
mscCfg *config.MSCs, mscCfg *config.MSCs,
servers federationAPI.ServersInRoomProvider,
) { ) {
v2keysmux := keyMux.PathPrefix("/v2").Subrouter() v2keysmux := keyMux.PathPrefix("/v2").Subrouter()
v1fedmux := fedMux.PathPrefix("/v1").Subrouter() v1fedmux := fedMux.PathPrefix("/v1").Subrouter()
@ -99,7 +101,7 @@ func Setup(
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return Send( return Send(
httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]), httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]),
cfg, rsAPI, eduAPI, keyAPI, keys, federation, mu, cfg, rsAPI, eduAPI, keyAPI, keys, federation, mu, servers,
) )
}, },
)).Methods(http.MethodPut, http.MethodOptions) )).Methods(http.MethodPut, http.MethodOptions)

View file

@ -16,16 +16,16 @@ package routing
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"sync" "sync"
"time" "time"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -34,6 +34,7 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"go.uber.org/atomic"
) )
const ( const (
@ -88,6 +89,67 @@ func init() {
) )
} }
type sendFIFOQueue struct {
tasks []*inputTask
count int
mutex sync.Mutex
notifs chan struct{}
}
func newSendFIFOQueue() *sendFIFOQueue {
q := &sendFIFOQueue{
notifs: make(chan struct{}, 1),
}
return q
}
func (q *sendFIFOQueue) push(frame *inputTask) {
q.mutex.Lock()
defer q.mutex.Unlock()
q.tasks = append(q.tasks, frame)
q.count++
select {
case q.notifs <- struct{}{}:
default:
}
}
// pop returns the first item of the queue, if there is one.
// The second return value will indicate if a task was returned.
func (q *sendFIFOQueue) pop() (*inputTask, bool) {
q.mutex.Lock()
defer q.mutex.Unlock()
if q.count == 0 {
return nil, false
}
frame := q.tasks[0]
q.tasks[0] = nil
q.tasks = q.tasks[1:]
q.count--
if q.count == 0 {
// Force a GC of the underlying array, since it might have
// grown significantly if the queue was hammered for some reason
q.tasks = nil
}
return frame, true
}
type inputTask struct {
ctx context.Context
t *txnReq
event *gomatrixserverlib.Event
wg *sync.WaitGroup
err error // written back by worker, only safe to read when all tasks are done
duration time.Duration // written back by worker, only safe to read when all tasks are done
}
type inputWorker struct {
running atomic.Bool
input *sendFIFOQueue
}
var inputWorkers sync.Map // room ID -> *inputWorker
// Send implements /_matrix/federation/v1/send/{txnID} // Send implements /_matrix/federation/v1/send/{txnID}
func Send( func Send(
httpReq *http.Request, httpReq *http.Request,
@ -100,14 +162,16 @@ func Send(
keys gomatrixserverlib.JSONVerifier, keys gomatrixserverlib.JSONVerifier,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
mu *internal.MutexByRoom, mu *internal.MutexByRoom,
servers federationAPI.ServersInRoomProvider,
) util.JSONResponse { ) util.JSONResponse {
t := txnReq{ t := txnReq{
rsAPI: rsAPI, rsAPI: rsAPI,
eduAPI: eduAPI, eduAPI: eduAPI,
keys: keys, keys: keys,
federation: federation, federation: federation,
hadEvents: make(map[string]bool),
haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent),
newEvents: make(map[string]bool), servers: servers,
keyAPI: keyAPI, keyAPI: keyAPI,
roomsMu: mu, roomsMu: mu,
} }
@ -159,21 +223,21 @@ func Send(
type txnReq struct { type txnReq struct {
gomatrixserverlib.Transaction gomatrixserverlib.Transaction
rsAPI api.RoomserverInternalAPI rsAPI api.RoomserverInternalAPI
eduAPI eduserverAPI.EDUServerInputAPI eduAPI eduserverAPI.EDUServerInputAPI
keyAPI keyapi.KeyInternalAPI keyAPI keyapi.KeyInternalAPI
keys gomatrixserverlib.JSONVerifier keys gomatrixserverlib.JSONVerifier
federation txnFederationClient federation txnFederationClient
servers []gomatrixserverlib.ServerName roomsMu *internal.MutexByRoom
serversMutex sync.RWMutex // something that can tell us about which servers are in a room right now
roomsMu *internal.MutexByRoom servers federationAPI.ServersInRoomProvider
// a list of events from the auth and prev events which we already had
hadEvents map[string]bool
// local cache of events for auth checks, etc - this may include events // local cache of events for auth checks, etc - this may include events
// which the roomserver is unaware of. // which the roomserver is unaware of.
haveEvents map[string]*gomatrixserverlib.HeaderedEvent haveEvents map[string]*gomatrixserverlib.HeaderedEvent
// new events which the roomserver does not know about haveEventsMutex sync.Mutex
newEvents map[string]bool work string // metrics
newEventsMutex sync.RWMutex
work string // metrics
} }
// A subset of FederationClient functionality that txn requires. Useful for testing. // A subset of FederationClient functionality that txn requires. Useful for testing.
@ -189,8 +253,12 @@ type txnFederationClient interface {
func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) {
results := make(map[string]gomatrixserverlib.PDUResult) results := make(map[string]gomatrixserverlib.PDUResult)
//var resultsMutex sync.Mutex
var wg sync.WaitGroup
var tasks []*inputTask
wg.Add(1) // for processEDUs
pdus := []*gomatrixserverlib.HeaderedEvent{}
for _, pdu := range t.PDUs { for _, pdu := range t.PDUs {
pduCountTotal.WithLabelValues("total").Inc() pduCountTotal.WithLabelValues("total").Inc()
var header struct { var header struct {
@ -241,83 +309,97 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res
} }
continue continue
} }
pdus = append(pdus, event.Headered(verRes.RoomVersion)) v, _ := inputWorkers.LoadOrStore(event.RoomID(), &inputWorker{
input: newSendFIFOQueue(),
})
worker := v.(*inputWorker)
if !worker.running.Load() {
go worker.run()
}
wg.Add(1)
task := &inputTask{
ctx: ctx,
t: t,
event: event,
wg: &wg,
}
tasks = append(tasks, task)
worker.input.push(task)
} }
// Process the events. go func() {
for _, e := range pdus { defer wg.Done()
evStart := time.Now() t.processEDUs(ctx)
if err := t.processEvent(ctx, e.Unwrap()); err != nil { }()
// If the error is due to the event itself being bad then we skip
// it and move onto the next event. We report an error so that the wg.Wait()
// sender knows that we have skipped processing it.
// for _, task := range tasks {
// However if the event is due to a temporary failure in our server if task.err != nil {
// such as a database being unavailable then we should bail, and results[task.event.EventID()] = gomatrixserverlib.PDUResult{
// hope that the sender will retry when we are feeling better. Error: task.err.Error(),
//
// It is uncertain what we should do if an event fails because
// we failed to fetch more information from the sending server.
// For example if a request to /state fails.
// If we skip the event then we risk missing the event until we
// receive another event referencing it.
// If we bail and stop processing then we risk wedging incoming
// transactions from that server forever.
if isProcessingErrorFatal(err) {
sentry.CaptureException(err)
// Any other error should be the result of a temporary error in
// our server so we should bail processing the transaction entirely.
util.GetLogger(ctx).Warnf("Processing %s failed fatally: %s", e.EventID(), err)
jsonErr := util.ErrorResponse(err)
processEventSummary.WithLabelValues(t.work, MetricsOutcomeFatal).Observe(
float64(time.Since(evStart).Nanoseconds()) / 1000.,
)
return nil, &jsonErr
} else {
// Auth errors mean the event is 'rejected' which have to be silent to appease sytest
errMsg := ""
outcome := MetricsOutcomeRejected
_, rejected := err.(*gomatrixserverlib.NotAllowed)
if !rejected {
errMsg = err.Error()
outcome = MetricsOutcomeFail
}
util.GetLogger(ctx).WithError(err).WithField("event_id", e.EventID()).WithField("rejected", rejected).Warn(
"Failed to process incoming federation event, skipping",
)
processEventSummary.WithLabelValues(t.work, outcome).Observe(
float64(time.Since(evStart).Nanoseconds()) / 1000.,
)
results[e.EventID()] = gomatrixserverlib.PDUResult{
Error: errMsg,
}
} }
} else { } else {
results[e.EventID()] = gomatrixserverlib.PDUResult{} results[task.event.EventID()] = gomatrixserverlib.PDUResult{}
pduCountTotal.WithLabelValues("success").Inc()
processEventSummary.WithLabelValues(t.work, MetricsOutcomeOK).Observe(
float64(time.Since(evStart).Nanoseconds()) / 1000.,
)
} }
} }
t.processEDUs(ctx)
if c := len(results); c > 0 { if c := len(results); c > 0 {
util.GetLogger(ctx).Infof("Processed %d PDUs from transaction %q", c, t.TransactionID) util.GetLogger(ctx).Infof("Processed %d PDUs from transaction %q", c, t.TransactionID)
} }
return &gomatrixserverlib.RespSend{PDUs: results}, nil return &gomatrixserverlib.RespSend{PDUs: results}, nil
} }
// isProcessingErrorFatal returns true if the error is really bad and func (t *inputWorker) run() {
// we should stop processing the transaction, and returns false if it if !t.running.CAS(false, true) {
// is just some less serious error about a specific event. return
func isProcessingErrorFatal(err error) bool { }
switch err { defer t.running.Store(false)
case sql.ErrConnDone: for {
case sql.ErrTxDone: task, ok := t.input.pop()
return true if !ok {
return
}
if task == nil {
continue
}
func() {
defer task.wg.Done()
select {
case <-task.ctx.Done():
task.err = context.DeadlineExceeded
return
default:
evStart := time.Now()
task.err = task.t.processEvent(task.ctx, task.event)
task.duration = time.Since(evStart)
if err := task.err; err != nil {
switch err.(type) {
case *gomatrixserverlib.NotAllowed:
processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeRejected).Observe(
float64(time.Since(evStart).Nanoseconds()) / 1000.,
)
util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", true).Warn(
"Failed to process incoming federation event, skipping",
)
task.err = nil // make "rejected" failures silent
default:
processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeFail).Observe(
float64(time.Since(evStart).Nanoseconds()) / 1000.,
)
util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", false).Warn(
"Failed to process incoming federation event, skipping",
)
}
} else {
pduCountTotal.WithLabelValues("success").Inc()
processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeOK).Observe(
float64(time.Since(evStart).Nanoseconds()) / 1000.,
)
}
}
}()
} }
return false
} }
type roomNotFoundError struct { type roomNotFoundError struct {
@ -340,19 +422,6 @@ func (e missingPrevEventsError) Error() string {
return fmt.Sprintf("unable to get prev_events for event %q: %s", e.eventID, e.err) return fmt.Sprintf("unable to get prev_events for event %q: %s", e.eventID, e.err)
} }
func (t *txnReq) haveEventIDs() map[string]bool {
t.newEventsMutex.RLock()
defer t.newEventsMutex.RUnlock()
result := make(map[string]bool, len(t.haveEvents))
for eventID := range t.haveEvents {
if t.newEvents[eventID] {
continue
}
result[eventID] = true
}
return result
}
func (t *txnReq) processEDUs(ctx context.Context) { func (t *txnReq) processEDUs(ctx context.Context) {
for _, e := range t.EDUs { for _, e := range t.EDUs {
eduCountTotal.Inc() eduCountTotal.Inc()
@ -479,22 +548,24 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli
} }
} }
func (t *txnReq) getServers(ctx context.Context, roomID string) []gomatrixserverlib.ServerName { func (t *txnReq) getServers(ctx context.Context, roomID string, event *gomatrixserverlib.Event) []gomatrixserverlib.ServerName {
t.serversMutex.Lock() // The server that sent us the event should be sufficient to tell us about missing
defer t.serversMutex.Unlock() // prev and auth events.
servers := []gomatrixserverlib.ServerName{t.Origin}
// If the event origin is different to the transaction origin then we can use
// this as a last resort. The origin server that created the event would have
// had to know the auth and prev events.
if event != nil {
if origin := event.Origin(); origin != t.Origin {
servers = append(servers, origin)
}
}
// If a specific room-to-server provider exists then use that. This will primarily
// be used for the P2P demos.
if t.servers != nil { if t.servers != nil {
return t.servers servers = append(servers, t.servers.GetServersForRoom(ctx, roomID, event)...)
} }
t.servers = []gomatrixserverlib.ServerName{t.Origin} return servers
serverReq := &api.QueryServerJoinedToRoomRequest{
RoomID: roomID,
}
serverRes := &api.QueryServerJoinedToRoomResponse{}
if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil {
t.servers = append(t.servers, serverRes.ServerNames...)
util.GetLogger(ctx).Infof("Found %d server(s) to query for missing events in %q", len(t.servers), roomID)
}
return t.servers
} }
func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) error { func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) error {
@ -527,6 +598,15 @@ func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) e
return roomNotFoundError{e.RoomID()} return roomNotFoundError{e.RoomID()}
} }
// Prepare a map of all the events we already had before this point, so
// that we don't send them to the roomserver again.
for _, eventID := range append(e.AuthEventIDs(), e.PrevEventIDs()...) {
t.hadEvents[eventID] = true
}
for _, eventID := range append(stateResp.MissingAuthEventIDs, stateResp.MissingPrevEventIDs...) {
t.hadEvents[eventID] = false
}
if len(stateResp.MissingAuthEventIDs) > 0 { if len(stateResp.MissingAuthEventIDs) > 0 {
t.work = MetricsWorkMissingAuthEvents t.work = MetricsWorkMissingAuthEvents
logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs)) logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs))
@ -570,11 +650,14 @@ func (t *txnReq) retrieveMissingAuthEvents(
withNextEvent: withNextEvent:
for missingAuthEventID := range missingAuthEvents { for missingAuthEventID := range missingAuthEvents {
withNextServer: withNextServer:
for _, server := range t.getServers(ctx, e.RoomID()) { for _, server := range t.getServers(ctx, e.RoomID(), e) {
logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server) logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server)
tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID) tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID)
if err != nil { if err != nil {
logger.WithError(err).Warnf("Failed to retrieve auth event %q", missingAuthEventID) logger.WithError(err).Warnf("Failed to retrieve auth event %q", missingAuthEventID)
if errors.Is(err, context.DeadlineExceeded) {
return err
}
continue withNextServer continue withNextServer
} }
ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion) ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion)
@ -596,6 +679,8 @@ withNextEvent:
); err != nil { ); err != nil {
return fmt.Errorf("api.SendEvents: %w", err) return fmt.Errorf("api.SendEvents: %w", err)
} }
t.hadEvents[ev.EventID()] = true // if the roomserver didn't know about the event before, it does now
t.cacheAndReturn(ev.Headered(stateResp.RoomVersion))
delete(missingAuthEvents, missingAuthEventID) delete(missingAuthEvents, missingAuthEventID)
continue withNextEvent continue withNextEvent
} }
@ -618,14 +703,14 @@ func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserv
return gomatrixserverlib.Allowed(e, &authUsingState) return gomatrixserverlib.Allowed(e, &authUsingState)
} }
var processEventWithMissingStateMutexes = internal.NewMutexByRoom()
func (t *txnReq) processEventWithMissingState( func (t *txnReq) processEventWithMissingState(
ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion,
) error { ) error {
// Do this with a fresh context, so that we keep working even if the processEventWithMissingStateMutexes.Lock(e.RoomID())
// original request times out. With any luck, by the time the remote defer processEventWithMissingStateMutexes.Unlock(e.RoomID())
// side retries, we'll have fetched the missing state.
gmectx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
defer cancel()
// We are missing the previous events for this events. // We are missing the previous events for this events.
// This means that there is a gap in our view of the history of the // This means that there is a gap in our view of the history of the
// room. There two ways that we can handle such a gap: // room. There two ways that we can handle such a gap:
@ -646,7 +731,7 @@ func (t *txnReq) processEventWithMissingState(
// - fill in the gap completely then process event `e` returning no backwards extremity // - fill in the gap completely then process event `e` returning no backwards extremity
// - fail to fill in the gap and tell us to terminate the transaction err=not nil // - fail to fill in the gap and tell us to terminate the transaction err=not nil
// - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction
newEvents, err := t.getMissingEvents(gmectx, e, roomVersion) newEvents, err := t.getMissingEvents(ctx, e, roomVersion)
if err != nil { if err != nil {
return err return err
} }
@ -673,7 +758,7 @@ func (t *txnReq) processEventWithMissingState(
// Look up what the state is after the backward extremity. This will either // Look up what the state is after the backward extremity. This will either
// come from the roomserver, if we know all the required events, or it will // come from the roomserver, if we know all the required events, or it will
// come from a remote server via /state_ids if not. // come from a remote server via /state_ids if not.
prevState, trustworthy, lerr := t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID) prevState, trustworthy, lerr := t.lookupStateAfterEvent(ctx, roomVersion, backwardsExtremity.RoomID(), prevEventID)
if lerr != nil { if lerr != nil {
util.GetLogger(ctx).WithError(lerr).Errorf("Failed to lookup state after prev_event: %s", prevEventID) util.GetLogger(ctx).WithError(lerr).Errorf("Failed to lookup state after prev_event: %s", prevEventID)
return lerr return lerr
@ -717,7 +802,7 @@ func (t *txnReq) processEventWithMissingState(
} }
// There's more than one previous state - run them all through state res // There's more than one previous state - run them all through state res
t.roomsMu.Lock(e.RoomID()) t.roomsMu.Lock(e.RoomID())
resolvedState, err = t.resolveStatesAndCheck(gmectx, roomVersion, respStates, backwardsExtremity) resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, backwardsExtremity)
t.roomsMu.Unlock(e.RoomID()) t.roomsMu.Unlock(e.RoomID())
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID())
@ -734,7 +819,7 @@ func (t *txnReq) processEventWithMissingState(
api.KindOld, api.KindOld,
resolvedState, resolvedState,
backwardsExtremity.Headered(roomVersion), backwardsExtremity.Headered(roomVersion),
t.haveEventIDs(), t.hadEvents,
) )
if err != nil { if err != nil {
return fmt.Errorf("api.SendEventWithState: %w", err) return fmt.Errorf("api.SendEventWithState: %w", err)
@ -786,7 +871,7 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix
default: default:
return nil, false, fmt.Errorf("t.lookupEvent: %w", err) return nil, false, fmt.Errorf("t.lookupEvent: %w", err)
} }
t.cacheAndReturn(h) h = t.cacheAndReturn(h)
if h.StateKey() != nil { if h.StateKey() != nil {
addedToState := false addedToState := false
for i := range respState.StateEvents { for i := range respState.StateEvents {
@ -806,6 +891,8 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix
} }
func (t *txnReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *gomatrixserverlib.HeaderedEvent { func (t *txnReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *gomatrixserverlib.HeaderedEvent {
t.haveEventsMutex.Lock()
defer t.haveEventsMutex.Unlock()
if cached, exists := t.haveEvents[ev.EventID()]; exists { if cached, exists := t.haveEvents[ev.EventID()]; exists {
return cached return cached
} }
@ -828,6 +915,7 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event
// set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this // set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this
// processEvent request, which is better for memory. // processEvent request, which is better for memory.
stateEvents[i] = t.cacheAndReturn(ev) stateEvents[i] = t.cacheAndReturn(ev)
t.hadEvents[ev.EventID()] = true
} }
// we should never access res.StateEvents again so we delete it here to make GC faster // we should never access res.StateEvents again so we delete it here to make GC faster
res.StateEvents = nil res.StateEvents = nil
@ -835,6 +923,7 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event
var authEvents []*gomatrixserverlib.Event var authEvents []*gomatrixserverlib.Event
missingAuthEvents := map[string]bool{} missingAuthEvents := map[string]bool{}
for _, ev := range stateEvents { for _, ev := range stateEvents {
t.haveEventsMutex.Lock()
for _, ae := range ev.AuthEventIDs() { for _, ae := range ev.AuthEventIDs() {
if aev, ok := t.haveEvents[ae]; ok { if aev, ok := t.haveEvents[ae]; ok {
authEvents = append(authEvents, aev.Unwrap()) authEvents = append(authEvents, aev.Unwrap())
@ -842,6 +931,7 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event
missingAuthEvents[ae] = true missingAuthEvents[ae] = true
} }
} }
t.haveEventsMutex.Unlock()
} }
// QueryStateAfterEvents does not return the auth events, so fetch them now. We know the roomserver has them else it wouldn't // QueryStateAfterEvents does not return the auth events, so fetch them now. We know the roomserver has them else it wouldn't
// have stored the event. // have stored the event.
@ -858,8 +948,9 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event
if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
return nil return nil
} }
for i := range queryRes.Events { for i, ev := range queryRes.Events {
authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap()) authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap())
t.hadEvents[ev.EventID()] = true
} }
queryRes.Events = nil queryRes.Events = nil
} }
@ -934,12 +1025,13 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Even
return nil, err return nil, err
} }
latestEvents := make([]string, len(res.LatestEvents)) latestEvents := make([]string, len(res.LatestEvents))
for i := range res.LatestEvents { for i, ev := range res.LatestEvents {
latestEvents[i] = res.LatestEvents[i].EventID latestEvents[i] = res.LatestEvents[i].EventID
t.hadEvents[ev.EventID] = true
} }
var missingResp *gomatrixserverlib.RespMissingEvents var missingResp *gomatrixserverlib.RespMissingEvents
servers := t.getServers(ctx, e.RoomID()) servers := t.getServers(ctx, e.RoomID(), e)
for _, server := range servers { for _, server := range servers {
var m gomatrixserverlib.RespMissingEvents var m gomatrixserverlib.RespMissingEvents
if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{
@ -953,6 +1045,9 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Even
break break
} else { } else {
logger.WithError(err).Errorf("%s pushed us an event but %q did not respond to /get_missing_events", t.Origin, server) logger.WithError(err).Errorf("%s pushed us an event but %q did not respond to /get_missing_events", t.Origin, server)
if errors.Is(err, context.DeadlineExceeded) {
break
}
} }
} }
@ -980,6 +1075,12 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Even
// For now, we do not allow Case B, so reject the event. // For now, we do not allow Case B, so reject the event.
logger.Infof("get_missing_events returned %d events", len(missingResp.Events)) logger.Infof("get_missing_events returned %d events", len(missingResp.Events))
// Make sure events from the missingResp are using the cache - missing events
// will be added and duplicates will be removed.
for i, ev := range missingResp.Events {
missingResp.Events[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap()
}
// topologically sort and sanity check that we are making forward progress // topologically sort and sanity check that we are making forward progress
newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents) newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents)
shouldHaveSomeEventIDs := e.PrevEventIDs() shouldHaveSomeEventIDs := e.PrevEventIDs()
@ -1018,6 +1119,14 @@ func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID
if err := state.Check(ctx, t.keys, nil); err != nil { if err := state.Check(ctx, t.keys, nil); err != nil {
return nil, err return nil, err
} }
// Cache the results of this state lookup and deduplicate anything we already
// have in the cache, freeing up memory.
for i, ev := range state.AuthEvents {
state.AuthEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap()
}
for i, ev := range state.StateEvents {
state.StateEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap()
}
return &state, nil return &state, nil
} }
@ -1033,6 +1142,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
wantIDs := append(stateIDs.StateEventIDs, stateIDs.AuthEventIDs...) wantIDs := append(stateIDs.StateEventIDs, stateIDs.AuthEventIDs...)
missing := make(map[string]bool) missing := make(map[string]bool)
var missingEventList []string var missingEventList []string
t.haveEventsMutex.Lock()
for _, sid := range wantIDs { for _, sid := range wantIDs {
if _, ok := t.haveEvents[sid]; !ok { if _, ok := t.haveEvents[sid]; !ok {
if !missing[sid] { if !missing[sid] {
@ -1041,6 +1151,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
} }
} }
} }
t.haveEventsMutex.Unlock()
// fetch as many as we can from the roomserver // fetch as many as we can from the roomserver
queryReq := api.QueryEventsByIDRequest{ queryReq := api.QueryEventsByIDRequest{
@ -1050,9 +1161,10 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
return nil, err return nil, err
} }
for i := range queryRes.Events { for i, ev := range queryRes.Events {
queryRes.Events[i] = t.cacheAndReturn(queryRes.Events[i])
t.hadEvents[ev.EventID()] = true
evID := queryRes.Events[i].EventID() evID := queryRes.Events[i].EventID()
t.cacheAndReturn(queryRes.Events[i])
if missing[evID] { if missing[evID] {
delete(missing, evID) delete(missing, evID)
} }
@ -1153,6 +1265,9 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs) ( func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs) (
*gomatrixserverlib.RespState, error) { // nolint:unparam *gomatrixserverlib.RespState, error) { // nolint:unparam
t.haveEventsMutex.Lock()
defer t.haveEventsMutex.Unlock()
// create a RespState response using the response to /state_ids as a guide // create a RespState response using the response to /state_ids as a guide
respState := gomatrixserverlib.RespState{} respState := gomatrixserverlib.RespState{}
@ -1193,11 +1308,14 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.
} }
var event *gomatrixserverlib.Event var event *gomatrixserverlib.Event
found := false found := false
servers := t.getServers(ctx, roomID) servers := t.getServers(ctx, roomID, nil)
for _, serverName := range servers { for _, serverName := range servers {
txn, err := t.federation.GetEvent(ctx, serverName, missingEventID) txn, err := t.federation.GetEvent(ctx, serverName, missingEventID)
if err != nil || len(txn.PDUs) == 0 { if err != nil || len(txn.PDUs) == 0 {
util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("Failed to get missing /event for event ID") util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("Failed to get missing /event for event ID")
if errors.Is(err, context.DeadlineExceeded) {
break
}
continue continue
} }
event, err = gomatrixserverlib.NewEventFromUntrustedJSON(txn.PDUs[0], roomVersion) event, err = gomatrixserverlib.NewEventFromUntrustedJSON(txn.PDUs[0], roomVersion)
@ -1216,9 +1334,5 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.
util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID())
return nil, verifySigError{event.EventID(), err} return nil, verifySigError{event.EventID(), err}
} }
h := event.Headered(roomVersion) return t.cacheAndReturn(event.Headered(roomVersion)), nil
t.newEventsMutex.Lock()
t.newEvents[h.EventID()] = true
t.newEventsMutex.Unlock()
return h, nil
} }

View file

@ -370,7 +370,7 @@ func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederat
keys: &test.NopJSONVerifier{}, keys: &test.NopJSONVerifier{},
federation: fedClient, federation: fedClient,
haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent),
newEvents: make(map[string]bool), hadEvents: make(map[string]bool),
roomsMu: internal.NewMutexByRoom(), roomsMu: internal.NewMutexByRoom(),
} }
t.PDUs = pdus t.PDUs = pdus

View file

@ -28,6 +28,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"go.uber.org/atomic" "go.uber.org/atomic"
) )
@ -38,8 +39,7 @@ type Inputer struct {
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
ACLs *acls.ServerACLs ACLs *acls.ServerACLs
OutputRoomEventTopic string OutputRoomEventTopic string
workers sync.Map // room ID -> *inputWorker
workers sync.Map // room ID -> *inputWorker
} }
type inputTask struct { type inputTask struct {
@ -52,7 +52,7 @@ type inputTask struct {
type inputWorker struct { type inputWorker struct {
r *Inputer r *Inputer
running atomic.Bool running atomic.Bool
input chan *inputTask input *fifoQueue
} }
// Guarded by a CAS on w.running // Guarded by a CAS on w.running
@ -60,7 +60,14 @@ func (w *inputWorker) start() {
defer w.running.Store(false) defer w.running.Store(false)
for { for {
select { select {
case task := <-w.input: case <-w.input.wait():
task, ok := w.input.pop()
if !ok {
continue
}
roomserverInputBackpressure.With(prometheus.Labels{
"room_id": task.event.Event.RoomID(),
}).Dec()
hooks.Run(hooks.KindNewEventReceived, task.event.Event) hooks.Run(hooks.KindNewEventReceived, task.event.Event)
_, task.err = w.r.processRoomEvent(task.ctx, task.event) _, task.err = w.r.processRoomEvent(task.ctx, task.event)
if task.err == nil { if task.err == nil {
@ -117,6 +124,20 @@ func (r *Inputer) WriteOutputEvents(roomID string, updates []api.OutputEvent) er
return errs return errs
} }
func init() {
prometheus.MustRegister(roomserverInputBackpressure)
}
var roomserverInputBackpressure = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: "dendrite",
Subsystem: "roomserver",
Name: "input_backpressure",
Help: "How many events are queued for input for a given room",
},
[]string{"room_id"},
)
// InputRoomEvents implements api.RoomserverInternalAPI // InputRoomEvents implements api.RoomserverInternalAPI
func (r *Inputer) InputRoomEvents( func (r *Inputer) InputRoomEvents(
_ context.Context, _ context.Context,
@ -143,7 +164,7 @@ func (r *Inputer) InputRoomEvents(
// room - the channel will be quite small as it's just pointer types. // room - the channel will be quite small as it's just pointer types.
w, _ := r.workers.LoadOrStore(roomID, &inputWorker{ w, _ := r.workers.LoadOrStore(roomID, &inputWorker{
r: r, r: r,
input: make(chan *inputTask, 32), input: newFIFOQueue(),
}) })
worker := w.(*inputWorker) worker := w.(*inputWorker)
@ -160,7 +181,10 @@ func (r *Inputer) InputRoomEvents(
if worker.running.CAS(false, true) { if worker.running.CAS(false, true) {
go worker.start() go worker.start()
} }
worker.input <- tasks[i] worker.input.push(tasks[i])
roomserverInputBackpressure.With(prometheus.Labels{
"room_id": roomID,
}).Inc()
} }
// Wait for all of the workers to return results about our tasks. // Wait for all of the workers to return results about our tasks.

View file

@ -0,0 +1,64 @@
package input
import (
"sync"
)
type fifoQueue struct {
tasks []*inputTask
count int
mutex sync.Mutex
notifs chan struct{}
}
func newFIFOQueue() *fifoQueue {
q := &fifoQueue{
notifs: make(chan struct{}, 1),
}
return q
}
func (q *fifoQueue) push(frame *inputTask) {
q.mutex.Lock()
defer q.mutex.Unlock()
q.tasks = append(q.tasks, frame)
q.count++
select {
case q.notifs <- struct{}{}:
default:
}
}
// pop returns the first item of the queue, if there is one.
// The second return value will indicate if a task was returned.
// You must check this value, even after calling wait().
func (q *fifoQueue) pop() (*inputTask, bool) {
q.mutex.Lock()
defer q.mutex.Unlock()
if q.count == 0 {
return nil, false
}
frame := q.tasks[0]
q.tasks[0] = nil
q.tasks = q.tasks[1:]
q.count--
if q.count == 0 {
// Force a GC of the underlying array, since it might have
// grown significantly if the queue was hammered for some reason
q.tasks = nil
}
return frame, true
}
// wait returns a channel which can be used to detect when an
// item is waiting in the queue.
func (q *fifoQueue) wait() <-chan struct{} {
q.mutex.Lock()
defer q.mutex.Unlock()
if q.count > 0 && len(q.notifs) == 0 {
ch := make(chan struct{})
close(ch)
return ch
}
return q.notifs
}

View file

@ -119,11 +119,15 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
_roomserver_state_snapshots _roomserver_state_snapshots
JOIN _roomserver_state_block ON _roomserver_state_block.state_block_nid = ANY (_roomserver_state_snapshots.state_block_nids) JOIN _roomserver_state_block ON _roomserver_state_block.state_block_nid = ANY (_roomserver_state_snapshots.state_block_nids)
WHERE WHERE
_roomserver_state_snapshots.state_snapshot_nid = ANY ( SELECT DISTINCT _roomserver_state_snapshots.state_snapshot_nid = ANY (
SELECT
_roomserver_state_snapshots.state_snapshot_nid _roomserver_state_snapshots.state_snapshot_nid
FROM FROM
_roomserver_state_snapshots _roomserver_state_snapshots
LIMIT $1 OFFSET $2)) AS _roomserver_state_block ORDER BY _roomserver_state_snapshots.state_snapshot_nid ASC
LIMIT $1 OFFSET $2
)
) AS _roomserver_state_block
GROUP BY GROUP BY
state_snapshot_nid, state_snapshot_nid,
room_nid, room_nid,
@ -202,6 +206,23 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
} }
} }
// By this point we should have no more state_snapshot_nids below maxsnapshotid in either roomserver_rooms or roomserver_events
// If we do, this is a problem if Dendrite tries to load the snapshot as it will not exist
// in roomserver_state_snapshots
var count int64
if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_events WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil {
return fmt.Errorf("assertion query failed: %s", err)
}
if count > 0 {
return fmt.Errorf("%d events exist in roomserver_events which have not been converted to a new state_snapshot_nid; this is a bug, please report", count)
}
if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil {
return fmt.Errorf("assertion query failed: %s", err)
}
if count > 0 {
return fmt.Errorf("%d rooms exist in roomserver_rooms which have not been converted to a new state_snapshot_nid; this is a bug, please report", count)
}
if _, err = tx.Exec(` if _, err = tx.Exec(`
DROP TABLE _roomserver_state_snapshots; DROP TABLE _roomserver_state_snapshots;
DROP SEQUENCE roomserver_state_snapshot_nid_seq; DROP SEQUENCE roomserver_state_snapshot_nid_seq;

View file

@ -31,6 +31,7 @@ func LoadStateBlocksRefactor(m *sqlutil.Migrations) {
m.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor) m.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor)
} }
// nolint:gocyclo
func UpStateBlocksRefactor(tx *sql.Tx) error { func UpStateBlocksRefactor(tx *sql.Tx) error {
logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!") logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!")
defer logrus.Warn("State storage upgrade complete") defer logrus.Warn("State storage upgrade complete")
@ -45,6 +46,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
} }
maxsnapshotid++ maxsnapshotid++
maxblockid++ maxblockid++
oldMaxSnapshotID := maxsnapshotid
if _, err := tx.Exec(`ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil { if _, err := tx.Exec(`ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil {
return fmt.Errorf("tx.Exec: %w", err) return fmt.Errorf("tx.Exec: %w", err)
@ -133,6 +135,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
if jerr != nil { if jerr != nil {
return fmt.Errorf("json.Marshal (new blocks): %w", jerr) return fmt.Errorf("json.Marshal (new blocks): %w", jerr)
} }
var newsnapshot types.StateSnapshotNID var newsnapshot types.StateSnapshotNID
err = tx.QueryRow(` err = tx.QueryRow(`
INSERT INTO roomserver_state_snapshots (state_snapshot_nid, state_snapshot_hash, room_nid, state_block_nids) INSERT INTO roomserver_state_snapshots (state_snapshot_nid, state_snapshot_hash, room_nid, state_block_nids)
@ -144,7 +147,8 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
return fmt.Errorf("tx.QueryRow.Scan (insert new snapshot): %w", err) return fmt.Errorf("tx.QueryRow.Scan (insert new snapshot): %w", err)
} }
maxsnapshotid++ maxsnapshotid++
if _, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid); err != nil { _, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid)
if err != nil {
return fmt.Errorf("tx.Exec (update events): %w", err) return fmt.Errorf("tx.Exec (update events): %w", err)
} }
if _, err = tx.Exec(`UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid); err != nil { if _, err = tx.Exec(`UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid); err != nil {
@ -153,6 +157,23 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
} }
} }
// By this point we should have no more state_snapshot_nids below oldMaxSnapshotID in either roomserver_rooms or roomserver_events
// If we do, this is a problem if Dendrite tries to load the snapshot as it will not exist
// in roomserver_state_snapshots
var count int64
if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_events WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil {
return fmt.Errorf("assertion query failed: %s", err)
}
if count > 0 {
return fmt.Errorf("%d events exist in roomserver_events which have not been converted to a new state_snapshot_nid; this is a bug, please report", count)
}
if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil {
return fmt.Errorf("assertion query failed: %s", err)
}
if count > 0 {
return fmt.Errorf("%d rooms exist in roomserver_rooms which have not been converted to a new state_snapshot_nid; this is a bug, please report", count)
}
if _, err = tx.Exec(`DROP TABLE _roomserver_state_snapshots;`); err != nil { if _, err = tx.Exec(`DROP TABLE _roomserver_state_snapshots;`); err != nil {
return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err) return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err)
} }

View file

@ -68,7 +68,7 @@ func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ss
federationapi.AddPublicRoutes( federationapi.AddPublicRoutes(
ssMux, keyMux, &m.Config.FederationAPI, m.UserAPI, m.FedClient, ssMux, keyMux, &m.Config.FederationAPI, m.UserAPI, m.FedClient,
m.KeyRing, m.RoomserverAPI, m.FederationSenderAPI, m.KeyRing, m.RoomserverAPI, m.FederationSenderAPI,
m.EDUInternalAPI, m.KeyAPI, &m.Config.MSCs, m.EDUInternalAPI, m.KeyAPI, &m.Config.MSCs, nil,
) )
mediaapi.AddPublicRoutes(mediaMux, &m.Config.MediaAPI, m.UserAPI, m.Client) mediaapi.AddPublicRoutes(mediaMux, &m.Config.MediaAPI, m.UserAPI, m.Client)
syncapi.AddPublicRoutes( syncapi.AddPublicRoutes(