Fix graceful shutdown

This commit is contained in:
Neil Alexander 2022-04-27 15:29:49 +01:00
parent 103795d33a
commit 923f789ca3
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
4 changed files with 33 additions and 21 deletions

View file

@ -78,7 +78,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
// this destination queue. We'll then be able to retrieve the PDU // this destination queue. We'll then be able to retrieve the PDU
// later. // later.
if err := oq.db.AssociatePDUWithDestination( if err := oq.db.AssociatePDUWithDestination(
context.TODO(), oq.process.Context(),
"", // TODO: remove this, as we don't need to persist the transaction ID "", // TODO: remove this, as we don't need to persist the transaction ID
oq.destination, // the destination server name oq.destination, // the destination server name
receipt, // NIDs from federationapi_queue_json table receipt, // NIDs from federationapi_queue_json table
@ -122,7 +122,7 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
// this destination queue. We'll then be able to retrieve the PDU // this destination queue. We'll then be able to retrieve the PDU
// later. // later.
if err := oq.db.AssociateEDUWithDestination( if err := oq.db.AssociateEDUWithDestination(
context.TODO(), oq.process.Context(),
oq.destination, // the destination server name oq.destination, // the destination server name
receipt, // NIDs from federationapi_queue_json table receipt, // NIDs from federationapi_queue_json table
event.Type, event.Type,
@ -177,7 +177,7 @@ func (oq *destinationQueue) getPendingFromDatabase() {
// Check to see if there's anything to do for this server // Check to see if there's anything to do for this server
// in the database. // in the database.
retrieved := false retrieved := false
ctx := context.Background() ctx := oq.process.Context()
oq.pendingMutex.Lock() oq.pendingMutex.Lock()
defer oq.pendingMutex.Unlock() defer oq.pendingMutex.Unlock()
@ -271,6 +271,9 @@ func (oq *destinationQueue) backgroundSend() {
// restarted automatically the next time we have an event to // restarted automatically the next time we have an event to
// send. // send.
return return
case <-oq.process.Context().Done():
// The parent process is shutting down, so stop.
return
} }
// If we are backing off this server then wait for the // If we are backing off this server then wait for the
@ -420,13 +423,13 @@ func (oq *destinationQueue) nextTransaction(
// Clean up the transaction in the database. // Clean up the transaction in the database.
if pduReceipts != nil { if pduReceipts != nil {
//logrus.Infof("Cleaning PDUs %q", pduReceipt.String()) //logrus.Infof("Cleaning PDUs %q", pduReceipt.String())
if err = oq.db.CleanPDUs(context.Background(), oq.destination, pduReceipts); err != nil { if err = oq.db.CleanPDUs(oq.process.Context(), oq.destination, pduReceipts); err != nil {
logrus.WithError(err).Errorf("Failed to clean PDUs for server %q", t.Destination) logrus.WithError(err).Errorf("Failed to clean PDUs for server %q", t.Destination)
} }
} }
if eduReceipts != nil { if eduReceipts != nil {
//logrus.Infof("Cleaning EDUs %q", eduReceipt.String()) //logrus.Infof("Cleaning EDUs %q", eduReceipt.String())
if err = oq.db.CleanEDUs(context.Background(), oq.destination, eduReceipts); err != nil { if err = oq.db.CleanEDUs(oq.process.Context(), oq.destination, eduReceipts); err != nil {
logrus.WithError(err).Errorf("Failed to clean EDUs for server %q", t.Destination) logrus.WithError(err).Errorf("Failed to clean EDUs for server %q", t.Destination)
} }
} }

View file

@ -15,7 +15,6 @@
package queue package queue
import ( import (
"context"
"crypto/ed25519" "crypto/ed25519"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -105,14 +104,14 @@ func NewOutgoingQueues(
// Look up which servers we have pending items for and then rehydrate those queues. // Look up which servers we have pending items for and then rehydrate those queues.
if !disabled { if !disabled {
serverNames := map[gomatrixserverlib.ServerName]struct{}{} serverNames := map[gomatrixserverlib.ServerName]struct{}{}
if names, err := db.GetPendingPDUServerNames(context.Background()); err == nil { if names, err := db.GetPendingPDUServerNames(process.Context()); err == nil {
for _, serverName := range names { for _, serverName := range names {
serverNames[serverName] = struct{}{} serverNames[serverName] = struct{}{}
} }
} else { } else {
log.WithError(err).Error("Failed to get PDU server names for destination queue hydration") log.WithError(err).Error("Failed to get PDU server names for destination queue hydration")
} }
if names, err := db.GetPendingEDUServerNames(context.Background()); err == nil { if names, err := db.GetPendingEDUServerNames(process.Context()); err == nil {
for _, serverName := range names { for _, serverName := range names {
serverNames[serverName] = struct{}{} serverNames[serverName] = struct{}{}
} }
@ -215,7 +214,7 @@ func (oqs *OutgoingQueues) SendEvent(
// Check if any of the destinations are prohibited by server ACLs. // Check if any of the destinations are prohibited by server ACLs.
for destination := range destmap { for destination := range destmap {
if api.IsServerBannedFromRoom( if api.IsServerBannedFromRoom(
context.TODO(), oqs.process.Context(),
oqs.rsAPI, oqs.rsAPI,
ev.RoomID(), ev.RoomID(),
destination, destination,
@ -238,7 +237,7 @@ func (oqs *OutgoingQueues) SendEvent(
return fmt.Errorf("json.Marshal: %w", err) return fmt.Errorf("json.Marshal: %w", err)
} }
nid, err := oqs.db.StoreJSON(context.TODO(), string(headeredJSON)) nid, err := oqs.db.StoreJSON(oqs.process.Context(), string(headeredJSON))
if err != nil { if err != nil {
return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err) return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err)
} }
@ -286,7 +285,7 @@ func (oqs *OutgoingQueues) SendEDU(
if result := gjson.GetBytes(e.Content, "room_id"); result.Exists() { if result := gjson.GetBytes(e.Content, "room_id"); result.Exists() {
for destination := range destmap { for destination := range destmap {
if api.IsServerBannedFromRoom( if api.IsServerBannedFromRoom(
context.TODO(), oqs.process.Context(),
oqs.rsAPI, oqs.rsAPI,
result.Str, result.Str,
destination, destination,
@ -310,7 +309,7 @@ func (oqs *OutgoingQueues) SendEDU(
return fmt.Errorf("json.Marshal: %w", err) return fmt.Errorf("json.Marshal: %w", err)
} }
nid, err := oqs.db.StoreJSON(context.TODO(), string(ephemeralJSON)) nid, err := oqs.db.StoreJSON(oqs.process.Context(), string(ephemeralJSON))
if err != nil { if err != nil {
return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err) return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err)
} }

View file

@ -469,14 +469,14 @@ func (b *BaseDendrite) SetupAndServeHTTP(
} }
minwinsvc.SetOnExit(b.ProcessContext.ShutdownDendrite) minwinsvc.SetOnExit(b.ProcessContext.ShutdownDendrite)
b.WaitForShutdown()
ctx, cancel := context.WithCancel(context.Background()) <-b.ProcessContext.WaitForShutdown()
defer cancel() logrus.Infof("Stopping HTTP listeners")
_ = internalServ.Shutdown(context.Background())
_ = internalServ.Shutdown(ctx) _ = externalServ.Shutdown(context.Background())
_ = externalServ.Shutdown(ctx)
logrus.Infof("Stopped HTTP listeners") logrus.Infof("Stopped HTTP listeners")
b.WaitForShutdown()
} }
func (b *BaseDendrite) WaitForShutdown() { func (b *BaseDendrite) WaitForShutdown() {

View file

@ -35,6 +35,16 @@ func JetStreamConsumer(
} }
go func() { go func() {
for { for {
// If the parent context has given up then there's no point in
// carrying on doing anything, so stop the listener.
select {
case <-ctx.Done():
if err := sub.Unsubscribe(); err != nil {
logrus.WithContext(ctx).Warnf("Failed to unsubscribe %q", durable)
}
return
default:
}
// The context behaviour here is surprising — we supply a context // The context behaviour here is surprising — we supply a context
// so that we can interrupt the fetch if we want, but NATS will still // so that we can interrupt the fetch if we want, but NATS will still
// enforce its own deadline (roughly 5 seconds by default). Therefore // enforce its own deadline (roughly 5 seconds by default). Therefore
@ -65,18 +75,18 @@ func JetStreamConsumer(
continue continue
} }
msg := msgs[0] msg := msgs[0]
if err = msg.InProgress(); err != nil { if err = msg.InProgress(nats.Context(ctx)); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err)) logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err))
sentry.CaptureException(err) sentry.CaptureException(err)
continue continue
} }
if f(ctx, msg) { if f(ctx, msg) {
if err = msg.AckSync(); err != nil { if err = msg.AckSync(nats.Context(ctx)); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err)) logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err))
sentry.CaptureException(err) sentry.CaptureException(err)
} }
} else { } else {
if err = msg.Nak(); err != nil { if err = msg.Nak(nats.Context(ctx)); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err)) logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err))
sentry.CaptureException(err) sentry.CaptureException(err)
} }