From 07bfb791ca616bd3a4aa96691b74c96146d59d90 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 17 Oct 2022 14:48:35 +0200 Subject: [PATCH 1/7] Scope transactions to endpoints (#2799) To avoid returning results from e.g. `/redact` on `/sendToDevice` requests. Takes the raw URL path and uses `filepath.Dir` to remove the `txnID` (file) from it. Co-authored-by: Neil Alexander --- clientapi/routing/redaction.go | 9 ++--- clientapi/routing/sendevent.go | 4 +-- clientapi/routing/sendtodevice.go | 7 ++-- clientapi/routing/server_notices.go | 7 ++-- internal/transactions/transactions.go | 16 +++++---- internal/transactions/transactions_test.go | 42 ++++++++++++++++++---- 6 files changed, 59 insertions(+), 26 deletions(-) diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 27f0ba5d0..a0f3b1152 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -19,6 +19,9 @@ import ( "net/http" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" @@ -26,8 +29,6 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" ) type redactionContent struct { @@ -51,7 +52,7 @@ func SendRedaction( if txnID != nil { // Try to fetch response from transactionsCache - if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok { return *res } } @@ -144,7 +145,7 @@ func SendRedaction( // Add response to transactionsCache if txnID != nil { - txnCache.AddTransaction(device.AccessToken, *txnID, &res) + txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res) } return res diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 85f1053f3..114e9088d 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -86,7 +86,7 @@ func SendEvent( if txnID != nil { // Try to fetch response from transactionsCache - if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok { return *res } } @@ -206,7 +206,7 @@ func SendEvent( } // Add response to transactionsCache if txnID != nil { - txnCache.AddTransaction(device.AccessToken, *txnID, &res) + txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res) } // Take a note of how long it took to generate the event vs submit diff --git a/clientapi/routing/sendtodevice.go b/clientapi/routing/sendtodevice.go index 4a5f08883..0c0227937 100644 --- a/clientapi/routing/sendtodevice.go +++ b/clientapi/routing/sendtodevice.go @@ -16,12 +16,13 @@ import ( "encoding/json" "net/http" + "github.com/matrix-org/util" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/internal/transactions" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/util" ) // SendToDevice handles PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId} @@ -33,7 +34,7 @@ func SendToDevice( eventType string, txnID *string, ) util.JSONResponse { if txnID != nil { - if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok { return *res } } @@ -63,7 +64,7 @@ func SendToDevice( } if txnID != nil { - txnCache.AddTransaction(device.AccessToken, *txnID, &res) + txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res) } return res diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index 9edeed2f7..7729eddd8 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -21,7 +21,6 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/tokens" @@ -29,6 +28,8 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/roomserver/version" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -73,7 +74,7 @@ func SendServerNotice( if txnID != nil { // Try to fetch response from transactionsCache - if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok { return *res } } @@ -251,7 +252,7 @@ func SendServerNotice( } // Add response to transactionsCache if txnID != nil { - txnCache.AddTransaction(device.AccessToken, *txnID, &res) + txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res) } // Take a note of how long it took to generate the event vs submit diff --git a/internal/transactions/transactions.go b/internal/transactions/transactions.go index d2eb0f27f..7ff6f5044 100644 --- a/internal/transactions/transactions.go +++ b/internal/transactions/transactions.go @@ -13,6 +13,8 @@ package transactions import ( + "net/url" + "path/filepath" "sync" "time" @@ -29,6 +31,7 @@ type txnsMap map[CacheKey]*util.JSONResponse type CacheKey struct { AccessToken string TxnID string + Endpoint string } // Cache represents a temporary store for response entries. @@ -57,14 +60,14 @@ func NewWithCleanupPeriod(cleanupPeriod time.Duration) *Cache { return &t } -// FetchTransaction looks up an entry for the (accessToken, txnID) tuple in Cache. +// FetchTransaction looks up an entry for the (accessToken, txnID, req.URL) tuple in Cache. // Looks in both the txnMaps. // Returns (JSON response, true) if txnID is found, else the returned bool is false. -func (t *Cache) FetchTransaction(accessToken, txnID string) (*util.JSONResponse, bool) { +func (t *Cache) FetchTransaction(accessToken, txnID string, u *url.URL) (*util.JSONResponse, bool) { t.RLock() defer t.RUnlock() for _, txns := range t.txnsMaps { - res, ok := txns[CacheKey{accessToken, txnID}] + res, ok := txns[CacheKey{accessToken, txnID, filepath.Dir(u.Path)}] if ok { return res, true } @@ -72,13 +75,12 @@ func (t *Cache) FetchTransaction(accessToken, txnID string) (*util.JSONResponse, return nil, false } -// AddTransaction adds an entry for the (accessToken, txnID) tuple in Cache. +// AddTransaction adds an entry for the (accessToken, txnID, req.URL) tuple in Cache. // Adds to the front txnMap. -func (t *Cache) AddTransaction(accessToken, txnID string, res *util.JSONResponse) { +func (t *Cache) AddTransaction(accessToken, txnID string, u *url.URL, res *util.JSONResponse) { t.Lock() defer t.Unlock() - - t.txnsMaps[0][CacheKey{accessToken, txnID}] = res + t.txnsMaps[0][CacheKey{accessToken, txnID, filepath.Dir(u.Path)}] = res } // cacheCleanService is responsible for cleaning up entries after cleanupPeriod. diff --git a/internal/transactions/transactions_test.go b/internal/transactions/transactions_test.go index aa837f76c..c552550ac 100644 --- a/internal/transactions/transactions_test.go +++ b/internal/transactions/transactions_test.go @@ -14,6 +14,9 @@ package transactions import ( "net/http" + "net/url" + "path/filepath" + "reflect" "strconv" "testing" @@ -24,6 +27,16 @@ type fakeType struct { ID string `json:"ID"` } +func TestCompare(t *testing.T) { + u1, _ := url.Parse("/send/1?accessToken=123") + u2, _ := url.Parse("/send/1") + c1 := CacheKey{"1", "2", filepath.Dir(u1.Path)} + c2 := CacheKey{"1", "2", filepath.Dir(u2.Path)} + if !reflect.DeepEqual(c1, c2) { + t.Fatalf("Cache keys differ: %+v <> %+v", c1, c2) + } +} + var ( fakeAccessToken = "aRandomAccessToken" fakeAccessToken2 = "anotherRandomAccessToken" @@ -34,23 +47,28 @@ var ( fakeResponse2 = &util.JSONResponse{ Code: http.StatusOK, JSON: fakeType{ID: "1"}, } + fakeResponse3 = &util.JSONResponse{ + Code: http.StatusOK, JSON: fakeType{ID: "2"}, + } ) // TestCache creates a New Cache and tests AddTransaction & FetchTransaction func TestCache(t *testing.T) { fakeTxnCache := New() - fakeTxnCache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse) + u, _ := url.Parse("") + fakeTxnCache.AddTransaction(fakeAccessToken, fakeTxnID, u, fakeResponse) // Add entries for noise. for i := 1; i <= 100; i++ { fakeTxnCache.AddTransaction( fakeAccessToken, fakeTxnID+strconv.Itoa(i), + u, &util.JSONResponse{Code: http.StatusOK, JSON: fakeType{ID: strconv.Itoa(i)}}, ) } - testResponse, ok := fakeTxnCache.FetchTransaction(fakeAccessToken, fakeTxnID) + testResponse, ok := fakeTxnCache.FetchTransaction(fakeAccessToken, fakeTxnID, u) if !ok { t.Error("Failed to retrieve entry for txnID: ", fakeTxnID) } else if testResponse.JSON != fakeResponse.JSON { @@ -59,20 +77,30 @@ func TestCache(t *testing.T) { } // TestCacheScope ensures transactions with the same transaction ID are not shared -// across multiple access tokens. +// across multiple access tokens and endpoints. func TestCacheScope(t *testing.T) { cache := New() - cache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse) - cache.AddTransaction(fakeAccessToken2, fakeTxnID, fakeResponse2) + sendEndpoint, _ := url.Parse("/send/1?accessToken=test") + sendToDeviceEndpoint, _ := url.Parse("/sendToDevice/1") + cache.AddTransaction(fakeAccessToken, fakeTxnID, sendEndpoint, fakeResponse) + cache.AddTransaction(fakeAccessToken2, fakeTxnID, sendEndpoint, fakeResponse2) + cache.AddTransaction(fakeAccessToken2, fakeTxnID, sendToDeviceEndpoint, fakeResponse3) - if res, ok := cache.FetchTransaction(fakeAccessToken, fakeTxnID); !ok { + if res, ok := cache.FetchTransaction(fakeAccessToken, fakeTxnID, sendEndpoint); !ok { t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID) } else if res.JSON != fakeResponse.JSON { t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse.JSON, res.JSON) } - if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID); !ok { + if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID, sendEndpoint); !ok { t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID) } else if res.JSON != fakeResponse2.JSON { t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse2.JSON, res.JSON) } + + // Ensure the txnID is not shared across endpoints + if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID, sendToDeviceEndpoint); !ok { + t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID) + } else if res.JSON != fakeResponse3.JSON { + t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse2.JSON, res.JSON) + } } From 9c189b1b80f9c338ac3cfa71cdaca60016de45f7 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 18 Oct 2022 09:51:31 +0100 Subject: [PATCH 2/7] Try to make `AddEvent` less expensive (update to matrix-org/gomatrixserverlib@a72a83f) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 1303d4004..911d36c1c 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241 + github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index 54055f2fd..a141fc9b4 100644 --- a/go.sum +++ b/go.sum @@ -387,8 +387,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241 h1:e5o68MWeU7wjTvvNKmVo655oCYesoNRoPeBb1Xfz54g= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a h1:bQKHk3AWlgm7XhzPhuU3Iw3pUptW5l1DR/1y0o7zCKQ= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 h1:lzkSQvBv8TuqKJCPoVwOVvEnARTlua5rrNy/Qw2Vxeo= github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= From 3aa92efaa3e814ad0596fc5fc174a2e43124dcf5 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 18 Oct 2022 15:59:08 +0100 Subject: [PATCH 3/7] Namespace user API tables (#2806) This migrates all the various user API tables, indices and sequences to be `userapi_`-namespaced, rather than the mess they are all now. --- userapi/consumers/roomserver_test.go | 9 +- .../storage/postgres/account_data_table.go | 8 +- userapi/storage/postgres/accounts_table.go | 14 +-- .../deltas/20200929203058_is_active.go | 4 +- .../deltas/20201001204705_last_seen_ts_ip.go | 12 +- .../2022021013023800_add_account_type.go | 10 +- .../deltas/2022101711000000_rename_tables.go | 102 ++++++++++++++++ userapi/storage/postgres/devices_table.go | 28 ++--- userapi/storage/postgres/key_backup_table.go | 18 +-- .../postgres/key_backup_version_table.go | 20 ++-- userapi/storage/postgres/logintoken_table.go | 12 +- userapi/storage/postgres/openid_table.go | 6 +- userapi/storage/postgres/profile_table.go | 12 +- userapi/storage/postgres/stats_table.go | 16 +-- userapi/storage/postgres/storage.go | 11 ++ userapi/storage/postgres/threepid_table.go | 12 +- userapi/storage/sqlite3/account_data_table.go | 8 +- userapi/storage/sqlite3/accounts_table.go | 14 +-- .../deltas/20200929203058_is_active.go | 20 ++-- .../deltas/20201001204705_last_seen_ts_ip.go | 20 ++-- .../2022021012490600_add_account_type.go | 16 +-- .../deltas/2022101711000000_rename_tables.go | 109 ++++++++++++++++++ userapi/storage/sqlite3/devices_table.go | 24 ++-- userapi/storage/sqlite3/key_backup_table.go | 18 +-- .../sqlite3/key_backup_version_table.go | 16 +-- userapi/storage/sqlite3/logintoken_table.go | 12 +- userapi/storage/sqlite3/openid_table.go | 6 +- userapi/storage/sqlite3/profile_table.go | 12 +- userapi/storage/sqlite3/stats_table.go | 16 +-- userapi/storage/sqlite3/storage.go | 11 ++ userapi/storage/sqlite3/threepid_table.go | 12 +- userapi/storage/storage_test.go | 9 +- userapi/storage/tables/stats_table_test.go | 4 +- userapi/userapi_test.go | 14 ++- 34 files changed, 441 insertions(+), 194 deletions(-) create mode 100644 userapi/storage/postgres/deltas/2022101711000000_rename_tables.go create mode 100644 userapi/storage/sqlite3/deltas/2022101711000000_rename_tables.go diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 3bbeb439a..e4587670f 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -10,19 +10,24 @@ import ( "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/storage" ) func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + base, baseclose := testrig.CreateBaseDendrite(t, dbType) t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{ + db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "", 4, 0, 0, "") if err != nil { t.Fatalf("failed to create new user db: %v", err) } - return db, close + return db, func() { + close() + baseclose() + } } func mustCreateEvent(t *testing.T, content string) *gomatrixserverlib.HeaderedEvent { diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 67113367b..0b6a3af6d 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -26,7 +26,7 @@ import ( const accountDataSchema = ` -- Stores data about accounts data. -CREATE TABLE IF NOT EXISTS account_data ( +CREATE TABLE IF NOT EXISTS userapi_account_datas ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL, -- The room ID for this data (empty string if not specific to a room) @@ -41,15 +41,15 @@ CREATE TABLE IF NOT EXISTS account_data ( ` const insertAccountDataSQL = ` - INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) + INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4) ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content ` const selectAccountDataSQL = "" + - "SELECT room_id, type, content FROM account_data WHERE localpart = $1" + "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1" const selectAccountDataByTypeSQL = "" + - "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" + "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3" type accountDataStatements struct { insertAccountDataStmt *sql.Stmt diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 33fb6dd42..7c309eb4f 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -32,7 +32,7 @@ import ( const accountsSchema = ` -- Stores data about accounts. -CREATE TABLE IF NOT EXISTS account_accounts ( +CREATE TABLE IF NOT EXISTS userapi_accounts ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL PRIMARY KEY, -- When this account was first created, as a unix timestamp (ms resolution). @@ -51,22 +51,22 @@ CREATE TABLE IF NOT EXISTS account_accounts ( ` const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" const updatePasswordSQL = "" + - "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" + "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" const deactivateAccountSQL = "" + - "UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1" + "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" + "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" const selectPasswordHashSQL = "" + - "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE" const selectNewNumericLocalpartSQL = "" + - "SELECT COALESCE(MAX(localpart::bigint), 0) FROM account_accounts WHERE localpart ~ '^[0-9]{1,}$'" + "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'" type accountsStatements struct { insertAccountStmt *sql.Stmt diff --git a/userapi/storage/postgres/deltas/20200929203058_is_active.go b/userapi/storage/postgres/deltas/20200929203058_is_active.go index 24f87e073..2c5cc2f58 100644 --- a/userapi/storage/postgres/deltas/20200929203058_is_active.go +++ b/userapi/storage/postgres/deltas/20200929203058_is_active.go @@ -7,7 +7,7 @@ import ( ) func UpIsActive(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;") + _, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;") if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } @@ -15,7 +15,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error { } func DownIsActive(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN is_deactivated;") + _, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts DROP COLUMN is_deactivated;") if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go index edd3353f0..40e237027 100644 --- a/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go @@ -8,9 +8,9 @@ import ( func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` -ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000; -ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS ip TEXT; -ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`) +ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000; +ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS ip TEXT; +ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`) if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } @@ -19,9 +19,9 @@ ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`) func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` - ALTER TABLE device_devices DROP COLUMN last_seen_ts; - ALTER TABLE device_devices DROP COLUMN ip; - ALTER TABLE device_devices DROP COLUMN user_agent;`) + ALTER TABLE userapi_devices DROP COLUMN last_seen_ts; + ALTER TABLE userapi_devices DROP COLUMN ip; + ALTER TABLE userapi_devices DROP COLUMN user_agent;`) if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go index eb7c3a958..164847e51 100644 --- a/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go +++ b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go @@ -9,10 +9,10 @@ import ( func UpAddAccountType(ctx context.Context, tx *sql.Tx) error { // initially set every account to useraccount, change appservice and guest accounts afterwards // (user = 1, guest = 2, admin = 3, appservice = 4) - _, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1; -UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> ''; -UPDATE account_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$'; -ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`, + _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1; +UPDATE userapi_accounts SET account_type = 4 WHERE appservice_id <> ''; +UPDATE userapi_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$'; +ALTER TABLE userapi_accounts ALTER COLUMN account_type DROP DEFAULT;`, ) if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) @@ -21,7 +21,7 @@ ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`, } func DownAddAccountType(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN account_type;") + _, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts DROP COLUMN account_type;") if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/postgres/deltas/2022101711000000_rename_tables.go b/userapi/storage/postgres/deltas/2022101711000000_rename_tables.go new file mode 100644 index 000000000..1d73d0af4 --- /dev/null +++ b/userapi/storage/postgres/deltas/2022101711000000_rename_tables.go @@ -0,0 +1,102 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" +) + +var renameTableMappings = map[string]string{ + "account_accounts": "userapi_accounts", + "account_data": "userapi_account_datas", + "device_devices": "userapi_devices", + "account_e2e_room_keys": "userapi_key_backups", + "account_e2e_room_keys_versions": "userapi_key_backup_versions", + "login_tokens": "userapi_login_tokens", + "open_id_tokens": "userapi_openid_tokens", + "account_profiles": "userapi_profiles", + "account_threepid": "userapi_threepids", +} + +var renameSequenceMappings = map[string]string{ + "device_session_id_seq": "userapi_device_session_id_seq", + "account_e2e_room_keys_versions_seq": "userapi_key_backup_versions_seq", +} + +var renameIndicesMappings = map[string]string{ + "device_localpart_id_idx": "userapi_device_localpart_id_idx", + "e2e_room_keys_idx": "userapi_key_backups_idx", + "e2e_room_keys_versions_idx": "userapi_key_backups_versions_idx", + "account_e2e_room_keys_versions_idx": "userapi_key_backup_versions_idx", + "login_tokens_expiration_idx": "userapi_login_tokens_expiration_idx", + "account_threepid_localpart": "userapi_threepid_idx", +} + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpRenameTables(ctx context.Context, tx *sql.Tx) error { + for old, new := range renameTableMappings { + q := fmt.Sprintf( + "ALTER TABLE IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(old), pq.QuoteIdentifier(new), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", old, new, err) + } + } + for old, new := range renameSequenceMappings { + q := fmt.Sprintf( + "ALTER SEQUENCE IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(old), pq.QuoteIdentifier(new), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", old, new, err) + } + } + for old, new := range renameIndicesMappings { + q := fmt.Sprintf( + "ALTER INDEX IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(old), pq.QuoteIdentifier(new), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", old, new, err) + } + } + return nil +} + +func DownRenameTables(ctx context.Context, tx *sql.Tx) error { + for old, new := range renameTableMappings { + q := fmt.Sprintf( + "ALTER TABLE IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(new), pq.QuoteIdentifier(old), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", new, old, err) + } + } + for old, new := range renameSequenceMappings { + q := fmt.Sprintf( + "ALTER SEQUENCE IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(new), pq.QuoteIdentifier(old), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", new, old, err) + } + } + for old, new := range renameIndicesMappings { + q := fmt.Sprintf( + "ALTER INDEX IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(new), pq.QuoteIdentifier(old), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", new, old, err) + } + } + return nil +} diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index f65681aae..8b7fbd6cf 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -31,10 +31,10 @@ import ( const devicesSchema = ` -- This sequence is used for automatic allocation of session_id. -CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1; +CREATE SEQUENCE IF NOT EXISTS userapi_device_session_id_seq START 1; -- Stores data about devices. -CREATE TABLE IF NOT EXISTS device_devices ( +CREATE TABLE IF NOT EXISTS userapi_devices ( -- The access token granted to this device. This has to be the primary key -- so we can distinguish which device is making a given request. access_token TEXT NOT NULL PRIMARY KEY, @@ -42,7 +42,7 @@ CREATE TABLE IF NOT EXISTS device_devices ( -- This can be used as a secure substitution of the access token in situations -- where data is associated with access tokens (e.g. transaction storage), -- so we don't have to store users' access tokens everywhere. - session_id BIGINT NOT NULL DEFAULT nextval('device_session_id_seq'), + session_id BIGINT NOT NULL DEFAULT nextval('userapi_device_session_id_seq'), -- The device identifier. This only needs to uniquely identify a device for a given user, not globally. -- access_tokens will be clobbered based on the device ID for a user. device_id TEXT NOT NULL, @@ -65,39 +65,39 @@ CREATE TABLE IF NOT EXISTS device_devices ( ); -- Device IDs must be unique for a given user. -CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(localpart, device_id); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, device_id); ` const insertDeviceSQL = "" + - "INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + + "INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + " RETURNING session_id" const selectDeviceByTokenSQL = "" + - "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" + "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + - "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" const deleteDeviceSQL = "" + - "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" + "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" + "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" const deleteDevicesSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" + "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + - "UPDATE device_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" type devicesStatements struct { insertDeviceStmt *sql.Stmt diff --git a/userapi/storage/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go index ac0e80617..7b58f7bae 100644 --- a/userapi/storage/postgres/key_backup_table.go +++ b/userapi/storage/postgres/key_backup_table.go @@ -26,7 +26,7 @@ import ( ) const keyBackupTableSchema = ` -CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( +CREATE TABLE IF NOT EXISTS userapi_key_backups ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, session_id TEXT NOT NULL, @@ -37,31 +37,31 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( is_verified BOOLEAN NOT NULL, session_data TEXT NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); -CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backups_idx ON userapi_key_backups(user_id, room_id, session_id, version); +CREATE INDEX IF NOT EXISTS userapi_key_backups_versions_idx ON userapi_key_backups(user_id, version); ` const insertBackupKeySQL = "" + - "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + + "INSERT INTO userapi_key_backups(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" const updateBackupKeySQL = "" + - "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + + "UPDATE userapi_key_backups SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8" const countKeysSQL = "" + - "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" const selectKeysSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" const selectKeysByRoomIDSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2 AND room_id = $3" const selectKeysByRoomIDAndSessionIDSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" type keyBackupStatements struct { diff --git a/userapi/storage/postgres/key_backup_version_table.go b/userapi/storage/postgres/key_backup_version_table.go index e78e4cd51..67c5e5481 100644 --- a/userapi/storage/postgres/key_backup_version_table.go +++ b/userapi/storage/postgres/key_backup_version_table.go @@ -26,40 +26,40 @@ import ( ) const keyBackupVersionTableSchema = ` -CREATE SEQUENCE IF NOT EXISTS account_e2e_room_keys_versions_seq; +CREATE SEQUENCE IF NOT EXISTS userapi_key_backup_versions_seq; -- the metadata for each generation of encrypted e2e session backups -CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( +CREATE TABLE IF NOT EXISTS userapi_key_backup_versions ( user_id TEXT NOT NULL, -- this means no 2 users will ever have the same version of e2e session backups which strictly -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. - version BIGINT DEFAULT nextval('account_e2e_room_keys_versions_seq'), + version BIGINT DEFAULT nextval('userapi_key_backup_versions_seq'), algorithm TEXT NOT NULL, auth_data TEXT NOT NULL, etag TEXT NOT NULL, deleted SMALLINT DEFAULT 0 NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backup_versions_idx ON userapi_key_backup_versions(user_id, version); ` const insertKeyBackupSQL = "" + - "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" + "INSERT INTO userapi_key_backup_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" const updateKeyBackupAuthDataSQL = "" + - "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" + "UPDATE userapi_key_backup_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" const updateKeyBackupETagSQL = "" + - "UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3" + "UPDATE userapi_key_backup_versions SET etag = $1 WHERE user_id = $2 AND version = $3" const deleteKeyBackupSQL = "" + - "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" + "UPDATE userapi_key_backup_versions SET deleted=1 WHERE user_id = $1 AND version = $2" const selectKeyBackupSQL = "" + - "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + "SELECT algorithm, auth_data, etag, deleted FROM userapi_key_backup_versions WHERE user_id = $1 AND version = $2" const selectLatestVersionSQL = "" + - "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" + "SELECT MAX(version) FROM userapi_key_backup_versions WHERE user_id = $1" type keyBackupVersionStatements struct { insertKeyBackupStmt *sql.Stmt diff --git a/userapi/storage/postgres/logintoken_table.go b/userapi/storage/postgres/logintoken_table.go index 4de96f839..44c6ca4ae 100644 --- a/userapi/storage/postgres/logintoken_table.go +++ b/userapi/storage/postgres/logintoken_table.go @@ -26,7 +26,7 @@ import ( ) const loginTokenSchema = ` -CREATE TABLE IF NOT EXISTS login_tokens ( +CREATE TABLE IF NOT EXISTS userapi_login_tokens ( -- The random value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, -- When the token expires @@ -37,17 +37,17 @@ CREATE TABLE IF NOT EXISTS login_tokens ( ); -- This index allows efficient garbage collection of expired tokens. -CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); +CREATE INDEX IF NOT EXISTS userapi_login_tokens_expiration_idx ON userapi_login_tokens(token_expires_at); ` const insertLoginTokenSQL = "" + - "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" + "INSERT INTO userapi_login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" const deleteLoginTokenSQL = "" + - "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2" + "DELETE FROM userapi_login_tokens WHERE token = $1 OR token_expires_at <= $2" const selectLoginTokenSQL = "" + - "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2" + "SELECT user_id FROM userapi_login_tokens WHERE token = $1 AND token_expires_at > $2" type loginTokenStatements struct { insertStmt *sql.Stmt @@ -78,7 +78,7 @@ func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx // deleteByToken removes the named token. // // As a simple way to garbage-collect stale tokens, we also remove all expired tokens. -// The login_tokens_expiration_idx index should make that efficient. +// The userapi_login_tokens_expiration_idx index should make that efficient. func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error { stmt := sqlutil.TxStmt(txn, s.deleteStmt) res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) diff --git a/userapi/storage/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go index 29c3ddcb4..06ae30d08 100644 --- a/userapi/storage/postgres/openid_table.go +++ b/userapi/storage/postgres/openid_table.go @@ -13,7 +13,7 @@ import ( const openIDTokenSchema = ` -- Stores data about openid tokens issued for accounts. -CREATE TABLE IF NOT EXISTS open_id_tokens ( +CREATE TABLE IF NOT EXISTS userapi_openid_tokens ( -- The value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, -- The Matrix user ID for this account @@ -24,10 +24,10 @@ CREATE TABLE IF NOT EXISTS open_id_tokens ( ` const insertOpenIDTokenSQL = "" + - "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" const selectOpenIDTokenSQL = "" + - "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" + "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" type openIDTokenStatements struct { insertTokenStmt *sql.Stmt diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index 6d336eb8e..f686127be 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -27,7 +27,7 @@ import ( const profilesSchema = ` -- Stores data about accounts profiles. -CREATE TABLE IF NOT EXISTS account_profiles ( +CREATE TABLE IF NOT EXISTS userapi_profiles ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL PRIMARY KEY, -- The display name for this account @@ -38,19 +38,19 @@ CREATE TABLE IF NOT EXISTS account_profiles ( ` const insertProfileSQL = "" + - "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" const setAvatarURLSQL = "" + - "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" const setDisplayNameSQL = "" + - "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" type profilesStatements struct { serverNoticesLocalpart string diff --git a/userapi/storage/postgres/stats_table.go b/userapi/storage/postgres/stats_table.go index c0b317503..20eb0bf46 100644 --- a/userapi/storage/postgres/stats_table.go +++ b/userapi/storage/postgres/stats_table.go @@ -45,7 +45,7 @@ CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON usera const countUsersLastSeenAfterSQL = "" + "SELECT COUNT(*) FROM (" + - " SELECT localpart FROM device_devices WHERE last_seen_ts > $1 " + + " SELECT localpart FROM userapi_devices WHERE last_seen_ts > $1 " + " GROUP BY localpart" + " ) u" @@ -62,7 +62,7 @@ R30Users counts the number of 30 day retained users, defined as: const countR30UsersSQL = ` SELECT platform, COUNT(*) FROM ( SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts) - FROM account_accounts users + FROM userapi_accounts users INNER JOIN (SELECT localpart, last_seen_ts, @@ -75,7 +75,7 @@ SELECT platform, COUNT(*) FROM ( ELSE 'unknown' END AS platform - FROM device_devices + FROM userapi_devices ) uip ON users.localpart = uip.localpart AND users.account_type <> 4 @@ -121,7 +121,7 @@ GROUP BY client_type ` const countUserByAccountTypeSQL = ` -SELECT COUNT(*) FROM account_accounts WHERE account_type = ANY($1) +SELECT COUNT(*) FROM userapi_accounts WHERE account_type = ANY($1) ` // $1 = All non guest AccountType IDs @@ -134,7 +134,7 @@ SELECT user_type, COUNT(*) AS count FROM ( WHEN account_type = $2 AND appservice_id IS NULL THEN 'guest' WHEN account_type = ANY($1) AND appservice_id IS NOT NULL THEN 'bridged' END AS user_type - FROM account_accounts + FROM userapi_accounts WHERE created_ts > $3 ) AS t GROUP BY user_type ` @@ -143,14 +143,14 @@ SELECT user_type, COUNT(*) AS count FROM ( const updateUserDailyVisitsSQL = ` INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent) SELECT u.localpart, u.device_id, $1, MAX(u.user_agent) - FROM device_devices AS u + FROM userapi_devices AS u LEFT JOIN ( SELECT localpart, device_id, timestamp FROM userapi_daily_visits WHERE timestamp = $1 ) udv ON u.localpart = udv.localpart AND u.device_id = udv.device_id - INNER JOIN device_devices d ON d.localpart = u.localpart - INNER JOIN account_accounts a ON a.localpart = u.localpart + INNER JOIN userapi_devices d ON d.localpart = u.localpart + INNER JOIN userapi_accounts a ON a.localpart = u.localpart WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3 AND a.account_type in (1, 3) GROUP BY u.localpart, u.device_id diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 7d3b9b6a5..c059e3e60 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/userapi/storage/shared" // Import the postgres database driver. @@ -36,6 +37,16 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return nil, err } + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: rename tables", + Up: deltas.UpRenameTables, + Down: deltas.DownRenameTables, + }) + if err = m.Up(base.Context()); err != nil { + return nil, err + } + accountDataTable, err := NewPostgresAccountDataTable(db) if err != nil { return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err) diff --git a/userapi/storage/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go index 63c08d61f..11af76161 100644 --- a/userapi/storage/postgres/threepid_table.go +++ b/userapi/storage/postgres/threepid_table.go @@ -26,7 +26,7 @@ import ( const threepidSchema = ` -- Stores data about third party identifiers -CREATE TABLE IF NOT EXISTS account_threepid ( +CREATE TABLE IF NOT EXISTS userapi_threepids ( -- The third party identifier threepid TEXT NOT NULL, -- The 3PID medium @@ -37,20 +37,20 @@ CREATE TABLE IF NOT EXISTS account_threepid ( PRIMARY KEY(threepid, medium) ); -CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart); +CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart); ` const selectLocalpartForThreePIDSQL = "" + - "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" + "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" const selectThreePIDsForLocalpartSQL = "" + - "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" + "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1" const insertThreePIDSQL = "" + - "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" + "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)" const deleteThreePIDSQL = "" + - "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" + "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2" type threepidStatements struct { selectLocalpartForThreePIDStmt *sql.Stmt diff --git a/userapi/storage/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go index cfd8568a9..af12decb3 100644 --- a/userapi/storage/sqlite3/account_data_table.go +++ b/userapi/storage/sqlite3/account_data_table.go @@ -25,7 +25,7 @@ import ( const accountDataSchema = ` -- Stores data about accounts data. -CREATE TABLE IF NOT EXISTS account_data ( +CREATE TABLE IF NOT EXISTS userapi_account_datas ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL, -- The room ID for this data (empty string if not specific to a room) @@ -40,15 +40,15 @@ CREATE TABLE IF NOT EXISTS account_data ( ` const insertAccountDataSQL = ` - INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) + INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4) ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 ` const selectAccountDataSQL = "" + - "SELECT room_id, type, content FROM account_data WHERE localpart = $1" + "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1" const selectAccountDataByTypeSQL = "" + - "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" + "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3" type accountDataStatements struct { db *sql.DB diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index 484e90056..671c1aa04 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -32,7 +32,7 @@ import ( const accountsSchema = ` -- Stores data about accounts. -CREATE TABLE IF NOT EXISTS account_accounts ( +CREATE TABLE IF NOT EXISTS userapi_accounts ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL PRIMARY KEY, -- When this account was first created, as a unix timestamp (ms resolution). @@ -51,22 +51,22 @@ CREATE TABLE IF NOT EXISTS account_accounts ( ` const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" const updatePasswordSQL = "" + - "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" + "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" const deactivateAccountSQL = "" + - "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" + "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" + "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" const selectPasswordHashSQL = "" + - "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0" const selectNewNumericLocalpartSQL = "" + - "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0" + "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0" type accountsStatements struct { db *sql.DB diff --git a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go index e25efc695..9158cb365 100644 --- a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go +++ b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go @@ -8,8 +8,8 @@ import ( func UpIsActive(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` - ALTER TABLE account_accounts RENAME TO account_accounts_tmp; -CREATE TABLE account_accounts ( + ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; +CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, created_ts BIGINT NOT NULL, password_hash TEXT, @@ -17,13 +17,13 @@ CREATE TABLE account_accounts ( is_deactivated BOOLEAN DEFAULT 0 ); INSERT - INTO account_accounts ( + INTO userapi_accounts ( localpart, created_ts, password_hash, appservice_id ) SELECT localpart, created_ts, password_hash, appservice_id - FROM account_accounts_tmp + FROM userapi_accounts_tmp ; -DROP TABLE account_accounts_tmp;`) +DROP TABLE userapi_accounts_tmp;`) if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } @@ -32,21 +32,21 @@ DROP TABLE account_accounts_tmp;`) func DownIsActive(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` - ALTER TABLE account_accounts RENAME TO account_accounts_tmp; -CREATE TABLE account_accounts ( + ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; +CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, created_ts BIGINT NOT NULL, password_hash TEXT, appservice_id TEXT ); INSERT - INTO account_accounts ( + INTO userapi_accounts ( localpart, created_ts, password_hash, appservice_id ) SELECT localpart, created_ts, password_hash, appservice_id - FROM account_accounts_tmp + FROM userapi_accounts_tmp ; -DROP TABLE account_accounts_tmp;`) +DROP TABLE userapi_accounts_tmp;`) if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go index 7f7e95d2d..a9224db6b 100644 --- a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go @@ -8,8 +8,8 @@ import ( func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` - ALTER TABLE device_devices RENAME TO device_devices_tmp; - CREATE TABLE device_devices ( + ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp; + CREATE TABLE userapi_devices ( access_token TEXT PRIMARY KEY, session_id INTEGER, device_id TEXT , @@ -22,12 +22,12 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { UNIQUE (localpart, device_id) ); INSERT - INTO device_devices ( + INTO userapi_devices ( access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent ) SELECT access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', '' - FROM device_devices_tmp; - DROP TABLE device_devices_tmp;`) + FROM userapi_devices_tmp; + DROP TABLE userapi_devices_tmp;`) if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } @@ -36,8 +36,8 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` -ALTER TABLE device_devices RENAME TO device_devices_tmp; -CREATE TABLE IF NOT EXISTS device_devices ( +ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp; +CREATE TABLE IF NOT EXISTS userapi_devices ( access_token TEXT PRIMARY KEY, session_id INTEGER, device_id TEXT , @@ -47,12 +47,12 @@ CREATE TABLE IF NOT EXISTS device_devices ( UNIQUE (localpart, device_id) ); INSERT -INTO device_devices ( +INTO userapi_devices ( access_token, session_id, device_id, localpart, created_ts, display_name ) SELECT access_token, session_id, device_id, localpart, created_ts, display_name -FROM device_devices_tmp; -DROP TABLE device_devices_tmp;`) +FROM userapi_devices_tmp; +DROP TABLE userapi_devices_tmp;`) if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go index 46532698c..230bc1433 100644 --- a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go +++ b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go @@ -9,8 +9,8 @@ import ( func UpAddAccountType(ctx context.Context, tx *sql.Tx) error { // initially set every account to useraccount, change appservice and guest accounts afterwards // (user = 1, guest = 2, admin = 3, appservice = 4) - _, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts RENAME TO account_accounts_tmp; -CREATE TABLE account_accounts ( + _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; +CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, created_ts BIGINT NOT NULL, password_hash TEXT, @@ -19,15 +19,15 @@ CREATE TABLE account_accounts ( account_type INTEGER NOT NULL ); INSERT - INTO account_accounts ( + INTO userapi_accounts ( localpart, created_ts, password_hash, appservice_id, account_type ) SELECT localpart, created_ts, password_hash, appservice_id, 1 - FROM account_accounts_tmp + FROM userapi_accounts_tmp ; -UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> ''; -UPDATE account_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*'; -DROP TABLE account_accounts_tmp;`) +UPDATE userapi_accounts SET account_type = 4 WHERE appservice_id <> ''; +UPDATE userapi_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*'; +DROP TABLE userapi_accounts_tmp;`) if err != nil { return fmt.Errorf("failed to add column: %w", err) } @@ -35,7 +35,7 @@ DROP TABLE account_accounts_tmp;`) } func DownAddAccountType(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts DROP COLUMN account_type;`) + _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts DROP COLUMN account_type;`) if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/sqlite3/deltas/2022101711000000_rename_tables.go b/userapi/storage/sqlite3/deltas/2022101711000000_rename_tables.go new file mode 100644 index 000000000..4ca1dc475 --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022101711000000_rename_tables.go @@ -0,0 +1,109 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + "strings" +) + +var renameTableMappings = map[string]string{ + "account_accounts": "userapi_accounts", + "account_data": "userapi_account_datas", + "device_devices": "userapi_devices", + "account_e2e_room_keys": "userapi_key_backups", + "account_e2e_room_keys_versions": "userapi_key_backup_versions", + "login_tokens": "userapi_login_tokens", + "open_id_tokens": "userapi_openid_tokens", + "account_profiles": "userapi_profiles", + "account_threepid": "userapi_threepids", +} + +var renameIndicesMappings = map[string]string{ + "device_localpart_id_idx": "userapi_device_localpart_id_idx", + "e2e_room_keys_idx": "userapi_key_backups_idx", + "e2e_room_keys_versions_idx": "userapi_key_backups_versions_idx", + "account_e2e_room_keys_versions_idx": "userapi_key_backup_versions_idx", + "login_tokens_expiration_idx": "userapi_login_tokens_expiration_idx", + "account_threepid_localpart": "userapi_threepid_idx", +} + +func UpRenameTables(ctx context.Context, tx *sql.Tx) error { + for old, new := range renameTableMappings { + // SQLite has no "IF EXISTS" so check if the table exists. + var name string + if err := tx.QueryRowContext( + ctx, "SELECT name FROM sqlite_schema WHERE type = 'table' AND name = $1;", old, + ).Scan(&name); err != nil { + if err == sql.ErrNoRows { + continue + } + return err + } + q := fmt.Sprintf( + "ALTER TABLE %s RENAME TO %s;", old, new, + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", old, new, err) + } + } + for old, new := range renameIndicesMappings { + var query string + if err := tx.QueryRowContext( + ctx, "SELECT sql FROM sqlite_schema WHERE type = 'index' AND name = $1;", old, + ).Scan(&query); err != nil { + if err == sql.ErrNoRows { + continue + } + return err + } + query = strings.Replace(query, old, new, 1) + if _, err := tx.ExecContext(ctx, fmt.Sprintf("DROP INDEX %s;", old)); err != nil { + return fmt.Errorf("drop index %q to %q error: %w", old, new, err) + } + if _, err := tx.ExecContext(ctx, query); err != nil { + return fmt.Errorf("recreate index %q to %q error: %w", old, new, err) + } + } + return nil +} + +func DownRenameTables(ctx context.Context, tx *sql.Tx) error { + for old, new := range renameTableMappings { + // SQLite has no "IF EXISTS" so check if the table exists. + var name string + if err := tx.QueryRowContext( + ctx, "SELECT name FROM sqlite_schema WHERE type = 'table' AND name = $1;", new, + ).Scan(&name); err != nil { + if err == sql.ErrNoRows { + continue + } + return err + } + q := fmt.Sprintf( + "ALTER TABLE %s RENAME TO %s;", new, old, + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", new, old, err) + } + } + for old, new := range renameIndicesMappings { + var query string + if err := tx.QueryRowContext( + ctx, "SELECT sql FROM sqlite_schema WHERE type = 'index' AND name = $1;", new, + ).Scan(&query); err != nil { + if err == sql.ErrNoRows { + continue + } + return err + } + query = strings.Replace(query, new, old, 1) + if _, err := tx.ExecContext(ctx, fmt.Sprintf("DROP INDEX %s;", new)); err != nil { + return fmt.Errorf("drop index %q to %q error: %w", new, old, err) + } + if _, err := tx.ExecContext(ctx, query); err != nil { + return fmt.Errorf("recreate index %q to %q error: %w", new, old, err) + } + } + return nil +} diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 27a7524d6..e53a08062 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -35,7 +35,7 @@ const devicesSchema = ` -- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1; -- Stores data about devices. -CREATE TABLE IF NOT EXISTS device_devices ( +CREATE TABLE IF NOT EXISTS userapi_devices ( access_token TEXT PRIMARY KEY, session_id INTEGER, device_id TEXT , @@ -51,38 +51,38 @@ CREATE TABLE IF NOT EXISTS device_devices ( ` const insertDeviceSQL = "" + - "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + + "INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" const selectDevicesCountSQL = "" + - "SELECT COUNT(access_token) FROM device_devices" + "SELECT COUNT(access_token) FROM userapi_devices" const selectDeviceByTokenSQL = "" + - "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" + "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + - "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" const deleteDeviceSQL = "" + - "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" + "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" + "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" const deleteDevicesSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" + "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + - "UPDATE device_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" type devicesStatements struct { db *sql.DB diff --git a/userapi/storage/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go index 81726edf9..7883ffb19 100644 --- a/userapi/storage/sqlite3/key_backup_table.go +++ b/userapi/storage/sqlite3/key_backup_table.go @@ -26,7 +26,7 @@ import ( ) const keyBackupTableSchema = ` -CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( +CREATE TABLE IF NOT EXISTS userapi_key_backups ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, session_id TEXT NOT NULL, @@ -37,31 +37,31 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( is_verified BOOLEAN NOT NULL, session_data TEXT NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); -CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); +CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON userapi_key_backups(user_id, room_id, session_id, version); +CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON userapi_key_backups(user_id, version); ` const insertBackupKeySQL = "" + - "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + + "INSERT INTO userapi_key_backups(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" const updateBackupKeySQL = "" + - "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + + "UPDATE userapi_key_backups SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8" const countKeysSQL = "" + - "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" const selectKeysSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" const selectKeysByRoomIDSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2 AND room_id = $3" const selectKeysByRoomIDAndSessionIDSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" type keyBackupStatements struct { diff --git a/userapi/storage/sqlite3/key_backup_version_table.go b/userapi/storage/sqlite3/key_backup_version_table.go index e85e6f08b..37bc13ed1 100644 --- a/userapi/storage/sqlite3/key_backup_version_table.go +++ b/userapi/storage/sqlite3/key_backup_version_table.go @@ -27,7 +27,7 @@ import ( const keyBackupVersionTableSchema = ` -- the metadata for each generation of encrypted e2e session backups -CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( +CREATE TABLE IF NOT EXISTS userapi_key_backup_versions ( user_id TEXT NOT NULL, -- this means no 2 users will ever have the same version of e2e session backups which strictly -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. @@ -38,26 +38,26 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( deleted INTEGER DEFAULT 0 NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backup_versions_idx ON userapi_key_backup_versions(user_id, version); ` const insertKeyBackupSQL = "" + - "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" + "INSERT INTO userapi_key_backup_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" const updateKeyBackupAuthDataSQL = "" + - "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" + "UPDATE userapi_key_backup_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" const updateKeyBackupETagSQL = "" + - "UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3" + "UPDATE userapi_key_backup_versions SET etag = $1 WHERE user_id = $2 AND version = $3" const deleteKeyBackupSQL = "" + - "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" + "UPDATE userapi_key_backup_versions SET deleted=1 WHERE user_id = $1 AND version = $2" const selectKeyBackupSQL = "" + - "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + "SELECT algorithm, auth_data, etag, deleted FROM userapi_key_backup_versions WHERE user_id = $1 AND version = $2" const selectLatestVersionSQL = "" + - "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" + "SELECT MAX(version) FROM userapi_key_backup_versions WHERE user_id = $1" type keyBackupVersionStatements struct { insertKeyBackupStmt *sql.Stmt diff --git a/userapi/storage/sqlite3/logintoken_table.go b/userapi/storage/sqlite3/logintoken_table.go index 78d42029a..2abdcb95e 100644 --- a/userapi/storage/sqlite3/logintoken_table.go +++ b/userapi/storage/sqlite3/logintoken_table.go @@ -32,7 +32,7 @@ type loginTokenStatements struct { } const loginTokenSchema = ` -CREATE TABLE IF NOT EXISTS login_tokens ( +CREATE TABLE IF NOT EXISTS userapi_login_tokens ( -- The random value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, -- When the token expires @@ -43,17 +43,17 @@ CREATE TABLE IF NOT EXISTS login_tokens ( ); -- This index allows efficient garbage collection of expired tokens. -CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); +CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON userapi_login_tokens(token_expires_at); ` const insertLoginTokenSQL = "" + - "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" + "INSERT INTO userapi_login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" const deleteLoginTokenSQL = "" + - "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2" + "DELETE FROM userapi_login_tokens WHERE token = $1 OR token_expires_at <= $2" const selectLoginTokenSQL = "" + - "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2" + "SELECT user_id FROM userapi_login_tokens WHERE token = $1 AND token_expires_at > $2" func NewSQLiteLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) { s := &loginTokenStatements{} @@ -78,7 +78,7 @@ func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx // deleteByToken removes the named token. // // As a simple way to garbage-collect stale tokens, we also remove all expired tokens. -// The login_tokens_expiration_idx index should make that efficient. +// The userapi_login_tokens_expiration_idx index should make that efficient. func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error { stmt := sqlutil.TxStmt(txn, s.deleteStmt) res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) diff --git a/userapi/storage/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go index d6090e0da..875f1a9a5 100644 --- a/userapi/storage/sqlite3/openid_table.go +++ b/userapi/storage/sqlite3/openid_table.go @@ -13,7 +13,7 @@ import ( const openIDTokenSchema = ` -- Stores data about accounts. -CREATE TABLE IF NOT EXISTS open_id_tokens ( +CREATE TABLE IF NOT EXISTS userapi_openid_tokens ( -- The value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, -- The Matrix user ID for this account @@ -24,10 +24,10 @@ CREATE TABLE IF NOT EXISTS open_id_tokens ( ` const insertOpenIDTokenSQL = "" + - "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" const selectOpenIDTokenSQL = "" + - "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" + "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" type openIDTokenStatements struct { db *sql.DB diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index 3050ff4b5..267daf044 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -27,7 +27,7 @@ import ( const profilesSchema = ` -- Stores data about accounts profiles. -CREATE TABLE IF NOT EXISTS account_profiles ( +CREATE TABLE IF NOT EXISTS userapi_profiles ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL PRIMARY KEY, -- The display name for this account @@ -38,19 +38,19 @@ CREATE TABLE IF NOT EXISTS account_profiles ( ` const insertProfileSQL = "" + - "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" const setAvatarURLSQL = "" + - "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" const setDisplayNameSQL = "" + - "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" type profilesStatements struct { db *sql.DB diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go index 8aa1746c5..35e3c653e 100644 --- a/userapi/storage/sqlite3/stats_table.go +++ b/userapi/storage/sqlite3/stats_table.go @@ -46,7 +46,7 @@ CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON usera const countUsersLastSeenAfterSQL = "" + "SELECT COUNT(*) FROM (" + - " SELECT localpart FROM device_devices WHERE last_seen_ts > $1 " + + " SELECT localpart FROM userapi_devices WHERE last_seen_ts > $1 " + " GROUP BY localpart" + " ) u" @@ -63,7 +63,7 @@ R30Users counts the number of 30 day retained users, defined as: const countR30UsersSQL = ` SELECT platform, COUNT(*) FROM ( SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts) - FROM account_accounts users + FROM userapi_accounts users INNER JOIN (SELECT localpart, last_seen_ts, @@ -76,7 +76,7 @@ SELECT platform, COUNT(*) FROM ( ELSE 'unknown' END AS platform - FROM device_devices + FROM userapi_devices ) uip ON users.localpart = uip.localpart AND users.account_type <> 4 @@ -126,7 +126,7 @@ GROUP BY client_type ` const countUserByAccountTypeSQL = ` -SELECT COUNT(*) FROM account_accounts WHERE account_type IN ($1) +SELECT COUNT(*) FROM userapi_accounts WHERE account_type IN ($1) ` // $1 = Guest AccountType @@ -139,7 +139,7 @@ SELECT user_type, COUNT(*) AS count FROM ( WHEN account_type = $4 AND appservice_id IS NULL THEN 'guest' WHEN account_type IN ($5) AND appservice_id IS NOT NULL THEN 'bridged' END AS user_type - FROM account_accounts + FROM userapi_accounts WHERE created_ts > $8 ) AS t GROUP BY user_type ` @@ -148,14 +148,14 @@ SELECT user_type, COUNT(*) AS count FROM ( const updateUserDailyVisitsSQL = ` INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent) SELECT u.localpart, u.device_id, $1, MAX(u.user_agent) - FROM device_devices AS u + FROM userapi_devices AS u LEFT JOIN ( SELECT localpart, device_id, timestamp FROM userapi_daily_visits WHERE timestamp = $1 ) udv ON u.localpart = udv.localpart AND u.device_id = udv.device_id - INNER JOIN device_devices d ON d.localpart = u.localpart - INNER JOIN account_accounts a ON a.localpart = u.localpart + INNER JOIN userapi_devices d ON d.localpart = u.localpart + INNER JOIN userapi_accounts a ON a.localpart = u.localpart WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3 AND a.account_type in (1, 3) GROUP BY u.localpart, u.device_id diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 78b7ce588..dd33dc0cf 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/shared" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" ) // NewDatabase creates a new accounts and profiles database @@ -34,6 +35,16 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return nil, err } + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: rename tables", + Up: deltas.UpRenameTables, + Down: deltas.DownRenameTables, + }) + if err = m.Up(base.Context()); err != nil { + return nil, err + } + accountDataTable, err := NewSQLiteAccountDataTable(db) if err != nil { return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) diff --git a/userapi/storage/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go index fa174eed5..73af139db 100644 --- a/userapi/storage/sqlite3/threepid_table.go +++ b/userapi/storage/sqlite3/threepid_table.go @@ -27,7 +27,7 @@ import ( const threepidSchema = ` -- Stores data about third party identifiers -CREATE TABLE IF NOT EXISTS account_threepid ( +CREATE TABLE IF NOT EXISTS userapi_threepids ( -- The third party identifier threepid TEXT NOT NULL, -- The 3PID medium @@ -38,20 +38,20 @@ CREATE TABLE IF NOT EXISTS account_threepid ( PRIMARY KEY(threepid, medium) ); -CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart); +CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart); ` const selectLocalpartForThreePIDSQL = "" + - "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" + "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" const selectThreePIDsForLocalpartSQL = "" + - "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" + "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1" const insertThreePIDSQL = "" + - "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" + "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)" const deleteThreePIDSQL = "" + - "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" + "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2" type threepidStatements struct { db *sql.DB diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 1538a8138..8e5b32b6a 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -16,6 +16,7 @@ import ( "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" @@ -29,14 +30,18 @@ var ( ) func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{ + db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") if err != nil { t.Fatalf("NewUserAPIDatabase returned %s", err) } - return db, close + return db, func() { + close() + baseclose() + } } // Tests storing and getting account data diff --git a/userapi/storage/tables/stats_table_test.go b/userapi/storage/tables/stats_table_test.go index 11521c8b0..c4aec552c 100644 --- a/userapi/storage/tables/stats_table_test.go +++ b/userapi/storage/tables/stats_table_test.go @@ -106,7 +106,7 @@ func mustUpdateDeviceLastSeen( timestamp time.Time, ) { t.Helper() - _, err := db.ExecContext(ctx, "UPDATE device_devices SET last_seen_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) + _, err := db.ExecContext(ctx, "UPDATE userapi_devices SET last_seen_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) if err != nil { t.Fatalf("unable to update device last seen") } @@ -119,7 +119,7 @@ func mustUserUpdateRegistered( localpart string, timestamp time.Time, ) { - _, err := db.ExecContext(ctx, "UPDATE account_accounts SET created_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) + _, err := db.ExecContext(ctx, "UPDATE userapi_accounts SET created_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) if err != nil { t.Fatalf("unable to update device last seen") } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 984fe8854..4417f4dc0 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -25,6 +25,7 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/gomatrixserverlib" @@ -48,9 +49,9 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap if opts.loginTokenLifetime == 0 { opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond } + base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - - accountDB, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{ + accountDB, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") if err != nil { @@ -64,9 +65,12 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap } return &internal.UserInternalAPI{ - DB: accountDB, - ServerName: cfg.Matrix.ServerName, - }, accountDB, close + DB: accountDB, + ServerName: cfg.Matrix.ServerName, + }, accountDB, func() { + close() + baseclose() + } } func TestQueryProfile(t *testing.T) { From 241d5c47dfa9e5cfadc350f688aab30f9e539fbb Mon Sep 17 00:00:00 2001 From: devonh Date: Wed, 19 Oct 2022 10:03:16 +0000 Subject: [PATCH 4/7] Refactor Federation Destination Queues (#2807) This is a refactor of the federation destination queues. It fixes a few things, namely: - actually retry outgoing events with backoff behaviour - obtain enough events from the database to fill messages as much as possible - minimize the amount of running goroutines - use pure timers for backoff - don't restart queue unless necessary - close the background task when backing off - increase max edus in a transaction to match the spec - cleanup timers more aggresively to reduce memory usage - add jitter to backoff timers to reduce resource spikes - add a bunch of tests (with real and fake databases) to ensure everything is working --- federationapi/federationapi.go | 9 +- federationapi/queue/destinationqueue.go | 374 ++++--- federationapi/queue/queue.go | 26 +- federationapi/queue/queue_test.go | 1047 +++++++++++++++++++ federationapi/statistics/statistics.go | 131 ++- federationapi/statistics/statistics_test.go | 19 +- federationapi/storage/shared/storage.go | 4 + go.mod | 2 +- 8 files changed, 1410 insertions(+), 202 deletions(-) create mode 100644 federationapi/queue/queue_test.go diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 4a13c9d9b..f6dace702 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -116,17 +116,14 @@ func NewInternalAPI( _ = federationDB.RemoveAllServersFromBlacklist() } - stats := &statistics.Statistics{ - DB: federationDB, - FailuresUntilBlacklist: cfg.FederationMaxRetries, - } + stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1) js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) queues := queue.NewOutgoingQueues( federationDB, base.ProcessContext, cfg.Matrix.DisableFederation, - cfg.Matrix.ServerName, federation, rsAPI, stats, + cfg.Matrix.ServerName, federation, rsAPI, &stats, &queue.SigningInfo{ KeyID: cfg.Matrix.KeyID, PrivateKey: cfg.Matrix.PrivateKey, @@ -183,5 +180,5 @@ func NewInternalAPI( } time.AfterFunc(time.Minute, cleanExpiredEDUs) - return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, stats, caches, queues, keyRing) + return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, &stats, caches, queues, keyRing) } diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 5cb8cae1f..00e02b2d9 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -35,7 +35,7 @@ import ( const ( maxPDUsPerTransaction = 50 - maxEDUsPerTransaction = 50 + maxEDUsPerTransaction = 100 maxPDUsInMemory = 128 maxEDUsInMemory = 128 queueIdleTimeout = time.Second * 30 @@ -64,7 +64,6 @@ type destinationQueue struct { pendingPDUs []*queuedPDU // PDUs waiting to be sent pendingEDUs []*queuedEDU // EDUs waiting to be sent pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs - interruptBackoff chan bool // interrupts backoff } // Send event adds the event to the pending queue for the destination. @@ -75,6 +74,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination) return } + // Create a database entry that associates the given PDU NID with // this destination queue. We'll then be able to retrieve the PDU // later. @@ -102,12 +102,12 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re oq.overflowed.Store(true) } oq.pendingMutex.Unlock() - // Wake up the queue if it's asleep. - oq.wakeQueueIfNeeded() - select { - case oq.notify <- struct{}{}: - default: + + if !oq.backingOff.Load() { + oq.wakeQueueAndNotify() } + } else { + oq.overflowed.Store(true) } } @@ -147,12 +147,37 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share oq.overflowed.Store(true) } oq.pendingMutex.Unlock() - // Wake up the queue if it's asleep. - oq.wakeQueueIfNeeded() - select { - case oq.notify <- struct{}{}: - default: + + if !oq.backingOff.Load() { + oq.wakeQueueAndNotify() } + } else { + oq.overflowed.Store(true) + } +} + +// handleBackoffNotifier is registered as the backoff notification +// callback with Statistics. It will wakeup and notify the queue +// if the queue is currently backing off. +func (oq *destinationQueue) handleBackoffNotifier() { + // Only wake up the queue if it is backing off. + // Otherwise there is no pending work for the queue to handle + // so waking the queue would be a waste of resources. + if oq.backingOff.Load() { + oq.wakeQueueAndNotify() + } +} + +// wakeQueueAndNotify ensures the destination queue is running and notifies it +// that there is pending work. +func (oq *destinationQueue) wakeQueueAndNotify() { + // Wake up the queue if it's asleep. + oq.wakeQueueIfNeeded() + + // Notify the queue that there are events ready to send. + select { + case oq.notify <- struct{}{}: + default: } } @@ -161,10 +186,11 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share // then we will interrupt the backoff, causing any federation // requests to retry. func (oq *destinationQueue) wakeQueueIfNeeded() { - // If we are backing off then interrupt the backoff. + // Clear the backingOff flag and update the backoff metrics if it was set. if oq.backingOff.CompareAndSwap(true, false) { - oq.interruptBackoff <- true + destinationQueueBackingOff.Dec() } + // If we aren't running then wake up the queue. if !oq.running.Load() { // Start the queue. @@ -196,38 +222,54 @@ func (oq *destinationQueue) getPendingFromDatabase() { gotEDUs[edu.receipt.String()] = struct{}{} } + overflowed := false if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 { // We have room in memory for some PDUs - let's request no more than that. - if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil { + if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, maxPDUsInMemory); err == nil { + if len(pdus) == maxPDUsInMemory { + overflowed = true + } for receipt, pdu := range pdus { if _, ok := gotPDUs[receipt.String()]; ok { continue } oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu}) retrieved = true + if len(oq.pendingPDUs) == maxPDUsInMemory { + break + } } } else { logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination) } } + if eduCapacity := maxEDUsInMemory - len(oq.pendingEDUs); eduCapacity > 0 { // We have room in memory for some EDUs - let's request no more than that. - if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, eduCapacity); err == nil { + if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, maxEDUsInMemory); err == nil { + if len(edus) == maxEDUsInMemory { + overflowed = true + } for receipt, edu := range edus { if _, ok := gotEDUs[receipt.String()]; ok { continue } oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu}) retrieved = true + if len(oq.pendingEDUs) == maxEDUsInMemory { + break + } } } else { logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination) } } + // If we've retrieved all of the events from the database with room to spare // in memory then we'll no longer consider this queue to be overflowed. - if len(oq.pendingPDUs) < maxPDUsInMemory && len(oq.pendingEDUs) < maxEDUsInMemory { + if !overflowed { oq.overflowed.Store(false) + } else { } // If we've retrieved some events then notify the destination queue goroutine. if retrieved { @@ -238,6 +280,24 @@ func (oq *destinationQueue) getPendingFromDatabase() { } } +// checkNotificationsOnClose checks for any remaining notifications +// and starts a new backgroundSend goroutine if any exist. +func (oq *destinationQueue) checkNotificationsOnClose() { + // NOTE : If we are stopping the queue due to blacklist then it + // doesn't matter if we have been notified of new work since + // this queue instance will be deleted anyway. + if !oq.statistics.Blacklisted() { + select { + case <-oq.notify: + // We received a new notification in between the + // idle timeout firing and stopping the goroutine. + // Immediately restart the queue. + oq.wakeQueueAndNotify() + default: + } + } +} + // backgroundSend is the worker goroutine for sending events. func (oq *destinationQueue) backgroundSend() { // Check if a worker is already running, and if it isn't, then @@ -245,10 +305,17 @@ func (oq *destinationQueue) backgroundSend() { if !oq.running.CompareAndSwap(false, true) { return } + + // Register queue cleanup functions. + // NOTE : The ordering here is very intentional. + defer oq.checkNotificationsOnClose() + defer oq.running.Store(false) + destinationQueueRunning.Inc() defer destinationQueueRunning.Dec() - defer oq.queues.clearQueue(oq) - defer oq.running.Store(false) + + idleTimeout := time.NewTimer(queueIdleTimeout) + defer idleTimeout.Stop() // Mark the queue as overflowed, so we will consult the database // to see if there's anything new to send. @@ -261,59 +328,33 @@ func (oq *destinationQueue) backgroundSend() { oq.getPendingFromDatabase() } + // Reset the queue idle timeout. + if !idleTimeout.Stop() { + select { + case <-idleTimeout.C: + default: + } + } + idleTimeout.Reset(queueIdleTimeout) + // If we have nothing to do then wait either for incoming events, or // until we hit an idle timeout. select { case <-oq.notify: // There's work to do, either because getPendingFromDatabase - // told us there is, or because a new event has come in via - // sendEvent/sendEDU. - case <-time.After(queueIdleTimeout): + // told us there is, a new event has come in via sendEvent/sendEDU, + // or we are backing off and it is time to retry. + case <-idleTimeout.C: // The worker is idle so stop the goroutine. It'll get // restarted automatically the next time we have an event to // send. return case <-oq.process.Context().Done(): // The parent process is shutting down, so stop. + oq.statistics.ClearBackoff() return } - // If we are backing off this server then wait for the - // backoff duration to complete first, or until explicitly - // told to retry. - until, blacklisted := oq.statistics.BackoffInfo() - if blacklisted { - // It's been suggested that we should give up because the backoff - // has exceeded a maximum allowable value. Clean up the in-memory - // buffers at this point. The PDU clean-up is already on a defer. - logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) - oq.pendingMutex.Lock() - for i := range oq.pendingPDUs { - oq.pendingPDUs[i] = nil - } - for i := range oq.pendingEDUs { - oq.pendingEDUs[i] = nil - } - oq.pendingPDUs = nil - oq.pendingEDUs = nil - oq.pendingMutex.Unlock() - return - } - if until != nil && until.After(time.Now()) { - // We haven't backed off yet, so wait for the suggested amount of - // time. - duration := time.Until(*until) - logrus.Debugf("Backing off %q for %s", oq.destination, duration) - oq.backingOff.Store(true) - destinationQueueBackingOff.Inc() - select { - case <-time.After(duration): - case <-oq.interruptBackoff: - } - destinationQueueBackingOff.Dec() - oq.backingOff.Store(false) - } - // Work out which PDUs/EDUs to include in the next transaction. oq.pendingMutex.RLock() pduCount := len(oq.pendingPDUs) @@ -328,99 +369,52 @@ func (oq *destinationQueue) backgroundSend() { toSendEDUs := oq.pendingEDUs[:eduCount] oq.pendingMutex.RUnlock() + // If we didn't get anything from the database and there are no + // pending EDUs then there's nothing to do - stop here. + if pduCount == 0 && eduCount == 0 { + continue + } + // If we have pending PDUs or EDUs then construct a transaction. // Try sending the next transaction and see what happens. - transaction, pc, ec, terr := oq.nextTransaction(toSendPDUs, toSendEDUs) + terr := oq.nextTransaction(toSendPDUs, toSendEDUs) if terr != nil { // We failed to send the transaction. Mark it as a failure. - oq.statistics.Failure() - - } else if transaction { - // If we successfully sent the transaction then clear out - // the pending events and EDUs, and wipe our transaction ID. - oq.statistics.Success() - oq.pendingMutex.Lock() - for i := range oq.pendingPDUs[:pc] { - oq.pendingPDUs[i] = nil + _, blacklisted := oq.statistics.Failure() + if !blacklisted { + // Register the backoff state and exit the goroutine. + // It'll get restarted automatically when the backoff + // completes. + oq.backingOff.Store(true) + destinationQueueBackingOff.Inc() + return + } else { + // Immediately trigger the blacklist logic. + oq.blacklistDestination() + return } - for i := range oq.pendingEDUs[:ec] { - oq.pendingEDUs[i] = nil - } - oq.pendingPDUs = oq.pendingPDUs[pc:] - oq.pendingEDUs = oq.pendingEDUs[ec:] - oq.pendingMutex.Unlock() + } else { + oq.handleTransactionSuccess(pduCount, eduCount) } } } // nextTransaction creates a new transaction from the pending event -// queue and sends it. Returns true if a transaction was sent or -// false otherwise. +// queue and sends it. +// Returns an error if the transaction wasn't sent. func (oq *destinationQueue) nextTransaction( pdus []*queuedPDU, edus []*queuedEDU, -) (bool, int, int, error) { - // If there's no projected transaction ID then generate one. If - // the transaction succeeds then we'll set it back to "" so that - // we generate a new one next time. If it fails, we'll preserve - // it so that we retry with the same transaction ID. - oq.transactionIDMutex.Lock() - if oq.transactionID == "" { - now := gomatrixserverlib.AsTimestamp(time.Now()) - oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) - } - oq.transactionIDMutex.Unlock() - +) error { // Create the transaction. - t := gomatrixserverlib.Transaction{ - PDUs: []json.RawMessage{}, - EDUs: []gomatrixserverlib.EDU{}, - } - t.Origin = oq.origin - t.Destination = oq.destination - t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) - t.TransactionID = oq.transactionID - - // If we didn't get anything from the database and there are no - // pending EDUs then there's nothing to do - stop here. - if len(pdus) == 0 && len(edus) == 0 { - return false, 0, 0, nil - } - - var pduReceipts []*shared.Receipt - var eduReceipts []*shared.Receipt - - // Go through PDUs that we retrieved from the database, if any, - // and add them into the transaction. - for _, pdu := range pdus { - if pdu == nil || pdu.pdu == nil { - continue - } - // Append the JSON of the event, since this is a json.RawMessage type in the - // gomatrixserverlib.Transaction struct - t.PDUs = append(t.PDUs, pdu.pdu.JSON()) - pduReceipts = append(pduReceipts, pdu.receipt) - } - - // Do the same for pending EDUS in the queue. - for _, edu := range edus { - if edu == nil || edu.edu == nil { - continue - } - t.EDUs = append(t.EDUs, *edu.edu) - eduReceipts = append(eduReceipts, edu.receipt) - } - + t, pduReceipts, eduReceipts := oq.createTransaction(pdus, edus) logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) // Try to send the transaction to the destination server. - // TODO: we should check for 500-ish fails vs 400-ish here, - // since we shouldn't queue things indefinitely in response - // to a 400-ish error ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5) defer cancel() _, err := oq.client.SendTransaction(ctx, t) - switch err.(type) { + switch errResponse := err.(type) { case nil: // Clean up the transaction in the database. if pduReceipts != nil { @@ -439,16 +433,128 @@ func (oq *destinationQueue) nextTransaction( oq.transactionIDMutex.Lock() oq.transactionID = "" oq.transactionIDMutex.Unlock() - return true, len(t.PDUs), len(t.EDUs), nil + return nil case gomatrix.HTTPError: // Report that we failed to send the transaction and we // will retry again, subject to backoff. - return false, 0, 0, err + + // TODO: we should check for 500-ish fails vs 400-ish here, + // since we shouldn't queue things indefinitely in response + // to a 400-ish error + code := errResponse.Code + logrus.Debug("Transaction failed with HTTP", code) + return err default: logrus.WithFields(logrus.Fields{ "destination": oq.destination, logrus.ErrorKey: err, }).Debugf("Failed to send transaction %q", t.TransactionID) - return false, 0, 0, err + return err + } +} + +// createTransaction generates a gomatrixserverlib.Transaction from the provided pdus and edus. +// It also returns the associated event receipts so they can be cleaned from the database in +// the case of a successful transaction. +func (oq *destinationQueue) createTransaction( + pdus []*queuedPDU, + edus []*queuedEDU, +) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) { + // If there's no projected transaction ID then generate one. If + // the transaction succeeds then we'll set it back to "" so that + // we generate a new one next time. If it fails, we'll preserve + // it so that we retry with the same transaction ID. + oq.transactionIDMutex.Lock() + if oq.transactionID == "" { + now := gomatrixserverlib.AsTimestamp(time.Now()) + oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) + } + oq.transactionIDMutex.Unlock() + + t := gomatrixserverlib.Transaction{ + PDUs: []json.RawMessage{}, + EDUs: []gomatrixserverlib.EDU{}, + } + t.Origin = oq.origin + t.Destination = oq.destination + t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) + t.TransactionID = oq.transactionID + + var pduReceipts []*shared.Receipt + var eduReceipts []*shared.Receipt + + // Go through PDUs that we retrieved from the database, if any, + // and add them into the transaction. + for _, pdu := range pdus { + // These should never be nil. + if pdu == nil || pdu.pdu == nil { + continue + } + // Append the JSON of the event, since this is a json.RawMessage type in the + // gomatrixserverlib.Transaction struct + t.PDUs = append(t.PDUs, pdu.pdu.JSON()) + pduReceipts = append(pduReceipts, pdu.receipt) + } + + // Do the same for pending EDUS in the queue. + for _, edu := range edus { + // These should never be nil. + if edu == nil || edu.edu == nil { + continue + } + t.EDUs = append(t.EDUs, *edu.edu) + eduReceipts = append(eduReceipts, edu.receipt) + } + + return t, pduReceipts, eduReceipts +} + +// blacklistDestination removes all pending PDUs and EDUs that have been cached +// and deletes this queue. +func (oq *destinationQueue) blacklistDestination() { + // It's been suggested that we should give up because the backoff + // has exceeded a maximum allowable value. Clean up the in-memory + // buffers at this point. The PDU clean-up is already on a defer. + logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) + + oq.pendingMutex.Lock() + for i := range oq.pendingPDUs { + oq.pendingPDUs[i] = nil + } + for i := range oq.pendingEDUs { + oq.pendingEDUs[i] = nil + } + oq.pendingPDUs = nil + oq.pendingEDUs = nil + oq.pendingMutex.Unlock() + + // Delete this queue as no more messages will be sent to this + // destination until it is no longer blacklisted. + oq.statistics.AssignBackoffNotifier(nil) + oq.queues.clearQueue(oq) +} + +// handleTransactionSuccess updates the cached event queues as well as the success and +// backoff information for this server. +func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) { + // If we successfully sent the transaction then clear out + // the pending events and EDUs, and wipe our transaction ID. + oq.statistics.Success() + oq.pendingMutex.Lock() + for i := range oq.pendingPDUs[:pduCount] { + oq.pendingPDUs[i] = nil + } + for i := range oq.pendingEDUs[:eduCount] { + oq.pendingEDUs[i] = nil + } + oq.pendingPDUs = oq.pendingPDUs[pduCount:] + oq.pendingEDUs = oq.pendingEDUs[eduCount:] + oq.pendingMutex.Unlock() + + if len(oq.pendingPDUs) > 0 || len(oq.pendingEDUs) > 0 { + select { + case oq.notify <- struct{}{}: + default: + } } } diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 8245aa5bd..68f789e37 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -162,23 +162,25 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d if !ok || oq == nil { destinationQueueTotal.Inc() oq = &destinationQueue{ - queues: oqs, - db: oqs.db, - process: oqs.process, - rsAPI: oqs.rsAPI, - origin: oqs.origin, - destination: destination, - client: oqs.client, - statistics: oqs.statistics.ForServer(destination), - notify: make(chan struct{}, 1), - interruptBackoff: make(chan bool), - signing: oqs.signing, + queues: oqs, + db: oqs.db, + process: oqs.process, + rsAPI: oqs.rsAPI, + origin: oqs.origin, + destination: destination, + client: oqs.client, + statistics: oqs.statistics.ForServer(destination), + notify: make(chan struct{}, 1), + signing: oqs.signing, } + oq.statistics.AssignBackoffNotifier(oq.handleBackoffNotifier) oqs.queues[destination] = oq } return oq } +// clearQueue removes the queue for the provided destination from the +// set of destination queues. func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) { oqs.queuesMutex.Lock() defer oqs.queuesMutex.Unlock() @@ -332,7 +334,9 @@ func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { if oqs.disabled { return } + oqs.statistics.ForServer(srv).RemoveBlacklist() if queue := oqs.getQueue(srv); queue != nil { + queue.statistics.ClearBackoff() queue.wakeQueueIfNeeded() } } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go new file mode 100644 index 000000000..6da9e6b30 --- /dev/null +++ b/federationapi/queue/queue_test.go @@ -0,0 +1,1047 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "testing" + "time" + + "go.uber.org/atomic" + "gotest.tools/v3/poll" + + "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/statistics" + "github.com/matrix-org/dendrite/federationapi/storage" + "github.com/matrix-org/dendrite/federationapi/storage/shared" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *process.ProcessContext, func()) { + if realDatabase { + // Real Database/s + b, baseClose := testrig.CreateBaseDendrite(t, dbType) + connStr, dbClose := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewDatabase(b, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, b.Caches, b.Cfg.Global.ServerName) + if err != nil { + t.Fatalf("NewDatabase returned %s", err) + } + return db, b.ProcessContext, func() { + dbClose() + baseClose() + } + } else { + // Fake Database + db := createDatabase() + b := struct { + ProcessContext *process.ProcessContext + }{ProcessContext: process.NewProcessContext()} + return db, b.ProcessContext, func() {} + } +} + +func createDatabase() storage.Database { + return &fakeDatabase{ + pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), + pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent), + pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU), + associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), + associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), + } +} + +type fakeDatabase struct { + storage.Database + dbMutex sync.Mutex + pendingPDUServers map[gomatrixserverlib.ServerName]struct{} + pendingEDUServers map[gomatrixserverlib.ServerName]struct{} + blacklistedServers map[gomatrixserverlib.ServerName]struct{} + pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent + pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU + associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} + associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} +} + +var nidMutex sync.Mutex +var nid = int64(0) + +func (d *fakeDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var event gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal([]byte(js), &event); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + receipt := shared.NewReceipt(nid) + d.pendingPDUs[&receipt] = &event + return &receipt, nil + } + + var edu gomatrixserverlib.EDU + if err := json.Unmarshal([]byte(js), &edu); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + receipt := shared.NewReceipt(nid) + d.pendingEDUs[&receipt] = &edu + return &receipt, nil + } + + return nil, errors.New("Failed to determine type of json to store") +} + +func (d *fakeDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + pduCount := 0 + pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent) + if receipts, ok := d.associatedPDUs[serverName]; ok { + for receipt := range receipts { + if event, ok := d.pendingPDUs[receipt]; ok { + pdus[receipt] = event + pduCount++ + if pduCount == limit { + break + } + } + } + } + return pdus, nil +} + +func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + eduCount := 0 + edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU) + if receipts, ok := d.associatedEDUs[serverName]; ok { + for receipt := range receipts { + if event, ok := d.pendingEDUs[receipt]; ok { + edus[receipt] = event + eduCount++ + if eduCount == limit { + break + } + } + } + } + return edus, nil +} + +func (d *fakeDatabase) AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingPDUs[receipt]; ok { + if _, ok := d.associatedPDUs[serverName]; !ok { + d.associatedPDUs[serverName] = make(map[*shared.Receipt]struct{}) + } + d.associatedPDUs[serverName][receipt] = struct{}{} + return nil + } else { + return errors.New("PDU doesn't exist") + } +} + +func (d *fakeDatabase) AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingEDUs[receipt]; ok { + if _, ok := d.associatedEDUs[serverName]; !ok { + d.associatedEDUs[serverName] = make(map[*shared.Receipt]struct{}) + } + d.associatedEDUs[serverName][receipt] = struct{}{} + return nil + } else { + return errors.New("EDU doesn't exist") + } +} + +func (d *fakeDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if pdus, ok := d.associatedPDUs[serverName]; ok { + for _, receipt := range receipts { + delete(pdus, receipt) + } + } + + return nil +} + +func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if edus, ok := d.associatedEDUs[serverName]; ok { + for _, receipt := range receipts { + delete(edus, receipt) + } + } + + return nil +} + +func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if pdus, ok := d.associatedPDUs[serverName]; ok { + count = int64(len(pdus)) + } + return count, nil +} + +func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if edus, ok := d.associatedEDUs[serverName]; ok { + count = int64(len(edus)) + } + return count, nil +} + +func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingPDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *fakeDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingEDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *fakeDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers[serverName] = struct{}{} + return nil +} + +func (d *fakeDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + delete(d.blacklistedServers, serverName) + return nil +} + +func (d *fakeDatabase) RemoveAllServersFromBlacklist() error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) + return nil +} + +func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + isBlacklisted := false + if _, ok := d.blacklistedServers[serverName]; ok { + isBlacklisted = true + } + + return isBlacklisted, nil +} + +type stubFederationRoomServerAPI struct { + rsapi.FederationRoomserverAPI +} + +func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Context, req *rsapi.QueryServerBannedFromRoomRequest, res *rsapi.QueryServerBannedFromRoomResponse) error { + res.Banned = false + return nil +} + +type stubFederationClient struct { + api.FederationClient + shouldTxSucceed bool + txCount atomic.Uint32 +} + +func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) { + var result error + if !f.shouldTxSucceed { + result = fmt.Errorf("transaction failed") + } + + f.txCount.Add(1) + return gomatrixserverlib.RespSend{}, result +} + +func mustCreatePDU(t *testing.T) *gomatrixserverlib.HeaderedEvent { + t.Helper() + content := `{"type":"m.room.message"}` + ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(content), false, gomatrixserverlib.RoomVersionV10) + if err != nil { + t.Fatalf("failed to create event: %v", err) + } + return ev.Headered(gomatrixserverlib.RoomVersionV10) +} + +func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU { + t.Helper() + return &gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping} +} + +func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { + db, processContext, close := mustCreateFederationDatabase(t, dbType, realDatabase) + + fc := &stubFederationClient{ + shouldTxSucceed: shouldTxSucceed, + txCount: *atomic.NewUint32(0), + } + rs := &stubFederationRoomServerAPI{} + stats := statistics.NewStatistics(db, failuresUntilBlacklist) + signingInfo := &SigningInfo{ + KeyID: "ed21019:auto", + PrivateKey: test.PrivateKeyA, + ServerName: "localhost", + } + queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo) + + return db, fc, queues, processContext, close +} + +func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == 1 { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == 1 { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUOnFailStoredInDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending + if fc.txCount.Load() >= 2 { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + return poll.Success() + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUOnFailStoredInDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending + if fc.txCount.Load() >= 2 { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + return poll.Success() + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending + if fc.txCount.Load() >= 2 { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + return poll.Success() + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + ev = mustCreatePDU(t) + err = queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + pollEnd := time.Now().Add(1 * time.Second) + immediateCheck := func(log poll.LogT) poll.Result { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Error(fmt.Errorf("The backoff was interrupted early")) + } + if time.Now().After(pollEnd) { + // Allow more than enough time for the backoff to be interrupted before + // reporting that it wasn't. + return poll.Success() + } + return poll.Continue("waiting for events to be removed from database. Currently present PDU: %d", len(data)) + } + poll.WaitOn(t, immediateCheck, poll.WithTimeout(2*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending + if fc.txCount.Load() >= 2 { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + return poll.Success() + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + ev = mustCreateEDU(t) + err = queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + pollEnd := time.Now().Add(1 * time.Second) + immediateCheck := func(log poll.LogT) poll.Result { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Error(fmt.Errorf("The backoff was interrupted early")) + } + if time.Now().After(pollEnd) { + // Allow more than enough time for the backoff to be interrupted before + // reporting that it wasn't. + return poll.Success() + } + return poll.Continue("waiting for events to be removed from database. Currently present EDU: %d", len(data)) + } + poll.WaitOn(t, immediateCheck, poll.WithTimeout(2*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + queues.statistics.ForServer(destination).Failure() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + queues.statistics.ForServer(destination).Failure() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestRetryServerSendsPDUSuccessfully(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // NOTE : getQueue before sending event to ensure we grab the same queue reference + // before it is blacklisted and deleted. + dest := queues.getQueue(destination) + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + checkBlacklisted := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + if !dest.running.Load() { + return poll.Success() + } + return poll.Continue("waiting for queue to stop completely") + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + db.RemoveServerFromBlacklist(destination) + queues.RetryServer(destination) + checkRetry := func(log poll.LogT) poll.Result { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data)) + } + poll.WaitOn(t, checkRetry, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestRetryServerSendsEDUSuccessfully(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // NOTE : getQueue before sending event to ensure we grab the same queue reference + // before it is blacklisted and deleted. + dest := queues.getQueue(destination) + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + checkBlacklisted := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + if !dest.running.Load() { + return poll.Success() + } + return poll.Continue("waiting for queue to stop completely") + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + db.RemoveServerFromBlacklist(destination) + queues.RetryServer(destination) + checkRetry := func(log poll.LogT) poll.Result { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data)) + } + poll.WaitOn(t, checkRetry, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUBatches(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + + // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // Populate database with > maxPDUsPerTransaction + pduMultiplier := uint32(3) + for i := 0; i < maxPDUsPerTransaction*int(pduMultiplier); i++ { + ev := mustCreatePDU(t) + headeredJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) + now := gomatrixserverlib.AsTimestamp(time.Now()) + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, i)) + db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid) + } + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == pduMultiplier+1 { // +1 for the extra SendEvent() + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for all events to be removed from database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + // }) +} + +func TestSendEDUBatches(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + + // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // Populate database with > maxEDUsPerTransaction + eduMultiplier := uint32(3) + for i := 0; i < maxEDUsPerTransaction*int(eduMultiplier); i++ { + ev := mustCreateEDU(t) + ephemeralJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) + db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil) + } + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == eduMultiplier+1 { // +1 for the extra SendEvent() + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for all events to be removed from database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + // }) +} + +func TestSendPDUAndEDUBatches(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + + // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // Populate database with > maxEDUsPerTransaction + multiplier := uint32(3) + + for i := 0; i < maxPDUsPerTransaction*int(multiplier)+1; i++ { + ev := mustCreatePDU(t) + headeredJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) + now := gomatrixserverlib.AsTimestamp(time.Now()) + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, i)) + db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid) + } + + for i := 0; i < maxEDUsPerTransaction*int(multiplier); i++ { + ev := mustCreateEDU(t) + ephemeralJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) + db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil) + } + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == multiplier+1 { // +1 for the extra SendEvent() + pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrPDU) + eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrEDU) + if len(pduData) == 0 && len(eduData) == 0 { + return poll.Success() + } + return poll.Continue("waiting for all events to be removed from database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData)) + } + return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + // }) +} + +func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + dest := queues.getQueue(destination) + queues.statistics.ForServer(destination).Failure() + + ev := mustCreatePDU(t) + headeredJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) + now := gomatrixserverlib.AsTimestamp(time.Now()) + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, 1)) + db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid) + + pollEnd := time.Now().Add(3 * time.Second) + runningCheck := func(log poll.LogT) poll.Result { + if dest.running.Load() || fc.txCount.Load() > 0 { + return poll.Error(fmt.Errorf("The queue was started")) + } + if time.Now().After(pollEnd) { + // Allow more than enough time for the queue to be started in the case + // of backoff triggering it to start. + return poll.Success() + } + return poll.Continue("waiting to ensure queue doesn't start.") + } + poll.WaitOn(t, runningCheck, poll.WithTimeout(4*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { + // NOTE : Only one test case against real databases can be run at a time. + t.Parallel() + failuresUntilBlacklist := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true) + // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up. + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // NOTE : getQueue before sending event to ensure we grab the same queue reference + // before it is blacklisted and deleted. + dest := queues.getQueue(destination) + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + edu := mustCreateEDU(t) + errEDU := queues.SendEDU(edu, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, errEDU) + + checkBlacklisted := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrPDU) + eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrEDU) + if len(pduData) == 1 && len(eduData) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + if !dest.running.Load() { + return poll.Success() + } + return poll.Continue("waiting for queue to stop completely") + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for events to be added to database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + db.RemoveServerFromBlacklist(destination) + queues.RetryServer(destination) + checkRetry := func(log poll.LogT) poll.Result { + pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrPDU) + eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrEDU) + if len(pduData) == 0 && len(eduData) == 0 { + return poll.Success() + } + return poll.Continue("waiting for events to be removed from database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData)) + } + poll.WaitOn(t, checkRetry, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) + }) +} diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index db6d5c735..2ba99112c 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -2,6 +2,7 @@ package statistics import ( "math" + "math/rand" "sync" "time" @@ -20,12 +21,23 @@ type Statistics struct { servers map[gomatrixserverlib.ServerName]*ServerStatistics mutex sync.RWMutex + backoffTimers map[gomatrixserverlib.ServerName]*time.Timer + backoffMutex sync.RWMutex + // How many times should we tolerate consecutive failures before we // just blacklist the host altogether? The backoff is exponential, // so the max time here to attempt is 2**failures seconds. FailuresUntilBlacklist uint32 } +func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistics { + return Statistics{ + DB: db, + FailuresUntilBlacklist: failuresUntilBlacklist, + backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), + } +} + // ForServer returns server statistics for the given server name. If it // does not exist, it will create empty statistics and return those. func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics { @@ -45,7 +57,6 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS server = &ServerStatistics{ statistics: s, serverName: serverName, - interrupt: make(chan struct{}), } s.servers[serverName] = server s.mutex.Unlock() @@ -64,29 +75,43 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS // many times we failed etc. It also manages the backoff time and black- // listing a remote host if it remains uncooperative. type ServerStatistics struct { - statistics *Statistics // - serverName gomatrixserverlib.ServerName // - blacklisted atomic.Bool // is the node blacklisted - backoffStarted atomic.Bool // is the backoff started - backoffUntil atomic.Value // time.Time until this backoff interval ends - backoffCount atomic.Uint32 // number of times BackoffDuration has been called - interrupt chan struct{} // interrupts the backoff goroutine - successCounter atomic.Uint32 // how many times have we succeeded? + statistics *Statistics // + serverName gomatrixserverlib.ServerName // + blacklisted atomic.Bool // is the node blacklisted + backoffStarted atomic.Bool // is the backoff started + backoffUntil atomic.Value // time.Time until this backoff interval ends + backoffCount atomic.Uint32 // number of times BackoffDuration has been called + successCounter atomic.Uint32 // how many times have we succeeded? + backoffNotifier func() // notifies destination queue when backoff completes + notifierMutex sync.Mutex } +const maxJitterMultiplier = 1.4 +const minJitterMultiplier = 0.8 + // duration returns how long the next backoff interval should be. func (s *ServerStatistics) duration(count uint32) time.Duration { - return time.Second * time.Duration(math.Exp2(float64(count))) + // Add some jitter to minimise the chance of having multiple backoffs + // ending at the same time. + jitter := rand.Float64()*(maxJitterMultiplier-minJitterMultiplier) + minJitterMultiplier + duration := time.Millisecond * time.Duration(math.Exp2(float64(count))*jitter*1000) + return duration } // cancel will interrupt the currently active backoff. func (s *ServerStatistics) cancel() { s.blacklisted.Store(false) s.backoffUntil.Store(time.Time{}) - select { - case s.interrupt <- struct{}{}: - default: - } + + s.ClearBackoff() +} + +// AssignBackoffNotifier configures the channel to send to when +// a backoff completes. +func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) { + s.notifierMutex.Lock() + defer s.notifierMutex.Unlock() + s.backoffNotifier = notifier } // Success updates the server statistics with a new successful @@ -95,8 +120,8 @@ func (s *ServerStatistics) cancel() { // we will unblacklist it. func (s *ServerStatistics) Success() { s.cancel() - s.successCounter.Inc() s.backoffCount.Store(0) + s.successCounter.Inc() if s.statistics.DB != nil { if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) @@ -105,13 +130,17 @@ func (s *ServerStatistics) Success() { } // Failure marks a failure and starts backing off if needed. -// The next call to BackoffIfRequired will do the right thing -// after this. It will return the time that the current failure +// It will return the time that the current failure // will result in backoff waiting until, and a bool signalling // whether we have blacklisted and therefore to give up. func (s *ServerStatistics) Failure() (time.Time, bool) { + // Return immediately if we have blacklisted this node. + if s.blacklisted.Load() { + return time.Time{}, true + } + // If we aren't already backing off, this call will start - // a new backoff period. Increase the failure counter and + // a new backoff period, increase the failure counter and // start a goroutine which will wait out the backoff and // unset the backoffStarted flag when done. if s.backoffStarted.CompareAndSwap(false, true) { @@ -122,40 +151,48 @@ func (s *ServerStatistics) Failure() (time.Time, bool) { logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName) } } + s.ClearBackoff() return time.Time{}, true } - go func() { - until, ok := s.backoffUntil.Load().(time.Time) - if ok && !until.IsZero() { - select { - case <-time.After(time.Until(until)): - case <-s.interrupt: - } - s.backoffStarted.Store(false) - } - }() + // We're starting a new back off so work out what the next interval + // will be. + count := s.backoffCount.Load() + until := time.Now().Add(s.duration(count)) + s.backoffUntil.Store(until) + + s.statistics.backoffMutex.Lock() + defer s.statistics.backoffMutex.Unlock() + s.statistics.backoffTimers[s.serverName] = time.AfterFunc(time.Until(until), s.backoffFinished) } - // Check if we have blacklisted this node. - if s.blacklisted.Load() { - return time.Now(), true - } + return s.backoffUntil.Load().(time.Time), false +} - // If we're already backing off and we haven't yet surpassed - // the deadline then return that. Repeated calls to Failure - // within a single backoff interval will have no side effects. - if until, ok := s.backoffUntil.Load().(time.Time); ok && !time.Now().After(until) { - return until, false +// ClearBackoff stops the backoff timer for this destination if it is running +// and removes the timer from the backoffTimers map. +func (s *ServerStatistics) ClearBackoff() { + // If the timer is still running then stop it so it's memory is cleaned up sooner. + s.statistics.backoffMutex.Lock() + defer s.statistics.backoffMutex.Unlock() + if timer, ok := s.statistics.backoffTimers[s.serverName]; ok { + timer.Stop() } + delete(s.statistics.backoffTimers, s.serverName) - // We're either backing off and have passed the deadline, or - // we aren't backing off, so work out what the next interval - // will be. - count := s.backoffCount.Load() - until := time.Now().Add(s.duration(count)) - s.backoffUntil.Store(until) - return until, false + s.backoffStarted.Store(false) +} + +// backoffFinished will clear the previous backoff and notify the destination queue. +func (s *ServerStatistics) backoffFinished() { + s.ClearBackoff() + + // Notify the destinationQueue if one is currently running. + s.notifierMutex.Lock() + defer s.notifierMutex.Unlock() + if s.backoffNotifier != nil { + s.backoffNotifier() + } } // BackoffInfo returns information about the current or previous backoff. @@ -174,6 +211,12 @@ func (s *ServerStatistics) Blacklisted() bool { return s.blacklisted.Load() } +// RemoveBlacklist removes the blacklisted status from the server. +func (s *ServerStatistics) RemoveBlacklist() { + s.cancel() + s.backoffCount.Store(0) +} + // SuccessCount returns the number of successful requests. This is // usually useful in constructing transaction IDs. func (s *ServerStatistics) SuccessCount() uint32 { diff --git a/federationapi/statistics/statistics_test.go b/federationapi/statistics/statistics_test.go index 225350b6d..6aa997f44 100644 --- a/federationapi/statistics/statistics_test.go +++ b/federationapi/statistics/statistics_test.go @@ -7,9 +7,7 @@ import ( ) func TestBackoff(t *testing.T) { - stats := Statistics{ - FailuresUntilBlacklist: 7, - } + stats := NewStatistics(nil, 7) server := ServerStatistics{ statistics: &stats, serverName: "test.com", @@ -36,7 +34,7 @@ func TestBackoff(t *testing.T) { // Get the duration. _, blacklist := server.BackoffInfo() - duration := time.Until(until).Round(time.Second) + duration := time.Until(until) // Unset the backoff, or otherwise our next call will think that // there's a backoff in progress and return the same result. @@ -57,8 +55,17 @@ func TestBackoff(t *testing.T) { // Check if the duration is what we expect. t.Logf("Backoff %d is for %s", i, duration) - if wanted := time.Second * time.Duration(math.Exp2(float64(i))); !blacklist && duration != wanted { - t.Fatalf("Backoff %d should have been %s but was %s", i, wanted, duration) + roundingAllowance := 0.01 + minDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*minJitterMultiplier*1000-roundingAllowance) + maxDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*maxJitterMultiplier*1000+roundingAllowance) + var inJitterRange bool + if duration >= minDuration && duration <= maxDuration { + inJitterRange = true + } else { + inJitterRange = false + } + if !blacklist && !inJitterRange { + t.Fatalf("Backoff %d should have been between %s and %s but was %s", i, minDuration, maxDuration, duration) } } } diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 9e40f311c..6afb313a8 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -52,6 +52,10 @@ type Receipt struct { nid int64 } +func NewReceipt(nid int64) Receipt { + return Receipt{nid: nid} +} + func (r *Receipt) String() string { return fmt.Sprintf("%d", r.nid) } diff --git a/go.mod b/go.mod index 911d36c1c..2248e73c6 100644 --- a/go.mod +++ b/go.mod @@ -50,6 +50,7 @@ require ( golang.org/x/term v0.0.0-20220919170432-7a66f970e087 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 + gotest.tools/v3 v3.4.0 nhooyr.io/websocket v1.8.7 ) @@ -127,7 +128,6 @@ require ( gopkg.in/macaroon.v2 v2.1.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - gotest.tools/v3 v3.4.0 // indirect ) go 1.18 From f3dae0e749ca35b1527fbfcb0371e89d0e9833ab Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 19 Oct 2022 11:40:38 +0100 Subject: [PATCH 5/7] Bump nokogiri from 1.13.6 to 1.13.9 in /docs (#2809) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [nokogiri](https://github.com/sparklemotion/nokogiri) from 1.13.6 to 1.13.9.
Release notes

Sourced from nokogiri's releases.

1.13.9 / 2022-10-18

Security

Dependencies

  • [CRuby] Vendored libxml2 is updated to v2.10.3 from v2.9.14.
  • [CRuby] Vendored libxslt is updated to v1.1.37 from v1.1.35.
  • [CRuby] Vendored zlib is updated from 1.2.12 to 1.2.13. (See LICENSE-DEPENDENCIES.md for details on which packages redistribute this library.)

Fixed

  • [CRuby] Nokogiri::XML::Namespace objects, when compacted, update their internal struct's reference to the Ruby object wrapper. Previously, with GC compaction enabled, a segmentation fault was possible after compaction was triggered. [#2658] (Thanks, @​eightbitraptor and @​peterzhu2118!)
  • [CRuby] Document#remove_namespaces! now defers freeing the underlying xmlNs struct until the Document is GCed. Previously, maintaining a reference to a Namespace object that was removed in this way could lead to a segfault. [#2658]

sha256 checksums:

9b69829561d30c4461ea803baeaf3460e8b145cff7a26ce397119577a4083a02
nokogiri-1.13.9-aarch64-linux.gem
e76ebb4b7b2e02c72b2d1541289f8b0679fb5984867cf199d89b8ef485764956
nokogiri-1.13.9-arm64-darwin.gem
15bae7d08bddeaa898d8e3f558723300137c26a2dc2632a1f89c8574c4467165
nokogiri-1.13.9-java.gem
f6a1dbc7229184357f3129503530af73cc59ceba4932c700a458a561edbe04b9
nokogiri-1.13.9-x64-mingw-ucrt.gem
36d935d799baa4dc488024f71881ff0bc8b172cecdfc54781169c40ec02cbdb3
nokogiri-1.13.9-x64-mingw32.gem
ebaf82aa9a11b8fafb67873d19ee48efb565040f04c898cdce8ca0cd53ff1a12
nokogiri-1.13.9-x86-linux.gem
11789a2a11b28bc028ee111f23311461104d8c4468d5b901ab7536b282504154
nokogiri-1.13.9-x86-mingw32.gem
01830e1646803ff91c0fe94bc768ff40082c6de8cfa563dafd01b3f7d5f9d795
nokogiri-1.13.9-x86_64-darwin.gem
8e93b8adec22958013799c8690d81c2cdf8a90b6f6e8150ab22e11895844d781
nokogiri-1.13.9-x86_64-linux.gem
96f37c1baf0234d3ae54c2c89aef7220d4a8a1b03d2675ff7723565b0a095531
nokogiri-1.13.9.gem

1.13.8 / 2022-07-23

Deprecated

  • XML::Reader#attribute_nodes is deprecated due to incompatibility between libxml2's xmlReader memory semantics and Ruby's garbage collector. Although this method continues to exist for backwards compatibility, it is unsafe to call and may segfault. This method will be removed in a future version of Nokogiri, and callers should use #attribute_hash instead. [#2598]

Improvements

  • XML::Reader#attribute_hash is a new method to safely retrieve the attributes of a node from XML::Reader. [#2598, #2599]

Fixed

... (truncated)

Changelog

Sourced from nokogiri's changelog.

1.13.9 / 2022-10-18

Security

Dependencies

  • [CRuby] Vendored libxml2 is updated to v2.10.3 from v2.9.14.
  • [CRuby] Vendored libxslt is updated to v1.1.37 from v1.1.35.
  • [CRuby] Vendored zlib is updated from 1.2.12 to 1.2.13. (See LICENSE-DEPENDENCIES.md for details on which packages redistribute this library.)

Fixed

  • [CRuby] Nokogiri::XML::Namespace objects, when compacted, update their internal struct's reference to the Ruby object wrapper. Previously, with GC compaction enabled, a segmentation fault was possible after compaction was triggered. [#2658] (Thanks, @​eightbitraptor and @​peterzhu2118!)
  • [CRuby] Document#remove_namespaces! now defers freeing the underlying xmlNs struct until the Document is GCed. Previously, maintaining a reference to a Namespace object that was removed in this way could lead to a segfault. [#2658]

1.13.8 / 2022-07-23

Deprecated

  • XML::Reader#attribute_nodes is deprecated due to incompatibility between libxml2's xmlReader memory semantics and Ruby's garbage collector. Although this method continues to exist for backwards compatibility, it is unsafe to call and may segfault. This method will be removed in a future version of Nokogiri, and callers should use #attribute_hash instead. [#2598]

Improvements

  • XML::Reader#attribute_hash is a new method to safely retrieve the attributes of a node from XML::Reader. [#2598, #2599]

Fixed

  • [CRuby] Calling XML::Reader#attributes is now safe to call. In Nokogiri <= 1.13.7 this method may segfault. [#2598, #2599]

1.13.7 / 2022-07-12

Fixed

XML::Node objects, when compacted, update their internal struct's reference to the Ruby object wrapper. Previously, with GC compaction enabled, a segmentation fault was possible after compaction was triggered. [#2578] (Thanks, @​eightbitraptor!)

Commits
  • 897759c version bump to v1.13.9
  • aeb1ac3 doc: update CHANGELOG
  • c663e49 Merge pull request #2671 from sparklemotion/flavorjones-update-zlib-1.2.13_v1...
  • 212e07d ext: hack to cross-compile zlib v1.2.13 on darwin
  • 76dbc8c dep: update zlib to v1.2.13
  • 24e3a9c doc: update CHANGELOG
  • 4db3b4d Merge pull request #2668 from sparklemotion/flavorjones-namespace-scopes-comp...
  • 73d73d6 fix: Document#remove_namespaces! use-after-free bug
  • 5f58b34 fix: namespace nodes behave properly when compacted
  • b08a858 test: repro namespace_scopes compaction issue
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=nokogiri&package-manager=bundler&previous-version=1.13.6&new-version=1.13.9)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) - `@dependabot use these labels` will set the current labels as the default for future PRs for this repo and language - `@dependabot use these reviewers` will set the current reviewers as the default for future PRs for this repo and language - `@dependabot use these assignees` will set the current assignees as the default for future PRs for this repo and language - `@dependabot use this milestone` will set the current milestone as the default for future PRs for this repo and language You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/matrix-org/dendrite/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/Gemfile.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index bc73df728..c7ba43711 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -231,9 +231,9 @@ GEM jekyll-seo-tag (~> 2.1) minitest (5.15.0) multipart-post (2.1.1) - nokogiri (1.13.6-arm64-darwin) + nokogiri (1.13.9-arm64-darwin) racc (~> 1.4) - nokogiri (1.13.6-x86_64-linux) + nokogiri (1.13.9-x86_64-linux) racc (~> 1.4) octokit (4.22.0) faraday (>= 0.9) From c1463db6c9183aa67ef41e7ea85ed36dc5817d18 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 19 Oct 2022 12:03:12 +0100 Subject: [PATCH 6/7] Fix concurrent map write in key server --- keyserver/internal/internal.go | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 06fc4987c..d2ea20935 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -250,6 +250,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap // nolint:gocyclo func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { + var respMu sync.Mutex res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey) res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) @@ -329,7 +330,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } // attempt to satisfy key queries from the local database first as we should get device updates pushed to us - domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys) + domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys) if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 { // perform key queries for remote devices a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys) @@ -407,7 +408,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } func (a *KeyInternalAPI) remoteKeysFromDatabase( - ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, + ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string, ) map[string]map[string][]string { fetchRemote := make(map[string]map[string][]string) for domain, userToDeviceMap := range domainToDeviceKeys { @@ -415,7 +416,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase( // we can't safely return keys from the db when all devices are requested as we don't // know if one has just been added. if len(deviceIDs) > 0 { - err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs) + err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs) if err == nil { continue } @@ -542,7 +543,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( // refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this // user so the fact that we're populating all devices here isn't a problem so long as we have devices. respMu.Lock() - err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil) + err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil) respMu.Unlock() if err != nil { logrus.WithFields(logrus.Fields{ @@ -573,7 +574,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( // inspecting the failures map though so they can know it's a cached response. for userID, dkeys := range devKeys { // drop the error as it's already a failure at this point - _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys) + _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys) } // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache @@ -585,7 +586,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( } func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( - ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string, + ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string, ) error { keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) // if we can't query the db or there are fewer keys than requested, fetch from remote. @@ -598,9 +599,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( if len(deviceIDs) == 0 && len(keys) == 0 { return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID) } + respMu.Lock() if res.DeviceKeys[userID] == nil { res.DeviceKeys[userID] = make(map[string]json.RawMessage) } + respMu.Unlock() for _, key := range keys { if len(key.KeyJSON) == 0 { @@ -610,7 +613,9 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { DisplayName string `json:"device_display_name,omitempty"` }{key.DisplayName}) + respMu.Lock() res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON + respMu.Unlock() } return nil } From 8cbe14bd6d985ceb2f7c098548a3fbeedfce2d55 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 19 Oct 2022 12:27:34 +0100 Subject: [PATCH 7/7] Fix lock contention --- keyserver/internal/internal.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index d2ea20935..89621aa87 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -542,9 +542,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( } // refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this // user so the fact that we're populating all devices here isn't a problem so long as we have devices. - respMu.Lock() err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil) - respMu.Unlock() if err != nil { logrus.WithFields(logrus.Fields{ logrus.ErrorKey: err, @@ -568,6 +566,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( res.Failures[serverName] = map[string]interface{}{ "message": err.Error(), } + respMu.Unlock() // last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server // is down, better to return something than nothing at all. Clients can know about the failure by @@ -578,11 +577,11 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( } // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache + respMu.Lock() if len(res.DeviceKeys) > 0 { delete(res.Failures, serverName) } respMu.Unlock() - } func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(