diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 5bfe237a8..af027c3c3 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -15,6 +15,8 @@ package federationapi import ( + "time" + "github.com/gorilla/mux" "github.com/matrix-org/dendrite/federationapi/api" federationAPI "github.com/matrix-org/dendrite/federationapi/api" @@ -156,5 +158,16 @@ func NewInternalAPI( if err = presenceConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start presence consumer") } + + var cleanExpiredEDUs func() + cleanExpiredEDUs = func() { + logrus.Infof("Cleaning expired EDUs") + if err := federationDB.DeleteExpiredEDUs(base.Context()); err != nil { + logrus.WithError(err).Error("Failed to clean expired EDUs") + } + time.AfterFunc(time.Hour, cleanExpiredEDUs) + } + time.AfterFunc(time.Minute, cleanExpiredEDUs) + return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, stats, caches, queues, keyRing) } diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index e3038651b..9f2b992f7 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -71,4 +71,6 @@ type Database interface { // Query the notary for the server keys for the given server. If `optKeyIDs` is not empty, multiple server keys may be returned (between 1 - len(optKeyIDs)) // such that the combination of all server keys will include all the `optKeyIDs`. GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) + // DeleteExpiredEDUs cleans up expired EDUs + DeleteExpiredEDUs(ctx context.Context) error } diff --git a/federationapi/storage/postgres/queue_edus_table.go b/federationapi/storage/postgres/queue_edus_table.go index 6cac489bf..0e3122026 100644 --- a/federationapi/storage/postgres/queue_edus_table.go +++ b/federationapi/storage/postgres/queue_edus_table.go @@ -31,7 +31,9 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( -- The domain part of the user ID the EDU event is for. server_name TEXT NOT NULL, -- The JSON NID from the federationsender_queue_edus_json table. - json_nid BIGINT NOT NULL + json_nid BIGINT NOT NULL, + -- The expiry time of this edu, if any. + expires_at BIGINT ); CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx @@ -40,7 +42,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx const insertQueueEDUSQL = "" + "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" + - " VALUES ($1, $2, $3)" + " VALUES ($1, $2, $3, $4)" const deleteQueueEDUSQL = "" + "DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid = ANY($2)" @@ -61,6 +63,12 @@ const selectQueueEDUCountSQL = "" + const selectQueueServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_edus" +const selectExpiredEDUsSQL = "" + + "SELECT DISTINCT json_nid FROM federationsender_queue_edus WHERE expires_at IS NOT NULL AND expires_at <= $1" + +const deleteExpiredEDUsSQL = "" + + "DELETE FROM federationsender_queue_edus WHERE expires_at IS NOT NULL AND expires_at <= $1" + type queueEDUsStatements struct { db *sql.DB insertQueueEDUStmt *sql.Stmt @@ -69,6 +77,8 @@ type queueEDUsStatements struct { selectQueueEDUReferenceJSONCountStmt *sql.Stmt selectQueueEDUCountStmt *sql.Stmt selectQueueEDUServerNamesStmt *sql.Stmt + selectExpiredEDUsStmt *sql.Stmt + deleteExpiredEDUsStmt *sql.Stmt } func NewPostgresQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { @@ -79,25 +89,18 @@ func NewPostgresQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { if err != nil { return } - if s.insertQueueEDUStmt, err = s.db.Prepare(insertQueueEDUSQL); err != nil { - return - } if s.deleteQueueEDUStmt, err = s.db.Prepare(deleteQueueEDUSQL); err != nil { return } - if s.selectQueueEDUStmt, err = s.db.Prepare(selectQueueEDUSQL); err != nil { - return - } - if s.selectQueueEDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil { - return - } - if s.selectQueueEDUCountStmt, err = s.db.Prepare(selectQueueEDUCountSQL); err != nil { - return - } - if s.selectQueueEDUServerNamesStmt, err = s.db.Prepare(selectQueueServerNamesSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertQueueEDUStmt, insertQueueEDUSQL}, + {&s.selectQueueEDUStmt, selectQueueEDUSQL}, + {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL}, + {&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL}, + {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL}, + {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL}, + {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL}, + }.Prepare(db) } func (s *queueEDUsStatements) InsertQueueEDU( @@ -106,6 +109,7 @@ func (s *queueEDUsStatements) InsertQueueEDU( eduType string, serverName gomatrixserverlib.ServerName, nid int64, + expiresAt *gomatrixserverlib.Timestamp, ) error { stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) _, err := stmt.ExecContext( @@ -113,6 +117,7 @@ func (s *queueEDUsStatements) InsertQueueEDU( eduType, // the EDU type serverName, // destination server name nid, // JSON blob NID + expiresAt, // timestamp of expiry ) return err } @@ -146,7 +151,7 @@ func (s *queueEDUsStatements) SelectQueueEDUs( } result = append(result, nid) } - return result, nil + return result, rows.Err() } func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( @@ -196,3 +201,33 @@ func (s *queueEDUsStatements) SelectQueueEDUServerNames( return result, rows.Err() } + +func (s *queueEDUsStatements) SelectExpiredEDUs( + ctx context.Context, txn *sql.Tx, + expiredBefore gomatrixserverlib.Timestamp, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectExpiredEDUsStmt) + rows, err := stmt.QueryContext(ctx, expiredBefore) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectExpiredEDUs: rows.close() failed") + var result []int64 + var nid int64 + for rows.Next() { + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + return result, rows.Err() +} + +func (s *queueEDUsStatements) DeleteExpiredEDUs( + ctx context.Context, txn *sql.Tx, + expiredBefore gomatrixserverlib.Timestamp, +) error { + stmt := sqlutil.TxStmt(txn, s.selectExpiredEDUsStmt) + _, err := stmt.ExecContext(ctx, expiredBefore) + return err +} diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index b2aea6929..e087af3d5 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -84,6 +84,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationC } m := sqlutil.NewMigrations() deltas.LoadRemoveRoomsTable(m) + deltas.LoadAddExpiresAt(m) if err = m.RunDeltas(d.db, dbProperties); err != nil { return nil, err } diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index 02a23338f..28730e8c7 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -20,10 +20,19 @@ import ( "encoding/json" "errors" "fmt" + "time" "github.com/matrix-org/gomatrixserverlib" ) +// expireEDUTypes contains EDUs which can/should be expired after a given time +// if the target server isn't reachable for some reason. +var expireEDUTypes = map[string]struct{}{ + gomatrixserverlib.MTyping: {}, + gomatrixserverlib.MPresence: {}, + gomatrixserverlib.MReceipt: {}, +} + // AssociateEDUWithDestination creates an association that the // destination queues will use to determine which JSON blobs to send // to which servers. @@ -33,6 +42,12 @@ func (d *Database) AssociateEDUWithDestination( receipt *Receipt, eduType string, ) error { + var expiresAt *gomatrixserverlib.Timestamp + if _, ok := expireEDUTypes[eduType]; ok { + // Keep EDUs for at least one hour before deleting them + ts := gomatrixserverlib.AsTimestamp(time.Now().Add(time.Hour)) + expiresAt = &ts + } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if err := d.FederationQueueEDUs.InsertQueueEDU( ctx, // context @@ -40,6 +55,7 @@ func (d *Database) AssociateEDUWithDestination( eduType, // EDU type for coalescing serverName, // destination server name receipt.nid, // NID from the federationapi_queue_json table + expiresAt, // The timestamp this EDU will expire ); err != nil { return fmt.Errorf("InsertQueueEDU: %w", err) } @@ -150,3 +166,23 @@ func (d *Database) GetPendingEDUServerNames( ) ([]gomatrixserverlib.ServerName, error) { return d.FederationQueueEDUs.SelectQueueEDUServerNames(ctx, nil) } + +// DeleteExpiredEDUs deletes expired EDUs +func (d *Database) DeleteExpiredEDUs(ctx context.Context) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + expiredBefore := gomatrixserverlib.AsTimestamp(time.Now()) + jsonNIDs, err := d.FederationQueueEDUs.SelectExpiredEDUs(ctx, txn, expiredBefore) + if err != nil { + return err + } + if len(jsonNIDs) == 0 { + return nil + } + + if err = d.FederationQueueJSON.DeleteQueueJSON(ctx, txn, jsonNIDs); err != nil { + return err + } + + return d.FederationQueueEDUs.DeleteExpiredEDUs(ctx, txn, expiredBefore) + }) +} diff --git a/federationapi/storage/sqlite3/queue_edus_table.go b/federationapi/storage/sqlite3/queue_edus_table.go index a6d609508..64e4683a5 100644 --- a/federationapi/storage/sqlite3/queue_edus_table.go +++ b/federationapi/storage/sqlite3/queue_edus_table.go @@ -32,7 +32,9 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( -- The domain part of the user ID the EDU event is for. server_name TEXT NOT NULL, -- The JSON NID from the federationsender_queue_edus_json table. - json_nid BIGINT NOT NULL + json_nid BIGINT NOT NULL, + -- The expiry time of this edu, if any. + expires_at BIGINT ); CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx @@ -41,7 +43,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx const insertQueueEDUSQL = "" + "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" + - " VALUES ($1, $2, $3)" + " VALUES ($1, $2, $3, $4)" const deleteQueueEDUsSQL = "" + "DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)" @@ -62,13 +64,22 @@ const selectQueueEDUCountSQL = "" + const selectQueueServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_edus" +const selectExpiredEDUsSQL = "" + + "SELECT DISTINCT json_nid FROM federationsender_queue_edus WHERE expires_at IS NOT NULL AND expires_at <= $1" + +const deleteExpiredEDUsSQL = "" + + "DELETE FROM federationsender_queue_edus WHERE expires_at IS NOT NULL AND expires_at <= $1" + type queueEDUsStatements struct { - db *sql.DB - insertQueueEDUStmt *sql.Stmt + db *sql.DB + insertQueueEDUStmt *sql.Stmt + // deleteQueueEDUStmt *sql.Stmt - prepared at runtime due to variadic selectQueueEDUStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt selectQueueEDUCountStmt *sql.Stmt selectQueueEDUServerNamesStmt *sql.Stmt + selectExpiredEDUsStmt *sql.Stmt + deleteExpiredEDUsStmt *sql.Stmt } func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { @@ -79,22 +90,15 @@ func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { if err != nil { return } - if s.insertQueueEDUStmt, err = db.Prepare(insertQueueEDUSQL); err != nil { - return - } - if s.selectQueueEDUStmt, err = db.Prepare(selectQueueEDUSQL); err != nil { - return - } - if s.selectQueueEDUReferenceJSONCountStmt, err = db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil { - return - } - if s.selectQueueEDUCountStmt, err = db.Prepare(selectQueueEDUCountSQL); err != nil { - return - } - if s.selectQueueEDUServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertQueueEDUStmt, insertQueueEDUSQL}, + {&s.selectQueueEDUStmt, selectQueueEDUSQL}, + {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL}, + {&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL}, + {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL}, + {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL}, + {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL}, + }.Prepare(db) } func (s *queueEDUsStatements) InsertQueueEDU( @@ -103,6 +107,7 @@ func (s *queueEDUsStatements) InsertQueueEDU( eduType string, serverName gomatrixserverlib.ServerName, nid int64, + expiresAt *gomatrixserverlib.Timestamp, ) error { stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) _, err := stmt.ExecContext( @@ -110,6 +115,7 @@ func (s *queueEDUsStatements) InsertQueueEDU( eduType, // the EDU type serverName, // destination server name nid, // JSON blob NID + expiresAt, // timestamp of expiry ) return err } @@ -155,7 +161,7 @@ func (s *queueEDUsStatements) SelectQueueEDUs( } result = append(result, nid) } - return result, nil + return result, rows.Err() } func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( @@ -205,3 +211,33 @@ func (s *queueEDUsStatements) SelectQueueEDUServerNames( return result, rows.Err() } + +func (s *queueEDUsStatements) SelectExpiredEDUs( + ctx context.Context, txn *sql.Tx, + expiredBefore gomatrixserverlib.Timestamp, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectExpiredEDUsStmt) + rows, err := stmt.QueryContext(ctx, expiredBefore) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectExpiredEDUs: rows.close() failed") + var result []int64 + var nid int64 + for rows.Next() { + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + return result, rows.Err() +} + +func (s *queueEDUsStatements) DeleteExpiredEDUs( + ctx context.Context, txn *sql.Tx, + expiredBefore gomatrixserverlib.Timestamp, +) error { + stmt := sqlutil.TxStmt(txn, s.selectExpiredEDUsStmt) + _, err := stmt.ExecContext(ctx, expiredBefore) + return err +} diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index c2e83211e..939f466c4 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -83,6 +83,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationC } m := sqlutil.NewMigrations() deltas.LoadRemoveRoomsTable(m) + deltas.LoadAddExpiresAt(m) if err = m.RunDeltas(d.db, dbProperties); err != nil { return nil, err } diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 19357393d..64c66da90 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -34,12 +34,14 @@ type FederationQueuePDUs interface { } type FederationQueueEDUs interface { - InsertQueueEDU(ctx context.Context, txn *sql.Tx, eduType string, serverName gomatrixserverlib.ServerName, nid int64) error + InsertQueueEDU(ctx context.Context, txn *sql.Tx, eduType string, serverName gomatrixserverlib.ServerName, nid int64, expiresAt *gomatrixserverlib.Timestamp) error DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) + SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error) + DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error } type FederationQueueJSON interface {