From 83c9dde219440e0d730d63d2d65fc9eaaea64762 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Mon, 17 Oct 2022 07:27:11 +0200 Subject: [PATCH 01/22] Return error if we fail to read the response body --- cmd/create-account/main.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 52301415f..c8e239f29 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -179,7 +179,10 @@ func sharedSecretRegister(sharedSecret, serverURL, localpart, password string, a body, _ = io.ReadAll(regResp.Body) return "", fmt.Errorf(gjson.GetBytes(body, "error").Str) } - r, _ := io.ReadAll(regResp.Body) + r, err := io.ReadAll(regResp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response body (HTTP %d): %w", regResp.StatusCode, err) + } return gjson.GetBytes(r, "access_token").Str, nil } From d72d4f8d5d0016a8dcbf77aba92671f3469eb630 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 17 Oct 2022 10:38:22 +0100 Subject: [PATCH 02/22] Set `org.matrix.msc2285.stable` in `/versions` --- clientapi/routing/routing.go | 1 + 1 file changed, 1 insertion(+) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index e72880ec5..ec5ca899e 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -70,6 +70,7 @@ func Setup( unstableFeatures := map[string]bool{ "org.matrix.e2e_cross_signing": true, + "org.matrix.msc2285.stable": true, } for _, msc := range cfg.MSCs.MSCs { unstableFeatures["org.matrix."+msc] = true From 07bfb791ca616bd3a4aa96691b74c96146d59d90 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 17 Oct 2022 14:48:35 +0200 Subject: [PATCH 03/22] Scope transactions to endpoints (#2799) To avoid returning results from e.g. `/redact` on `/sendToDevice` requests. Takes the raw URL path and uses `filepath.Dir` to remove the `txnID` (file) from it. Co-authored-by: Neil Alexander --- clientapi/routing/redaction.go | 9 ++--- clientapi/routing/sendevent.go | 4 +-- clientapi/routing/sendtodevice.go | 7 ++-- clientapi/routing/server_notices.go | 7 ++-- internal/transactions/transactions.go | 16 +++++---- internal/transactions/transactions_test.go | 42 ++++++++++++++++++---- 6 files changed, 59 insertions(+), 26 deletions(-) diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 27f0ba5d0..a0f3b1152 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -19,6 +19,9 @@ import ( "net/http" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" @@ -26,8 +29,6 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" ) type redactionContent struct { @@ -51,7 +52,7 @@ func SendRedaction( if txnID != nil { // Try to fetch response from transactionsCache - if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok { return *res } } @@ -144,7 +145,7 @@ func SendRedaction( // Add response to transactionsCache if txnID != nil { - txnCache.AddTransaction(device.AccessToken, *txnID, &res) + txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res) } return res diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 85f1053f3..114e9088d 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -86,7 +86,7 @@ func SendEvent( if txnID != nil { // Try to fetch response from transactionsCache - if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok { return *res } } @@ -206,7 +206,7 @@ func SendEvent( } // Add response to transactionsCache if txnID != nil { - txnCache.AddTransaction(device.AccessToken, *txnID, &res) + txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res) } // Take a note of how long it took to generate the event vs submit diff --git a/clientapi/routing/sendtodevice.go b/clientapi/routing/sendtodevice.go index 4a5f08883..0c0227937 100644 --- a/clientapi/routing/sendtodevice.go +++ b/clientapi/routing/sendtodevice.go @@ -16,12 +16,13 @@ import ( "encoding/json" "net/http" + "github.com/matrix-org/util" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/internal/transactions" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/util" ) // SendToDevice handles PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId} @@ -33,7 +34,7 @@ func SendToDevice( eventType string, txnID *string, ) util.JSONResponse { if txnID != nil { - if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok { return *res } } @@ -63,7 +64,7 @@ func SendToDevice( } if txnID != nil { - txnCache.AddTransaction(device.AccessToken, *txnID, &res) + txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res) } return res diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index 9edeed2f7..7729eddd8 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -21,7 +21,6 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/tokens" @@ -29,6 +28,8 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/roomserver/version" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -73,7 +74,7 @@ func SendServerNotice( if txnID != nil { // Try to fetch response from transactionsCache - if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok { return *res } } @@ -251,7 +252,7 @@ func SendServerNotice( } // Add response to transactionsCache if txnID != nil { - txnCache.AddTransaction(device.AccessToken, *txnID, &res) + txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res) } // Take a note of how long it took to generate the event vs submit diff --git a/internal/transactions/transactions.go b/internal/transactions/transactions.go index d2eb0f27f..7ff6f5044 100644 --- a/internal/transactions/transactions.go +++ b/internal/transactions/transactions.go @@ -13,6 +13,8 @@ package transactions import ( + "net/url" + "path/filepath" "sync" "time" @@ -29,6 +31,7 @@ type txnsMap map[CacheKey]*util.JSONResponse type CacheKey struct { AccessToken string TxnID string + Endpoint string } // Cache represents a temporary store for response entries. @@ -57,14 +60,14 @@ func NewWithCleanupPeriod(cleanupPeriod time.Duration) *Cache { return &t } -// FetchTransaction looks up an entry for the (accessToken, txnID) tuple in Cache. +// FetchTransaction looks up an entry for the (accessToken, txnID, req.URL) tuple in Cache. // Looks in both the txnMaps. // Returns (JSON response, true) if txnID is found, else the returned bool is false. -func (t *Cache) FetchTransaction(accessToken, txnID string) (*util.JSONResponse, bool) { +func (t *Cache) FetchTransaction(accessToken, txnID string, u *url.URL) (*util.JSONResponse, bool) { t.RLock() defer t.RUnlock() for _, txns := range t.txnsMaps { - res, ok := txns[CacheKey{accessToken, txnID}] + res, ok := txns[CacheKey{accessToken, txnID, filepath.Dir(u.Path)}] if ok { return res, true } @@ -72,13 +75,12 @@ func (t *Cache) FetchTransaction(accessToken, txnID string) (*util.JSONResponse, return nil, false } -// AddTransaction adds an entry for the (accessToken, txnID) tuple in Cache. +// AddTransaction adds an entry for the (accessToken, txnID, req.URL) tuple in Cache. // Adds to the front txnMap. -func (t *Cache) AddTransaction(accessToken, txnID string, res *util.JSONResponse) { +func (t *Cache) AddTransaction(accessToken, txnID string, u *url.URL, res *util.JSONResponse) { t.Lock() defer t.Unlock() - - t.txnsMaps[0][CacheKey{accessToken, txnID}] = res + t.txnsMaps[0][CacheKey{accessToken, txnID, filepath.Dir(u.Path)}] = res } // cacheCleanService is responsible for cleaning up entries after cleanupPeriod. diff --git a/internal/transactions/transactions_test.go b/internal/transactions/transactions_test.go index aa837f76c..c552550ac 100644 --- a/internal/transactions/transactions_test.go +++ b/internal/transactions/transactions_test.go @@ -14,6 +14,9 @@ package transactions import ( "net/http" + "net/url" + "path/filepath" + "reflect" "strconv" "testing" @@ -24,6 +27,16 @@ type fakeType struct { ID string `json:"ID"` } +func TestCompare(t *testing.T) { + u1, _ := url.Parse("/send/1?accessToken=123") + u2, _ := url.Parse("/send/1") + c1 := CacheKey{"1", "2", filepath.Dir(u1.Path)} + c2 := CacheKey{"1", "2", filepath.Dir(u2.Path)} + if !reflect.DeepEqual(c1, c2) { + t.Fatalf("Cache keys differ: %+v <> %+v", c1, c2) + } +} + var ( fakeAccessToken = "aRandomAccessToken" fakeAccessToken2 = "anotherRandomAccessToken" @@ -34,23 +47,28 @@ var ( fakeResponse2 = &util.JSONResponse{ Code: http.StatusOK, JSON: fakeType{ID: "1"}, } + fakeResponse3 = &util.JSONResponse{ + Code: http.StatusOK, JSON: fakeType{ID: "2"}, + } ) // TestCache creates a New Cache and tests AddTransaction & FetchTransaction func TestCache(t *testing.T) { fakeTxnCache := New() - fakeTxnCache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse) + u, _ := url.Parse("") + fakeTxnCache.AddTransaction(fakeAccessToken, fakeTxnID, u, fakeResponse) // Add entries for noise. for i := 1; i <= 100; i++ { fakeTxnCache.AddTransaction( fakeAccessToken, fakeTxnID+strconv.Itoa(i), + u, &util.JSONResponse{Code: http.StatusOK, JSON: fakeType{ID: strconv.Itoa(i)}}, ) } - testResponse, ok := fakeTxnCache.FetchTransaction(fakeAccessToken, fakeTxnID) + testResponse, ok := fakeTxnCache.FetchTransaction(fakeAccessToken, fakeTxnID, u) if !ok { t.Error("Failed to retrieve entry for txnID: ", fakeTxnID) } else if testResponse.JSON != fakeResponse.JSON { @@ -59,20 +77,30 @@ func TestCache(t *testing.T) { } // TestCacheScope ensures transactions with the same transaction ID are not shared -// across multiple access tokens. +// across multiple access tokens and endpoints. func TestCacheScope(t *testing.T) { cache := New() - cache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse) - cache.AddTransaction(fakeAccessToken2, fakeTxnID, fakeResponse2) + sendEndpoint, _ := url.Parse("/send/1?accessToken=test") + sendToDeviceEndpoint, _ := url.Parse("/sendToDevice/1") + cache.AddTransaction(fakeAccessToken, fakeTxnID, sendEndpoint, fakeResponse) + cache.AddTransaction(fakeAccessToken2, fakeTxnID, sendEndpoint, fakeResponse2) + cache.AddTransaction(fakeAccessToken2, fakeTxnID, sendToDeviceEndpoint, fakeResponse3) - if res, ok := cache.FetchTransaction(fakeAccessToken, fakeTxnID); !ok { + if res, ok := cache.FetchTransaction(fakeAccessToken, fakeTxnID, sendEndpoint); !ok { t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID) } else if res.JSON != fakeResponse.JSON { t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse.JSON, res.JSON) } - if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID); !ok { + if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID, sendEndpoint); !ok { t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID) } else if res.JSON != fakeResponse2.JSON { t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse2.JSON, res.JSON) } + + // Ensure the txnID is not shared across endpoints + if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID, sendToDeviceEndpoint); !ok { + t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID) + } else if res.JSON != fakeResponse3.JSON { + t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse2.JSON, res.JSON) + } } From 9c189b1b80f9c338ac3cfa71cdaca60016de45f7 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 18 Oct 2022 09:51:31 +0100 Subject: [PATCH 04/22] Try to make `AddEvent` less expensive (update to matrix-org/gomatrixserverlib@a72a83f) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 1303d4004..911d36c1c 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241 + github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index 54055f2fd..a141fc9b4 100644 --- a/go.sum +++ b/go.sum @@ -387,8 +387,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241 h1:e5o68MWeU7wjTvvNKmVo655oCYesoNRoPeBb1Xfz54g= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a h1:bQKHk3AWlgm7XhzPhuU3Iw3pUptW5l1DR/1y0o7zCKQ= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 h1:lzkSQvBv8TuqKJCPoVwOVvEnARTlua5rrNy/Qw2Vxeo= github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= From 3aa92efaa3e814ad0596fc5fc174a2e43124dcf5 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 18 Oct 2022 15:59:08 +0100 Subject: [PATCH 05/22] Namespace user API tables (#2806) This migrates all the various user API tables, indices and sequences to be `userapi_`-namespaced, rather than the mess they are all now. --- userapi/consumers/roomserver_test.go | 9 +- .../storage/postgres/account_data_table.go | 8 +- userapi/storage/postgres/accounts_table.go | 14 +-- .../deltas/20200929203058_is_active.go | 4 +- .../deltas/20201001204705_last_seen_ts_ip.go | 12 +- .../2022021013023800_add_account_type.go | 10 +- .../deltas/2022101711000000_rename_tables.go | 102 ++++++++++++++++ userapi/storage/postgres/devices_table.go | 28 ++--- userapi/storage/postgres/key_backup_table.go | 18 +-- .../postgres/key_backup_version_table.go | 20 ++-- userapi/storage/postgres/logintoken_table.go | 12 +- userapi/storage/postgres/openid_table.go | 6 +- userapi/storage/postgres/profile_table.go | 12 +- userapi/storage/postgres/stats_table.go | 16 +-- userapi/storage/postgres/storage.go | 11 ++ userapi/storage/postgres/threepid_table.go | 12 +- userapi/storage/sqlite3/account_data_table.go | 8 +- userapi/storage/sqlite3/accounts_table.go | 14 +-- .../deltas/20200929203058_is_active.go | 20 ++-- .../deltas/20201001204705_last_seen_ts_ip.go | 20 ++-- .../2022021012490600_add_account_type.go | 16 +-- .../deltas/2022101711000000_rename_tables.go | 109 ++++++++++++++++++ userapi/storage/sqlite3/devices_table.go | 24 ++-- userapi/storage/sqlite3/key_backup_table.go | 18 +-- .../sqlite3/key_backup_version_table.go | 16 +-- userapi/storage/sqlite3/logintoken_table.go | 12 +- userapi/storage/sqlite3/openid_table.go | 6 +- userapi/storage/sqlite3/profile_table.go | 12 +- userapi/storage/sqlite3/stats_table.go | 16 +-- userapi/storage/sqlite3/storage.go | 11 ++ userapi/storage/sqlite3/threepid_table.go | 12 +- userapi/storage/storage_test.go | 9 +- userapi/storage/tables/stats_table_test.go | 4 +- userapi/userapi_test.go | 14 ++- 34 files changed, 441 insertions(+), 194 deletions(-) create mode 100644 userapi/storage/postgres/deltas/2022101711000000_rename_tables.go create mode 100644 userapi/storage/sqlite3/deltas/2022101711000000_rename_tables.go diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 3bbeb439a..e4587670f 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -10,19 +10,24 @@ import ( "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/storage" ) func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + base, baseclose := testrig.CreateBaseDendrite(t, dbType) t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{ + db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "", 4, 0, 0, "") if err != nil { t.Fatalf("failed to create new user db: %v", err) } - return db, close + return db, func() { + close() + baseclose() + } } func mustCreateEvent(t *testing.T, content string) *gomatrixserverlib.HeaderedEvent { diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 67113367b..0b6a3af6d 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -26,7 +26,7 @@ import ( const accountDataSchema = ` -- Stores data about accounts data. -CREATE TABLE IF NOT EXISTS account_data ( +CREATE TABLE IF NOT EXISTS userapi_account_datas ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL, -- The room ID for this data (empty string if not specific to a room) @@ -41,15 +41,15 @@ CREATE TABLE IF NOT EXISTS account_data ( ` const insertAccountDataSQL = ` - INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) + INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4) ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content ` const selectAccountDataSQL = "" + - "SELECT room_id, type, content FROM account_data WHERE localpart = $1" + "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1" const selectAccountDataByTypeSQL = "" + - "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" + "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3" type accountDataStatements struct { insertAccountDataStmt *sql.Stmt diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 33fb6dd42..7c309eb4f 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -32,7 +32,7 @@ import ( const accountsSchema = ` -- Stores data about accounts. -CREATE TABLE IF NOT EXISTS account_accounts ( +CREATE TABLE IF NOT EXISTS userapi_accounts ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL PRIMARY KEY, -- When this account was first created, as a unix timestamp (ms resolution). @@ -51,22 +51,22 @@ CREATE TABLE IF NOT EXISTS account_accounts ( ` const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" const updatePasswordSQL = "" + - "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" + "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" const deactivateAccountSQL = "" + - "UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1" + "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" + "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" const selectPasswordHashSQL = "" + - "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE" const selectNewNumericLocalpartSQL = "" + - "SELECT COALESCE(MAX(localpart::bigint), 0) FROM account_accounts WHERE localpart ~ '^[0-9]{1,}$'" + "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'" type accountsStatements struct { insertAccountStmt *sql.Stmt diff --git a/userapi/storage/postgres/deltas/20200929203058_is_active.go b/userapi/storage/postgres/deltas/20200929203058_is_active.go index 24f87e073..2c5cc2f58 100644 --- a/userapi/storage/postgres/deltas/20200929203058_is_active.go +++ b/userapi/storage/postgres/deltas/20200929203058_is_active.go @@ -7,7 +7,7 @@ import ( ) func UpIsActive(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;") + _, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;") if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } @@ -15,7 +15,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error { } func DownIsActive(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN is_deactivated;") + _, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts DROP COLUMN is_deactivated;") if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go index edd3353f0..40e237027 100644 --- a/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go @@ -8,9 +8,9 @@ import ( func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` -ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000; -ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS ip TEXT; -ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`) +ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000; +ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS ip TEXT; +ALTER TABLE userapi_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`) if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } @@ -19,9 +19,9 @@ ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`) func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` - ALTER TABLE device_devices DROP COLUMN last_seen_ts; - ALTER TABLE device_devices DROP COLUMN ip; - ALTER TABLE device_devices DROP COLUMN user_agent;`) + ALTER TABLE userapi_devices DROP COLUMN last_seen_ts; + ALTER TABLE userapi_devices DROP COLUMN ip; + ALTER TABLE userapi_devices DROP COLUMN user_agent;`) if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go index eb7c3a958..164847e51 100644 --- a/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go +++ b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go @@ -9,10 +9,10 @@ import ( func UpAddAccountType(ctx context.Context, tx *sql.Tx) error { // initially set every account to useraccount, change appservice and guest accounts afterwards // (user = 1, guest = 2, admin = 3, appservice = 4) - _, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1; -UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> ''; -UPDATE account_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$'; -ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`, + _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1; +UPDATE userapi_accounts SET account_type = 4 WHERE appservice_id <> ''; +UPDATE userapi_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$'; +ALTER TABLE userapi_accounts ALTER COLUMN account_type DROP DEFAULT;`, ) if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) @@ -21,7 +21,7 @@ ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`, } func DownAddAccountType(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN account_type;") + _, err := tx.ExecContext(ctx, "ALTER TABLE userapi_accounts DROP COLUMN account_type;") if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/postgres/deltas/2022101711000000_rename_tables.go b/userapi/storage/postgres/deltas/2022101711000000_rename_tables.go new file mode 100644 index 000000000..1d73d0af4 --- /dev/null +++ b/userapi/storage/postgres/deltas/2022101711000000_rename_tables.go @@ -0,0 +1,102 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" +) + +var renameTableMappings = map[string]string{ + "account_accounts": "userapi_accounts", + "account_data": "userapi_account_datas", + "device_devices": "userapi_devices", + "account_e2e_room_keys": "userapi_key_backups", + "account_e2e_room_keys_versions": "userapi_key_backup_versions", + "login_tokens": "userapi_login_tokens", + "open_id_tokens": "userapi_openid_tokens", + "account_profiles": "userapi_profiles", + "account_threepid": "userapi_threepids", +} + +var renameSequenceMappings = map[string]string{ + "device_session_id_seq": "userapi_device_session_id_seq", + "account_e2e_room_keys_versions_seq": "userapi_key_backup_versions_seq", +} + +var renameIndicesMappings = map[string]string{ + "device_localpart_id_idx": "userapi_device_localpart_id_idx", + "e2e_room_keys_idx": "userapi_key_backups_idx", + "e2e_room_keys_versions_idx": "userapi_key_backups_versions_idx", + "account_e2e_room_keys_versions_idx": "userapi_key_backup_versions_idx", + "login_tokens_expiration_idx": "userapi_login_tokens_expiration_idx", + "account_threepid_localpart": "userapi_threepid_idx", +} + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpRenameTables(ctx context.Context, tx *sql.Tx) error { + for old, new := range renameTableMappings { + q := fmt.Sprintf( + "ALTER TABLE IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(old), pq.QuoteIdentifier(new), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", old, new, err) + } + } + for old, new := range renameSequenceMappings { + q := fmt.Sprintf( + "ALTER SEQUENCE IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(old), pq.QuoteIdentifier(new), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", old, new, err) + } + } + for old, new := range renameIndicesMappings { + q := fmt.Sprintf( + "ALTER INDEX IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(old), pq.QuoteIdentifier(new), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", old, new, err) + } + } + return nil +} + +func DownRenameTables(ctx context.Context, tx *sql.Tx) error { + for old, new := range renameTableMappings { + q := fmt.Sprintf( + "ALTER TABLE IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(new), pq.QuoteIdentifier(old), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", new, old, err) + } + } + for old, new := range renameSequenceMappings { + q := fmt.Sprintf( + "ALTER SEQUENCE IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(new), pq.QuoteIdentifier(old), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", new, old, err) + } + } + for old, new := range renameIndicesMappings { + q := fmt.Sprintf( + "ALTER INDEX IF EXISTS %s RENAME TO %s;", + pq.QuoteIdentifier(new), pq.QuoteIdentifier(old), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", new, old, err) + } + } + return nil +} diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index f65681aae..8b7fbd6cf 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -31,10 +31,10 @@ import ( const devicesSchema = ` -- This sequence is used for automatic allocation of session_id. -CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1; +CREATE SEQUENCE IF NOT EXISTS userapi_device_session_id_seq START 1; -- Stores data about devices. -CREATE TABLE IF NOT EXISTS device_devices ( +CREATE TABLE IF NOT EXISTS userapi_devices ( -- The access token granted to this device. This has to be the primary key -- so we can distinguish which device is making a given request. access_token TEXT NOT NULL PRIMARY KEY, @@ -42,7 +42,7 @@ CREATE TABLE IF NOT EXISTS device_devices ( -- This can be used as a secure substitution of the access token in situations -- where data is associated with access tokens (e.g. transaction storage), -- so we don't have to store users' access tokens everywhere. - session_id BIGINT NOT NULL DEFAULT nextval('device_session_id_seq'), + session_id BIGINT NOT NULL DEFAULT nextval('userapi_device_session_id_seq'), -- The device identifier. This only needs to uniquely identify a device for a given user, not globally. -- access_tokens will be clobbered based on the device ID for a user. device_id TEXT NOT NULL, @@ -65,39 +65,39 @@ CREATE TABLE IF NOT EXISTS device_devices ( ); -- Device IDs must be unique for a given user. -CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(localpart, device_id); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, device_id); ` const insertDeviceSQL = "" + - "INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + + "INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + " RETURNING session_id" const selectDeviceByTokenSQL = "" + - "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" + "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + - "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" const deleteDeviceSQL = "" + - "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" + "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" + "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" const deleteDevicesSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" + "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + - "UPDATE device_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" type devicesStatements struct { insertDeviceStmt *sql.Stmt diff --git a/userapi/storage/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go index ac0e80617..7b58f7bae 100644 --- a/userapi/storage/postgres/key_backup_table.go +++ b/userapi/storage/postgres/key_backup_table.go @@ -26,7 +26,7 @@ import ( ) const keyBackupTableSchema = ` -CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( +CREATE TABLE IF NOT EXISTS userapi_key_backups ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, session_id TEXT NOT NULL, @@ -37,31 +37,31 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( is_verified BOOLEAN NOT NULL, session_data TEXT NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); -CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backups_idx ON userapi_key_backups(user_id, room_id, session_id, version); +CREATE INDEX IF NOT EXISTS userapi_key_backups_versions_idx ON userapi_key_backups(user_id, version); ` const insertBackupKeySQL = "" + - "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + + "INSERT INTO userapi_key_backups(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" const updateBackupKeySQL = "" + - "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + + "UPDATE userapi_key_backups SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8" const countKeysSQL = "" + - "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" const selectKeysSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" const selectKeysByRoomIDSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2 AND room_id = $3" const selectKeysByRoomIDAndSessionIDSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" type keyBackupStatements struct { diff --git a/userapi/storage/postgres/key_backup_version_table.go b/userapi/storage/postgres/key_backup_version_table.go index e78e4cd51..67c5e5481 100644 --- a/userapi/storage/postgres/key_backup_version_table.go +++ b/userapi/storage/postgres/key_backup_version_table.go @@ -26,40 +26,40 @@ import ( ) const keyBackupVersionTableSchema = ` -CREATE SEQUENCE IF NOT EXISTS account_e2e_room_keys_versions_seq; +CREATE SEQUENCE IF NOT EXISTS userapi_key_backup_versions_seq; -- the metadata for each generation of encrypted e2e session backups -CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( +CREATE TABLE IF NOT EXISTS userapi_key_backup_versions ( user_id TEXT NOT NULL, -- this means no 2 users will ever have the same version of e2e session backups which strictly -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. - version BIGINT DEFAULT nextval('account_e2e_room_keys_versions_seq'), + version BIGINT DEFAULT nextval('userapi_key_backup_versions_seq'), algorithm TEXT NOT NULL, auth_data TEXT NOT NULL, etag TEXT NOT NULL, deleted SMALLINT DEFAULT 0 NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backup_versions_idx ON userapi_key_backup_versions(user_id, version); ` const insertKeyBackupSQL = "" + - "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" + "INSERT INTO userapi_key_backup_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" const updateKeyBackupAuthDataSQL = "" + - "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" + "UPDATE userapi_key_backup_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" const updateKeyBackupETagSQL = "" + - "UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3" + "UPDATE userapi_key_backup_versions SET etag = $1 WHERE user_id = $2 AND version = $3" const deleteKeyBackupSQL = "" + - "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" + "UPDATE userapi_key_backup_versions SET deleted=1 WHERE user_id = $1 AND version = $2" const selectKeyBackupSQL = "" + - "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + "SELECT algorithm, auth_data, etag, deleted FROM userapi_key_backup_versions WHERE user_id = $1 AND version = $2" const selectLatestVersionSQL = "" + - "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" + "SELECT MAX(version) FROM userapi_key_backup_versions WHERE user_id = $1" type keyBackupVersionStatements struct { insertKeyBackupStmt *sql.Stmt diff --git a/userapi/storage/postgres/logintoken_table.go b/userapi/storage/postgres/logintoken_table.go index 4de96f839..44c6ca4ae 100644 --- a/userapi/storage/postgres/logintoken_table.go +++ b/userapi/storage/postgres/logintoken_table.go @@ -26,7 +26,7 @@ import ( ) const loginTokenSchema = ` -CREATE TABLE IF NOT EXISTS login_tokens ( +CREATE TABLE IF NOT EXISTS userapi_login_tokens ( -- The random value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, -- When the token expires @@ -37,17 +37,17 @@ CREATE TABLE IF NOT EXISTS login_tokens ( ); -- This index allows efficient garbage collection of expired tokens. -CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); +CREATE INDEX IF NOT EXISTS userapi_login_tokens_expiration_idx ON userapi_login_tokens(token_expires_at); ` const insertLoginTokenSQL = "" + - "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" + "INSERT INTO userapi_login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" const deleteLoginTokenSQL = "" + - "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2" + "DELETE FROM userapi_login_tokens WHERE token = $1 OR token_expires_at <= $2" const selectLoginTokenSQL = "" + - "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2" + "SELECT user_id FROM userapi_login_tokens WHERE token = $1 AND token_expires_at > $2" type loginTokenStatements struct { insertStmt *sql.Stmt @@ -78,7 +78,7 @@ func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx // deleteByToken removes the named token. // // As a simple way to garbage-collect stale tokens, we also remove all expired tokens. -// The login_tokens_expiration_idx index should make that efficient. +// The userapi_login_tokens_expiration_idx index should make that efficient. func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error { stmt := sqlutil.TxStmt(txn, s.deleteStmt) res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) diff --git a/userapi/storage/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go index 29c3ddcb4..06ae30d08 100644 --- a/userapi/storage/postgres/openid_table.go +++ b/userapi/storage/postgres/openid_table.go @@ -13,7 +13,7 @@ import ( const openIDTokenSchema = ` -- Stores data about openid tokens issued for accounts. -CREATE TABLE IF NOT EXISTS open_id_tokens ( +CREATE TABLE IF NOT EXISTS userapi_openid_tokens ( -- The value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, -- The Matrix user ID for this account @@ -24,10 +24,10 @@ CREATE TABLE IF NOT EXISTS open_id_tokens ( ` const insertOpenIDTokenSQL = "" + - "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" const selectOpenIDTokenSQL = "" + - "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" + "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" type openIDTokenStatements struct { insertTokenStmt *sql.Stmt diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index 6d336eb8e..f686127be 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -27,7 +27,7 @@ import ( const profilesSchema = ` -- Stores data about accounts profiles. -CREATE TABLE IF NOT EXISTS account_profiles ( +CREATE TABLE IF NOT EXISTS userapi_profiles ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL PRIMARY KEY, -- The display name for this account @@ -38,19 +38,19 @@ CREATE TABLE IF NOT EXISTS account_profiles ( ` const insertProfileSQL = "" + - "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" const setAvatarURLSQL = "" + - "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" const setDisplayNameSQL = "" + - "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" type profilesStatements struct { serverNoticesLocalpart string diff --git a/userapi/storage/postgres/stats_table.go b/userapi/storage/postgres/stats_table.go index c0b317503..20eb0bf46 100644 --- a/userapi/storage/postgres/stats_table.go +++ b/userapi/storage/postgres/stats_table.go @@ -45,7 +45,7 @@ CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON usera const countUsersLastSeenAfterSQL = "" + "SELECT COUNT(*) FROM (" + - " SELECT localpart FROM device_devices WHERE last_seen_ts > $1 " + + " SELECT localpart FROM userapi_devices WHERE last_seen_ts > $1 " + " GROUP BY localpart" + " ) u" @@ -62,7 +62,7 @@ R30Users counts the number of 30 day retained users, defined as: const countR30UsersSQL = ` SELECT platform, COUNT(*) FROM ( SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts) - FROM account_accounts users + FROM userapi_accounts users INNER JOIN (SELECT localpart, last_seen_ts, @@ -75,7 +75,7 @@ SELECT platform, COUNT(*) FROM ( ELSE 'unknown' END AS platform - FROM device_devices + FROM userapi_devices ) uip ON users.localpart = uip.localpart AND users.account_type <> 4 @@ -121,7 +121,7 @@ GROUP BY client_type ` const countUserByAccountTypeSQL = ` -SELECT COUNT(*) FROM account_accounts WHERE account_type = ANY($1) +SELECT COUNT(*) FROM userapi_accounts WHERE account_type = ANY($1) ` // $1 = All non guest AccountType IDs @@ -134,7 +134,7 @@ SELECT user_type, COUNT(*) AS count FROM ( WHEN account_type = $2 AND appservice_id IS NULL THEN 'guest' WHEN account_type = ANY($1) AND appservice_id IS NOT NULL THEN 'bridged' END AS user_type - FROM account_accounts + FROM userapi_accounts WHERE created_ts > $3 ) AS t GROUP BY user_type ` @@ -143,14 +143,14 @@ SELECT user_type, COUNT(*) AS count FROM ( const updateUserDailyVisitsSQL = ` INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent) SELECT u.localpart, u.device_id, $1, MAX(u.user_agent) - FROM device_devices AS u + FROM userapi_devices AS u LEFT JOIN ( SELECT localpart, device_id, timestamp FROM userapi_daily_visits WHERE timestamp = $1 ) udv ON u.localpart = udv.localpart AND u.device_id = udv.device_id - INNER JOIN device_devices d ON d.localpart = u.localpart - INNER JOIN account_accounts a ON a.localpart = u.localpart + INNER JOIN userapi_devices d ON d.localpart = u.localpart + INNER JOIN userapi_accounts a ON a.localpart = u.localpart WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3 AND a.account_type in (1, 3) GROUP BY u.localpart, u.device_id diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 7d3b9b6a5..c059e3e60 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/userapi/storage/shared" // Import the postgres database driver. @@ -36,6 +37,16 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return nil, err } + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: rename tables", + Up: deltas.UpRenameTables, + Down: deltas.DownRenameTables, + }) + if err = m.Up(base.Context()); err != nil { + return nil, err + } + accountDataTable, err := NewPostgresAccountDataTable(db) if err != nil { return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err) diff --git a/userapi/storage/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go index 63c08d61f..11af76161 100644 --- a/userapi/storage/postgres/threepid_table.go +++ b/userapi/storage/postgres/threepid_table.go @@ -26,7 +26,7 @@ import ( const threepidSchema = ` -- Stores data about third party identifiers -CREATE TABLE IF NOT EXISTS account_threepid ( +CREATE TABLE IF NOT EXISTS userapi_threepids ( -- The third party identifier threepid TEXT NOT NULL, -- The 3PID medium @@ -37,20 +37,20 @@ CREATE TABLE IF NOT EXISTS account_threepid ( PRIMARY KEY(threepid, medium) ); -CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart); +CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart); ` const selectLocalpartForThreePIDSQL = "" + - "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" + "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" const selectThreePIDsForLocalpartSQL = "" + - "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" + "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1" const insertThreePIDSQL = "" + - "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" + "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)" const deleteThreePIDSQL = "" + - "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" + "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2" type threepidStatements struct { selectLocalpartForThreePIDStmt *sql.Stmt diff --git a/userapi/storage/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go index cfd8568a9..af12decb3 100644 --- a/userapi/storage/sqlite3/account_data_table.go +++ b/userapi/storage/sqlite3/account_data_table.go @@ -25,7 +25,7 @@ import ( const accountDataSchema = ` -- Stores data about accounts data. -CREATE TABLE IF NOT EXISTS account_data ( +CREATE TABLE IF NOT EXISTS userapi_account_datas ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL, -- The room ID for this data (empty string if not specific to a room) @@ -40,15 +40,15 @@ CREATE TABLE IF NOT EXISTS account_data ( ` const insertAccountDataSQL = ` - INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) + INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4) ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 ` const selectAccountDataSQL = "" + - "SELECT room_id, type, content FROM account_data WHERE localpart = $1" + "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1" const selectAccountDataByTypeSQL = "" + - "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" + "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3" type accountDataStatements struct { db *sql.DB diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index 484e90056..671c1aa04 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -32,7 +32,7 @@ import ( const accountsSchema = ` -- Stores data about accounts. -CREATE TABLE IF NOT EXISTS account_accounts ( +CREATE TABLE IF NOT EXISTS userapi_accounts ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL PRIMARY KEY, -- When this account was first created, as a unix timestamp (ms resolution). @@ -51,22 +51,22 @@ CREATE TABLE IF NOT EXISTS account_accounts ( ` const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" const updatePasswordSQL = "" + - "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" + "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" const deactivateAccountSQL = "" + - "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" + "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" + "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" const selectPasswordHashSQL = "" + - "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0" const selectNewNumericLocalpartSQL = "" + - "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0" + "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0" type accountsStatements struct { db *sql.DB diff --git a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go index e25efc695..9158cb365 100644 --- a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go +++ b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go @@ -8,8 +8,8 @@ import ( func UpIsActive(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` - ALTER TABLE account_accounts RENAME TO account_accounts_tmp; -CREATE TABLE account_accounts ( + ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; +CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, created_ts BIGINT NOT NULL, password_hash TEXT, @@ -17,13 +17,13 @@ CREATE TABLE account_accounts ( is_deactivated BOOLEAN DEFAULT 0 ); INSERT - INTO account_accounts ( + INTO userapi_accounts ( localpart, created_ts, password_hash, appservice_id ) SELECT localpart, created_ts, password_hash, appservice_id - FROM account_accounts_tmp + FROM userapi_accounts_tmp ; -DROP TABLE account_accounts_tmp;`) +DROP TABLE userapi_accounts_tmp;`) if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } @@ -32,21 +32,21 @@ DROP TABLE account_accounts_tmp;`) func DownIsActive(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` - ALTER TABLE account_accounts RENAME TO account_accounts_tmp; -CREATE TABLE account_accounts ( + ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; +CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, created_ts BIGINT NOT NULL, password_hash TEXT, appservice_id TEXT ); INSERT - INTO account_accounts ( + INTO userapi_accounts ( localpart, created_ts, password_hash, appservice_id ) SELECT localpart, created_ts, password_hash, appservice_id - FROM account_accounts_tmp + FROM userapi_accounts_tmp ; -DROP TABLE account_accounts_tmp;`) +DROP TABLE userapi_accounts_tmp;`) if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go index 7f7e95d2d..a9224db6b 100644 --- a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go @@ -8,8 +8,8 @@ import ( func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` - ALTER TABLE device_devices RENAME TO device_devices_tmp; - CREATE TABLE device_devices ( + ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp; + CREATE TABLE userapi_devices ( access_token TEXT PRIMARY KEY, session_id INTEGER, device_id TEXT , @@ -22,12 +22,12 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { UNIQUE (localpart, device_id) ); INSERT - INTO device_devices ( + INTO userapi_devices ( access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent ) SELECT access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', '' - FROM device_devices_tmp; - DROP TABLE device_devices_tmp;`) + FROM userapi_devices_tmp; + DROP TABLE userapi_devices_tmp;`) if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } @@ -36,8 +36,8 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` -ALTER TABLE device_devices RENAME TO device_devices_tmp; -CREATE TABLE IF NOT EXISTS device_devices ( +ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp; +CREATE TABLE IF NOT EXISTS userapi_devices ( access_token TEXT PRIMARY KEY, session_id INTEGER, device_id TEXT , @@ -47,12 +47,12 @@ CREATE TABLE IF NOT EXISTS device_devices ( UNIQUE (localpart, device_id) ); INSERT -INTO device_devices ( +INTO userapi_devices ( access_token, session_id, device_id, localpart, created_ts, display_name ) SELECT access_token, session_id, device_id, localpart, created_ts, display_name -FROM device_devices_tmp; -DROP TABLE device_devices_tmp;`) +FROM userapi_devices_tmp; +DROP TABLE userapi_devices_tmp;`) if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go index 46532698c..230bc1433 100644 --- a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go +++ b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go @@ -9,8 +9,8 @@ import ( func UpAddAccountType(ctx context.Context, tx *sql.Tx) error { // initially set every account to useraccount, change appservice and guest accounts afterwards // (user = 1, guest = 2, admin = 3, appservice = 4) - _, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts RENAME TO account_accounts_tmp; -CREATE TABLE account_accounts ( + _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; +CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, created_ts BIGINT NOT NULL, password_hash TEXT, @@ -19,15 +19,15 @@ CREATE TABLE account_accounts ( account_type INTEGER NOT NULL ); INSERT - INTO account_accounts ( + INTO userapi_accounts ( localpart, created_ts, password_hash, appservice_id, account_type ) SELECT localpart, created_ts, password_hash, appservice_id, 1 - FROM account_accounts_tmp + FROM userapi_accounts_tmp ; -UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> ''; -UPDATE account_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*'; -DROP TABLE account_accounts_tmp;`) +UPDATE userapi_accounts SET account_type = 4 WHERE appservice_id <> ''; +UPDATE userapi_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*'; +DROP TABLE userapi_accounts_tmp;`) if err != nil { return fmt.Errorf("failed to add column: %w", err) } @@ -35,7 +35,7 @@ DROP TABLE account_accounts_tmp;`) } func DownAddAccountType(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts DROP COLUMN account_type;`) + _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts DROP COLUMN account_type;`) if err != nil { return fmt.Errorf("failed to execute downgrade: %w", err) } diff --git a/userapi/storage/sqlite3/deltas/2022101711000000_rename_tables.go b/userapi/storage/sqlite3/deltas/2022101711000000_rename_tables.go new file mode 100644 index 000000000..4ca1dc475 --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022101711000000_rename_tables.go @@ -0,0 +1,109 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + "strings" +) + +var renameTableMappings = map[string]string{ + "account_accounts": "userapi_accounts", + "account_data": "userapi_account_datas", + "device_devices": "userapi_devices", + "account_e2e_room_keys": "userapi_key_backups", + "account_e2e_room_keys_versions": "userapi_key_backup_versions", + "login_tokens": "userapi_login_tokens", + "open_id_tokens": "userapi_openid_tokens", + "account_profiles": "userapi_profiles", + "account_threepid": "userapi_threepids", +} + +var renameIndicesMappings = map[string]string{ + "device_localpart_id_idx": "userapi_device_localpart_id_idx", + "e2e_room_keys_idx": "userapi_key_backups_idx", + "e2e_room_keys_versions_idx": "userapi_key_backups_versions_idx", + "account_e2e_room_keys_versions_idx": "userapi_key_backup_versions_idx", + "login_tokens_expiration_idx": "userapi_login_tokens_expiration_idx", + "account_threepid_localpart": "userapi_threepid_idx", +} + +func UpRenameTables(ctx context.Context, tx *sql.Tx) error { + for old, new := range renameTableMappings { + // SQLite has no "IF EXISTS" so check if the table exists. + var name string + if err := tx.QueryRowContext( + ctx, "SELECT name FROM sqlite_schema WHERE type = 'table' AND name = $1;", old, + ).Scan(&name); err != nil { + if err == sql.ErrNoRows { + continue + } + return err + } + q := fmt.Sprintf( + "ALTER TABLE %s RENAME TO %s;", old, new, + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", old, new, err) + } + } + for old, new := range renameIndicesMappings { + var query string + if err := tx.QueryRowContext( + ctx, "SELECT sql FROM sqlite_schema WHERE type = 'index' AND name = $1;", old, + ).Scan(&query); err != nil { + if err == sql.ErrNoRows { + continue + } + return err + } + query = strings.Replace(query, old, new, 1) + if _, err := tx.ExecContext(ctx, fmt.Sprintf("DROP INDEX %s;", old)); err != nil { + return fmt.Errorf("drop index %q to %q error: %w", old, new, err) + } + if _, err := tx.ExecContext(ctx, query); err != nil { + return fmt.Errorf("recreate index %q to %q error: %w", old, new, err) + } + } + return nil +} + +func DownRenameTables(ctx context.Context, tx *sql.Tx) error { + for old, new := range renameTableMappings { + // SQLite has no "IF EXISTS" so check if the table exists. + var name string + if err := tx.QueryRowContext( + ctx, "SELECT name FROM sqlite_schema WHERE type = 'table' AND name = $1;", new, + ).Scan(&name); err != nil { + if err == sql.ErrNoRows { + continue + } + return err + } + q := fmt.Sprintf( + "ALTER TABLE %s RENAME TO %s;", new, old, + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("rename table %q to %q error: %w", new, old, err) + } + } + for old, new := range renameIndicesMappings { + var query string + if err := tx.QueryRowContext( + ctx, "SELECT sql FROM sqlite_schema WHERE type = 'index' AND name = $1;", new, + ).Scan(&query); err != nil { + if err == sql.ErrNoRows { + continue + } + return err + } + query = strings.Replace(query, new, old, 1) + if _, err := tx.ExecContext(ctx, fmt.Sprintf("DROP INDEX %s;", new)); err != nil { + return fmt.Errorf("drop index %q to %q error: %w", new, old, err) + } + if _, err := tx.ExecContext(ctx, query); err != nil { + return fmt.Errorf("recreate index %q to %q error: %w", new, old, err) + } + } + return nil +} diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 27a7524d6..e53a08062 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -35,7 +35,7 @@ const devicesSchema = ` -- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1; -- Stores data about devices. -CREATE TABLE IF NOT EXISTS device_devices ( +CREATE TABLE IF NOT EXISTS userapi_devices ( access_token TEXT PRIMARY KEY, session_id INTEGER, device_id TEXT , @@ -51,38 +51,38 @@ CREATE TABLE IF NOT EXISTS device_devices ( ` const insertDeviceSQL = "" + - "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + + "INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" const selectDevicesCountSQL = "" + - "SELECT COUNT(access_token) FROM device_devices" + "SELECT COUNT(access_token) FROM userapi_devices" const selectDeviceByTokenSQL = "" + - "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" + "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + - "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" const deleteDeviceSQL = "" + - "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" + "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" + "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" const deleteDevicesSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" + "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + - "UPDATE device_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" type devicesStatements struct { db *sql.DB diff --git a/userapi/storage/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go index 81726edf9..7883ffb19 100644 --- a/userapi/storage/sqlite3/key_backup_table.go +++ b/userapi/storage/sqlite3/key_backup_table.go @@ -26,7 +26,7 @@ import ( ) const keyBackupTableSchema = ` -CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( +CREATE TABLE IF NOT EXISTS userapi_key_backups ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, session_id TEXT NOT NULL, @@ -37,31 +37,31 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( is_verified BOOLEAN NOT NULL, session_data TEXT NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); -CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); +CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON userapi_key_backups(user_id, room_id, session_id, version); +CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON userapi_key_backups(user_id, version); ` const insertBackupKeySQL = "" + - "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + + "INSERT INTO userapi_key_backups(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" const updateBackupKeySQL = "" + - "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + + "UPDATE userapi_key_backups SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8" const countKeysSQL = "" + - "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" const selectKeysSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" const selectKeysByRoomIDSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2 AND room_id = $3" const selectKeysByRoomIDAndSessionIDSQL = "" + - "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" type keyBackupStatements struct { diff --git a/userapi/storage/sqlite3/key_backup_version_table.go b/userapi/storage/sqlite3/key_backup_version_table.go index e85e6f08b..37bc13ed1 100644 --- a/userapi/storage/sqlite3/key_backup_version_table.go +++ b/userapi/storage/sqlite3/key_backup_version_table.go @@ -27,7 +27,7 @@ import ( const keyBackupVersionTableSchema = ` -- the metadata for each generation of encrypted e2e session backups -CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( +CREATE TABLE IF NOT EXISTS userapi_key_backup_versions ( user_id TEXT NOT NULL, -- this means no 2 users will ever have the same version of e2e session backups which strictly -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. @@ -38,26 +38,26 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( deleted INTEGER DEFAULT 0 NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_key_backup_versions_idx ON userapi_key_backup_versions(user_id, version); ` const insertKeyBackupSQL = "" + - "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" + "INSERT INTO userapi_key_backup_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" const updateKeyBackupAuthDataSQL = "" + - "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" + "UPDATE userapi_key_backup_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" const updateKeyBackupETagSQL = "" + - "UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3" + "UPDATE userapi_key_backup_versions SET etag = $1 WHERE user_id = $2 AND version = $3" const deleteKeyBackupSQL = "" + - "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" + "UPDATE userapi_key_backup_versions SET deleted=1 WHERE user_id = $1 AND version = $2" const selectKeyBackupSQL = "" + - "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + "SELECT algorithm, auth_data, etag, deleted FROM userapi_key_backup_versions WHERE user_id = $1 AND version = $2" const selectLatestVersionSQL = "" + - "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" + "SELECT MAX(version) FROM userapi_key_backup_versions WHERE user_id = $1" type keyBackupVersionStatements struct { insertKeyBackupStmt *sql.Stmt diff --git a/userapi/storage/sqlite3/logintoken_table.go b/userapi/storage/sqlite3/logintoken_table.go index 78d42029a..2abdcb95e 100644 --- a/userapi/storage/sqlite3/logintoken_table.go +++ b/userapi/storage/sqlite3/logintoken_table.go @@ -32,7 +32,7 @@ type loginTokenStatements struct { } const loginTokenSchema = ` -CREATE TABLE IF NOT EXISTS login_tokens ( +CREATE TABLE IF NOT EXISTS userapi_login_tokens ( -- The random value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, -- When the token expires @@ -43,17 +43,17 @@ CREATE TABLE IF NOT EXISTS login_tokens ( ); -- This index allows efficient garbage collection of expired tokens. -CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); +CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON userapi_login_tokens(token_expires_at); ` const insertLoginTokenSQL = "" + - "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" + "INSERT INTO userapi_login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" const deleteLoginTokenSQL = "" + - "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2" + "DELETE FROM userapi_login_tokens WHERE token = $1 OR token_expires_at <= $2" const selectLoginTokenSQL = "" + - "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2" + "SELECT user_id FROM userapi_login_tokens WHERE token = $1 AND token_expires_at > $2" func NewSQLiteLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) { s := &loginTokenStatements{} @@ -78,7 +78,7 @@ func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx // deleteByToken removes the named token. // // As a simple way to garbage-collect stale tokens, we also remove all expired tokens. -// The login_tokens_expiration_idx index should make that efficient. +// The userapi_login_tokens_expiration_idx index should make that efficient. func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error { stmt := sqlutil.TxStmt(txn, s.deleteStmt) res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) diff --git a/userapi/storage/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go index d6090e0da..875f1a9a5 100644 --- a/userapi/storage/sqlite3/openid_table.go +++ b/userapi/storage/sqlite3/openid_table.go @@ -13,7 +13,7 @@ import ( const openIDTokenSchema = ` -- Stores data about accounts. -CREATE TABLE IF NOT EXISTS open_id_tokens ( +CREATE TABLE IF NOT EXISTS userapi_openid_tokens ( -- The value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, -- The Matrix user ID for this account @@ -24,10 +24,10 @@ CREATE TABLE IF NOT EXISTS open_id_tokens ( ` const insertOpenIDTokenSQL = "" + - "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" const selectOpenIDTokenSQL = "" + - "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" + "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" type openIDTokenStatements struct { db *sql.DB diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index 3050ff4b5..267daf044 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -27,7 +27,7 @@ import ( const profilesSchema = ` -- Stores data about accounts profiles. -CREATE TABLE IF NOT EXISTS account_profiles ( +CREATE TABLE IF NOT EXISTS userapi_profiles ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL PRIMARY KEY, -- The display name for this account @@ -38,19 +38,19 @@ CREATE TABLE IF NOT EXISTS account_profiles ( ` const insertProfileSQL = "" + - "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" const setAvatarURLSQL = "" + - "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" const setDisplayNameSQL = "" + - "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" type profilesStatements struct { db *sql.DB diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go index 8aa1746c5..35e3c653e 100644 --- a/userapi/storage/sqlite3/stats_table.go +++ b/userapi/storage/sqlite3/stats_table.go @@ -46,7 +46,7 @@ CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON usera const countUsersLastSeenAfterSQL = "" + "SELECT COUNT(*) FROM (" + - " SELECT localpart FROM device_devices WHERE last_seen_ts > $1 " + + " SELECT localpart FROM userapi_devices WHERE last_seen_ts > $1 " + " GROUP BY localpart" + " ) u" @@ -63,7 +63,7 @@ R30Users counts the number of 30 day retained users, defined as: const countR30UsersSQL = ` SELECT platform, COUNT(*) FROM ( SELECT users.localpart, platform, users.created_ts, MAX(uip.last_seen_ts) - FROM account_accounts users + FROM userapi_accounts users INNER JOIN (SELECT localpart, last_seen_ts, @@ -76,7 +76,7 @@ SELECT platform, COUNT(*) FROM ( ELSE 'unknown' END AS platform - FROM device_devices + FROM userapi_devices ) uip ON users.localpart = uip.localpart AND users.account_type <> 4 @@ -126,7 +126,7 @@ GROUP BY client_type ` const countUserByAccountTypeSQL = ` -SELECT COUNT(*) FROM account_accounts WHERE account_type IN ($1) +SELECT COUNT(*) FROM userapi_accounts WHERE account_type IN ($1) ` // $1 = Guest AccountType @@ -139,7 +139,7 @@ SELECT user_type, COUNT(*) AS count FROM ( WHEN account_type = $4 AND appservice_id IS NULL THEN 'guest' WHEN account_type IN ($5) AND appservice_id IS NOT NULL THEN 'bridged' END AS user_type - FROM account_accounts + FROM userapi_accounts WHERE created_ts > $8 ) AS t GROUP BY user_type ` @@ -148,14 +148,14 @@ SELECT user_type, COUNT(*) AS count FROM ( const updateUserDailyVisitsSQL = ` INSERT INTO userapi_daily_visits(localpart, device_id, timestamp, user_agent) SELECT u.localpart, u.device_id, $1, MAX(u.user_agent) - FROM device_devices AS u + FROM userapi_devices AS u LEFT JOIN ( SELECT localpart, device_id, timestamp FROM userapi_daily_visits WHERE timestamp = $1 ) udv ON u.localpart = udv.localpart AND u.device_id = udv.device_id - INNER JOIN device_devices d ON d.localpart = u.localpart - INNER JOIN account_accounts a ON a.localpart = u.localpart + INNER JOIN userapi_devices d ON d.localpart = u.localpart + INNER JOIN userapi_accounts a ON a.localpart = u.localpart WHERE $2 <= d.last_seen_ts AND d.last_seen_ts < $3 AND a.account_type in (1, 3) GROUP BY u.localpart, u.device_id diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 78b7ce588..dd33dc0cf 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/shared" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" ) // NewDatabase creates a new accounts and profiles database @@ -34,6 +35,16 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return nil, err } + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: rename tables", + Up: deltas.UpRenameTables, + Down: deltas.DownRenameTables, + }) + if err = m.Up(base.Context()); err != nil { + return nil, err + } + accountDataTable, err := NewSQLiteAccountDataTable(db) if err != nil { return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) diff --git a/userapi/storage/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go index fa174eed5..73af139db 100644 --- a/userapi/storage/sqlite3/threepid_table.go +++ b/userapi/storage/sqlite3/threepid_table.go @@ -27,7 +27,7 @@ import ( const threepidSchema = ` -- Stores data about third party identifiers -CREATE TABLE IF NOT EXISTS account_threepid ( +CREATE TABLE IF NOT EXISTS userapi_threepids ( -- The third party identifier threepid TEXT NOT NULL, -- The 3PID medium @@ -38,20 +38,20 @@ CREATE TABLE IF NOT EXISTS account_threepid ( PRIMARY KEY(threepid, medium) ); -CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart); +CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart); ` const selectLocalpartForThreePIDSQL = "" + - "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" + "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" const selectThreePIDsForLocalpartSQL = "" + - "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" + "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1" const insertThreePIDSQL = "" + - "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" + "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)" const deleteThreePIDSQL = "" + - "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" + "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2" type threepidStatements struct { db *sql.DB diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 1538a8138..8e5b32b6a 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -16,6 +16,7 @@ import ( "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" @@ -29,14 +30,18 @@ var ( ) func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{ + db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") if err != nil { t.Fatalf("NewUserAPIDatabase returned %s", err) } - return db, close + return db, func() { + close() + baseclose() + } } // Tests storing and getting account data diff --git a/userapi/storage/tables/stats_table_test.go b/userapi/storage/tables/stats_table_test.go index 11521c8b0..c4aec552c 100644 --- a/userapi/storage/tables/stats_table_test.go +++ b/userapi/storage/tables/stats_table_test.go @@ -106,7 +106,7 @@ func mustUpdateDeviceLastSeen( timestamp time.Time, ) { t.Helper() - _, err := db.ExecContext(ctx, "UPDATE device_devices SET last_seen_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) + _, err := db.ExecContext(ctx, "UPDATE userapi_devices SET last_seen_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) if err != nil { t.Fatalf("unable to update device last seen") } @@ -119,7 +119,7 @@ func mustUserUpdateRegistered( localpart string, timestamp time.Time, ) { - _, err := db.ExecContext(ctx, "UPDATE account_accounts SET created_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) + _, err := db.ExecContext(ctx, "UPDATE userapi_accounts SET created_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) if err != nil { t.Fatalf("unable to update device last seen") } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 984fe8854..4417f4dc0 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -25,6 +25,7 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/gomatrixserverlib" @@ -48,9 +49,9 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap if opts.loginTokenLifetime == 0 { opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond } + base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - - accountDB, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{ + accountDB, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") if err != nil { @@ -64,9 +65,12 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap } return &internal.UserInternalAPI{ - DB: accountDB, - ServerName: cfg.Matrix.ServerName, - }, accountDB, close + DB: accountDB, + ServerName: cfg.Matrix.ServerName, + }, accountDB, func() { + close() + baseclose() + } } func TestQueryProfile(t *testing.T) { From 241d5c47dfa9e5cfadc350f688aab30f9e539fbb Mon Sep 17 00:00:00 2001 From: devonh Date: Wed, 19 Oct 2022 10:03:16 +0000 Subject: [PATCH 06/22] Refactor Federation Destination Queues (#2807) This is a refactor of the federation destination queues. It fixes a few things, namely: - actually retry outgoing events with backoff behaviour - obtain enough events from the database to fill messages as much as possible - minimize the amount of running goroutines - use pure timers for backoff - don't restart queue unless necessary - close the background task when backing off - increase max edus in a transaction to match the spec - cleanup timers more aggresively to reduce memory usage - add jitter to backoff timers to reduce resource spikes - add a bunch of tests (with real and fake databases) to ensure everything is working --- federationapi/federationapi.go | 9 +- federationapi/queue/destinationqueue.go | 374 ++++--- federationapi/queue/queue.go | 26 +- federationapi/queue/queue_test.go | 1047 +++++++++++++++++++ federationapi/statistics/statistics.go | 131 ++- federationapi/statistics/statistics_test.go | 19 +- federationapi/storage/shared/storage.go | 4 + go.mod | 2 +- 8 files changed, 1410 insertions(+), 202 deletions(-) create mode 100644 federationapi/queue/queue_test.go diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 4a13c9d9b..f6dace702 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -116,17 +116,14 @@ func NewInternalAPI( _ = federationDB.RemoveAllServersFromBlacklist() } - stats := &statistics.Statistics{ - DB: federationDB, - FailuresUntilBlacklist: cfg.FederationMaxRetries, - } + stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1) js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) queues := queue.NewOutgoingQueues( federationDB, base.ProcessContext, cfg.Matrix.DisableFederation, - cfg.Matrix.ServerName, federation, rsAPI, stats, + cfg.Matrix.ServerName, federation, rsAPI, &stats, &queue.SigningInfo{ KeyID: cfg.Matrix.KeyID, PrivateKey: cfg.Matrix.PrivateKey, @@ -183,5 +180,5 @@ func NewInternalAPI( } time.AfterFunc(time.Minute, cleanExpiredEDUs) - return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, stats, caches, queues, keyRing) + return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, &stats, caches, queues, keyRing) } diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 5cb8cae1f..00e02b2d9 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -35,7 +35,7 @@ import ( const ( maxPDUsPerTransaction = 50 - maxEDUsPerTransaction = 50 + maxEDUsPerTransaction = 100 maxPDUsInMemory = 128 maxEDUsInMemory = 128 queueIdleTimeout = time.Second * 30 @@ -64,7 +64,6 @@ type destinationQueue struct { pendingPDUs []*queuedPDU // PDUs waiting to be sent pendingEDUs []*queuedEDU // EDUs waiting to be sent pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs - interruptBackoff chan bool // interrupts backoff } // Send event adds the event to the pending queue for the destination. @@ -75,6 +74,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination) return } + // Create a database entry that associates the given PDU NID with // this destination queue. We'll then be able to retrieve the PDU // later. @@ -102,12 +102,12 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re oq.overflowed.Store(true) } oq.pendingMutex.Unlock() - // Wake up the queue if it's asleep. - oq.wakeQueueIfNeeded() - select { - case oq.notify <- struct{}{}: - default: + + if !oq.backingOff.Load() { + oq.wakeQueueAndNotify() } + } else { + oq.overflowed.Store(true) } } @@ -147,12 +147,37 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share oq.overflowed.Store(true) } oq.pendingMutex.Unlock() - // Wake up the queue if it's asleep. - oq.wakeQueueIfNeeded() - select { - case oq.notify <- struct{}{}: - default: + + if !oq.backingOff.Load() { + oq.wakeQueueAndNotify() } + } else { + oq.overflowed.Store(true) + } +} + +// handleBackoffNotifier is registered as the backoff notification +// callback with Statistics. It will wakeup and notify the queue +// if the queue is currently backing off. +func (oq *destinationQueue) handleBackoffNotifier() { + // Only wake up the queue if it is backing off. + // Otherwise there is no pending work for the queue to handle + // so waking the queue would be a waste of resources. + if oq.backingOff.Load() { + oq.wakeQueueAndNotify() + } +} + +// wakeQueueAndNotify ensures the destination queue is running and notifies it +// that there is pending work. +func (oq *destinationQueue) wakeQueueAndNotify() { + // Wake up the queue if it's asleep. + oq.wakeQueueIfNeeded() + + // Notify the queue that there are events ready to send. + select { + case oq.notify <- struct{}{}: + default: } } @@ -161,10 +186,11 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share // then we will interrupt the backoff, causing any federation // requests to retry. func (oq *destinationQueue) wakeQueueIfNeeded() { - // If we are backing off then interrupt the backoff. + // Clear the backingOff flag and update the backoff metrics if it was set. if oq.backingOff.CompareAndSwap(true, false) { - oq.interruptBackoff <- true + destinationQueueBackingOff.Dec() } + // If we aren't running then wake up the queue. if !oq.running.Load() { // Start the queue. @@ -196,38 +222,54 @@ func (oq *destinationQueue) getPendingFromDatabase() { gotEDUs[edu.receipt.String()] = struct{}{} } + overflowed := false if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 { // We have room in memory for some PDUs - let's request no more than that. - if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil { + if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, maxPDUsInMemory); err == nil { + if len(pdus) == maxPDUsInMemory { + overflowed = true + } for receipt, pdu := range pdus { if _, ok := gotPDUs[receipt.String()]; ok { continue } oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu}) retrieved = true + if len(oq.pendingPDUs) == maxPDUsInMemory { + break + } } } else { logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination) } } + if eduCapacity := maxEDUsInMemory - len(oq.pendingEDUs); eduCapacity > 0 { // We have room in memory for some EDUs - let's request no more than that. - if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, eduCapacity); err == nil { + if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, maxEDUsInMemory); err == nil { + if len(edus) == maxEDUsInMemory { + overflowed = true + } for receipt, edu := range edus { if _, ok := gotEDUs[receipt.String()]; ok { continue } oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu}) retrieved = true + if len(oq.pendingEDUs) == maxEDUsInMemory { + break + } } } else { logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination) } } + // If we've retrieved all of the events from the database with room to spare // in memory then we'll no longer consider this queue to be overflowed. - if len(oq.pendingPDUs) < maxPDUsInMemory && len(oq.pendingEDUs) < maxEDUsInMemory { + if !overflowed { oq.overflowed.Store(false) + } else { } // If we've retrieved some events then notify the destination queue goroutine. if retrieved { @@ -238,6 +280,24 @@ func (oq *destinationQueue) getPendingFromDatabase() { } } +// checkNotificationsOnClose checks for any remaining notifications +// and starts a new backgroundSend goroutine if any exist. +func (oq *destinationQueue) checkNotificationsOnClose() { + // NOTE : If we are stopping the queue due to blacklist then it + // doesn't matter if we have been notified of new work since + // this queue instance will be deleted anyway. + if !oq.statistics.Blacklisted() { + select { + case <-oq.notify: + // We received a new notification in between the + // idle timeout firing and stopping the goroutine. + // Immediately restart the queue. + oq.wakeQueueAndNotify() + default: + } + } +} + // backgroundSend is the worker goroutine for sending events. func (oq *destinationQueue) backgroundSend() { // Check if a worker is already running, and if it isn't, then @@ -245,10 +305,17 @@ func (oq *destinationQueue) backgroundSend() { if !oq.running.CompareAndSwap(false, true) { return } + + // Register queue cleanup functions. + // NOTE : The ordering here is very intentional. + defer oq.checkNotificationsOnClose() + defer oq.running.Store(false) + destinationQueueRunning.Inc() defer destinationQueueRunning.Dec() - defer oq.queues.clearQueue(oq) - defer oq.running.Store(false) + + idleTimeout := time.NewTimer(queueIdleTimeout) + defer idleTimeout.Stop() // Mark the queue as overflowed, so we will consult the database // to see if there's anything new to send. @@ -261,59 +328,33 @@ func (oq *destinationQueue) backgroundSend() { oq.getPendingFromDatabase() } + // Reset the queue idle timeout. + if !idleTimeout.Stop() { + select { + case <-idleTimeout.C: + default: + } + } + idleTimeout.Reset(queueIdleTimeout) + // If we have nothing to do then wait either for incoming events, or // until we hit an idle timeout. select { case <-oq.notify: // There's work to do, either because getPendingFromDatabase - // told us there is, or because a new event has come in via - // sendEvent/sendEDU. - case <-time.After(queueIdleTimeout): + // told us there is, a new event has come in via sendEvent/sendEDU, + // or we are backing off and it is time to retry. + case <-idleTimeout.C: // The worker is idle so stop the goroutine. It'll get // restarted automatically the next time we have an event to // send. return case <-oq.process.Context().Done(): // The parent process is shutting down, so stop. + oq.statistics.ClearBackoff() return } - // If we are backing off this server then wait for the - // backoff duration to complete first, or until explicitly - // told to retry. - until, blacklisted := oq.statistics.BackoffInfo() - if blacklisted { - // It's been suggested that we should give up because the backoff - // has exceeded a maximum allowable value. Clean up the in-memory - // buffers at this point. The PDU clean-up is already on a defer. - logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) - oq.pendingMutex.Lock() - for i := range oq.pendingPDUs { - oq.pendingPDUs[i] = nil - } - for i := range oq.pendingEDUs { - oq.pendingEDUs[i] = nil - } - oq.pendingPDUs = nil - oq.pendingEDUs = nil - oq.pendingMutex.Unlock() - return - } - if until != nil && until.After(time.Now()) { - // We haven't backed off yet, so wait for the suggested amount of - // time. - duration := time.Until(*until) - logrus.Debugf("Backing off %q for %s", oq.destination, duration) - oq.backingOff.Store(true) - destinationQueueBackingOff.Inc() - select { - case <-time.After(duration): - case <-oq.interruptBackoff: - } - destinationQueueBackingOff.Dec() - oq.backingOff.Store(false) - } - // Work out which PDUs/EDUs to include in the next transaction. oq.pendingMutex.RLock() pduCount := len(oq.pendingPDUs) @@ -328,99 +369,52 @@ func (oq *destinationQueue) backgroundSend() { toSendEDUs := oq.pendingEDUs[:eduCount] oq.pendingMutex.RUnlock() + // If we didn't get anything from the database and there are no + // pending EDUs then there's nothing to do - stop here. + if pduCount == 0 && eduCount == 0 { + continue + } + // If we have pending PDUs or EDUs then construct a transaction. // Try sending the next transaction and see what happens. - transaction, pc, ec, terr := oq.nextTransaction(toSendPDUs, toSendEDUs) + terr := oq.nextTransaction(toSendPDUs, toSendEDUs) if terr != nil { // We failed to send the transaction. Mark it as a failure. - oq.statistics.Failure() - - } else if transaction { - // If we successfully sent the transaction then clear out - // the pending events and EDUs, and wipe our transaction ID. - oq.statistics.Success() - oq.pendingMutex.Lock() - for i := range oq.pendingPDUs[:pc] { - oq.pendingPDUs[i] = nil + _, blacklisted := oq.statistics.Failure() + if !blacklisted { + // Register the backoff state and exit the goroutine. + // It'll get restarted automatically when the backoff + // completes. + oq.backingOff.Store(true) + destinationQueueBackingOff.Inc() + return + } else { + // Immediately trigger the blacklist logic. + oq.blacklistDestination() + return } - for i := range oq.pendingEDUs[:ec] { - oq.pendingEDUs[i] = nil - } - oq.pendingPDUs = oq.pendingPDUs[pc:] - oq.pendingEDUs = oq.pendingEDUs[ec:] - oq.pendingMutex.Unlock() + } else { + oq.handleTransactionSuccess(pduCount, eduCount) } } } // nextTransaction creates a new transaction from the pending event -// queue and sends it. Returns true if a transaction was sent or -// false otherwise. +// queue and sends it. +// Returns an error if the transaction wasn't sent. func (oq *destinationQueue) nextTransaction( pdus []*queuedPDU, edus []*queuedEDU, -) (bool, int, int, error) { - // If there's no projected transaction ID then generate one. If - // the transaction succeeds then we'll set it back to "" so that - // we generate a new one next time. If it fails, we'll preserve - // it so that we retry with the same transaction ID. - oq.transactionIDMutex.Lock() - if oq.transactionID == "" { - now := gomatrixserverlib.AsTimestamp(time.Now()) - oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) - } - oq.transactionIDMutex.Unlock() - +) error { // Create the transaction. - t := gomatrixserverlib.Transaction{ - PDUs: []json.RawMessage{}, - EDUs: []gomatrixserverlib.EDU{}, - } - t.Origin = oq.origin - t.Destination = oq.destination - t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) - t.TransactionID = oq.transactionID - - // If we didn't get anything from the database and there are no - // pending EDUs then there's nothing to do - stop here. - if len(pdus) == 0 && len(edus) == 0 { - return false, 0, 0, nil - } - - var pduReceipts []*shared.Receipt - var eduReceipts []*shared.Receipt - - // Go through PDUs that we retrieved from the database, if any, - // and add them into the transaction. - for _, pdu := range pdus { - if pdu == nil || pdu.pdu == nil { - continue - } - // Append the JSON of the event, since this is a json.RawMessage type in the - // gomatrixserverlib.Transaction struct - t.PDUs = append(t.PDUs, pdu.pdu.JSON()) - pduReceipts = append(pduReceipts, pdu.receipt) - } - - // Do the same for pending EDUS in the queue. - for _, edu := range edus { - if edu == nil || edu.edu == nil { - continue - } - t.EDUs = append(t.EDUs, *edu.edu) - eduReceipts = append(eduReceipts, edu.receipt) - } - + t, pduReceipts, eduReceipts := oq.createTransaction(pdus, edus) logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) // Try to send the transaction to the destination server. - // TODO: we should check for 500-ish fails vs 400-ish here, - // since we shouldn't queue things indefinitely in response - // to a 400-ish error ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5) defer cancel() _, err := oq.client.SendTransaction(ctx, t) - switch err.(type) { + switch errResponse := err.(type) { case nil: // Clean up the transaction in the database. if pduReceipts != nil { @@ -439,16 +433,128 @@ func (oq *destinationQueue) nextTransaction( oq.transactionIDMutex.Lock() oq.transactionID = "" oq.transactionIDMutex.Unlock() - return true, len(t.PDUs), len(t.EDUs), nil + return nil case gomatrix.HTTPError: // Report that we failed to send the transaction and we // will retry again, subject to backoff. - return false, 0, 0, err + + // TODO: we should check for 500-ish fails vs 400-ish here, + // since we shouldn't queue things indefinitely in response + // to a 400-ish error + code := errResponse.Code + logrus.Debug("Transaction failed with HTTP", code) + return err default: logrus.WithFields(logrus.Fields{ "destination": oq.destination, logrus.ErrorKey: err, }).Debugf("Failed to send transaction %q", t.TransactionID) - return false, 0, 0, err + return err + } +} + +// createTransaction generates a gomatrixserverlib.Transaction from the provided pdus and edus. +// It also returns the associated event receipts so they can be cleaned from the database in +// the case of a successful transaction. +func (oq *destinationQueue) createTransaction( + pdus []*queuedPDU, + edus []*queuedEDU, +) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) { + // If there's no projected transaction ID then generate one. If + // the transaction succeeds then we'll set it back to "" so that + // we generate a new one next time. If it fails, we'll preserve + // it so that we retry with the same transaction ID. + oq.transactionIDMutex.Lock() + if oq.transactionID == "" { + now := gomatrixserverlib.AsTimestamp(time.Now()) + oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) + } + oq.transactionIDMutex.Unlock() + + t := gomatrixserverlib.Transaction{ + PDUs: []json.RawMessage{}, + EDUs: []gomatrixserverlib.EDU{}, + } + t.Origin = oq.origin + t.Destination = oq.destination + t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) + t.TransactionID = oq.transactionID + + var pduReceipts []*shared.Receipt + var eduReceipts []*shared.Receipt + + // Go through PDUs that we retrieved from the database, if any, + // and add them into the transaction. + for _, pdu := range pdus { + // These should never be nil. + if pdu == nil || pdu.pdu == nil { + continue + } + // Append the JSON of the event, since this is a json.RawMessage type in the + // gomatrixserverlib.Transaction struct + t.PDUs = append(t.PDUs, pdu.pdu.JSON()) + pduReceipts = append(pduReceipts, pdu.receipt) + } + + // Do the same for pending EDUS in the queue. + for _, edu := range edus { + // These should never be nil. + if edu == nil || edu.edu == nil { + continue + } + t.EDUs = append(t.EDUs, *edu.edu) + eduReceipts = append(eduReceipts, edu.receipt) + } + + return t, pduReceipts, eduReceipts +} + +// blacklistDestination removes all pending PDUs and EDUs that have been cached +// and deletes this queue. +func (oq *destinationQueue) blacklistDestination() { + // It's been suggested that we should give up because the backoff + // has exceeded a maximum allowable value. Clean up the in-memory + // buffers at this point. The PDU clean-up is already on a defer. + logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) + + oq.pendingMutex.Lock() + for i := range oq.pendingPDUs { + oq.pendingPDUs[i] = nil + } + for i := range oq.pendingEDUs { + oq.pendingEDUs[i] = nil + } + oq.pendingPDUs = nil + oq.pendingEDUs = nil + oq.pendingMutex.Unlock() + + // Delete this queue as no more messages will be sent to this + // destination until it is no longer blacklisted. + oq.statistics.AssignBackoffNotifier(nil) + oq.queues.clearQueue(oq) +} + +// handleTransactionSuccess updates the cached event queues as well as the success and +// backoff information for this server. +func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) { + // If we successfully sent the transaction then clear out + // the pending events and EDUs, and wipe our transaction ID. + oq.statistics.Success() + oq.pendingMutex.Lock() + for i := range oq.pendingPDUs[:pduCount] { + oq.pendingPDUs[i] = nil + } + for i := range oq.pendingEDUs[:eduCount] { + oq.pendingEDUs[i] = nil + } + oq.pendingPDUs = oq.pendingPDUs[pduCount:] + oq.pendingEDUs = oq.pendingEDUs[eduCount:] + oq.pendingMutex.Unlock() + + if len(oq.pendingPDUs) > 0 || len(oq.pendingEDUs) > 0 { + select { + case oq.notify <- struct{}{}: + default: + } } } diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 8245aa5bd..68f789e37 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -162,23 +162,25 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d if !ok || oq == nil { destinationQueueTotal.Inc() oq = &destinationQueue{ - queues: oqs, - db: oqs.db, - process: oqs.process, - rsAPI: oqs.rsAPI, - origin: oqs.origin, - destination: destination, - client: oqs.client, - statistics: oqs.statistics.ForServer(destination), - notify: make(chan struct{}, 1), - interruptBackoff: make(chan bool), - signing: oqs.signing, + queues: oqs, + db: oqs.db, + process: oqs.process, + rsAPI: oqs.rsAPI, + origin: oqs.origin, + destination: destination, + client: oqs.client, + statistics: oqs.statistics.ForServer(destination), + notify: make(chan struct{}, 1), + signing: oqs.signing, } + oq.statistics.AssignBackoffNotifier(oq.handleBackoffNotifier) oqs.queues[destination] = oq } return oq } +// clearQueue removes the queue for the provided destination from the +// set of destination queues. func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) { oqs.queuesMutex.Lock() defer oqs.queuesMutex.Unlock() @@ -332,7 +334,9 @@ func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { if oqs.disabled { return } + oqs.statistics.ForServer(srv).RemoveBlacklist() if queue := oqs.getQueue(srv); queue != nil { + queue.statistics.ClearBackoff() queue.wakeQueueIfNeeded() } } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go new file mode 100644 index 000000000..6da9e6b30 --- /dev/null +++ b/federationapi/queue/queue_test.go @@ -0,0 +1,1047 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "testing" + "time" + + "go.uber.org/atomic" + "gotest.tools/v3/poll" + + "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/statistics" + "github.com/matrix-org/dendrite/federationapi/storage" + "github.com/matrix-org/dendrite/federationapi/storage/shared" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *process.ProcessContext, func()) { + if realDatabase { + // Real Database/s + b, baseClose := testrig.CreateBaseDendrite(t, dbType) + connStr, dbClose := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewDatabase(b, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, b.Caches, b.Cfg.Global.ServerName) + if err != nil { + t.Fatalf("NewDatabase returned %s", err) + } + return db, b.ProcessContext, func() { + dbClose() + baseClose() + } + } else { + // Fake Database + db := createDatabase() + b := struct { + ProcessContext *process.ProcessContext + }{ProcessContext: process.NewProcessContext()} + return db, b.ProcessContext, func() {} + } +} + +func createDatabase() storage.Database { + return &fakeDatabase{ + pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), + pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent), + pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU), + associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), + associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), + } +} + +type fakeDatabase struct { + storage.Database + dbMutex sync.Mutex + pendingPDUServers map[gomatrixserverlib.ServerName]struct{} + pendingEDUServers map[gomatrixserverlib.ServerName]struct{} + blacklistedServers map[gomatrixserverlib.ServerName]struct{} + pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent + pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU + associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} + associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} +} + +var nidMutex sync.Mutex +var nid = int64(0) + +func (d *fakeDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var event gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal([]byte(js), &event); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + receipt := shared.NewReceipt(nid) + d.pendingPDUs[&receipt] = &event + return &receipt, nil + } + + var edu gomatrixserverlib.EDU + if err := json.Unmarshal([]byte(js), &edu); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + receipt := shared.NewReceipt(nid) + d.pendingEDUs[&receipt] = &edu + return &receipt, nil + } + + return nil, errors.New("Failed to determine type of json to store") +} + +func (d *fakeDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + pduCount := 0 + pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent) + if receipts, ok := d.associatedPDUs[serverName]; ok { + for receipt := range receipts { + if event, ok := d.pendingPDUs[receipt]; ok { + pdus[receipt] = event + pduCount++ + if pduCount == limit { + break + } + } + } + } + return pdus, nil +} + +func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + eduCount := 0 + edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU) + if receipts, ok := d.associatedEDUs[serverName]; ok { + for receipt := range receipts { + if event, ok := d.pendingEDUs[receipt]; ok { + edus[receipt] = event + eduCount++ + if eduCount == limit { + break + } + } + } + } + return edus, nil +} + +func (d *fakeDatabase) AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingPDUs[receipt]; ok { + if _, ok := d.associatedPDUs[serverName]; !ok { + d.associatedPDUs[serverName] = make(map[*shared.Receipt]struct{}) + } + d.associatedPDUs[serverName][receipt] = struct{}{} + return nil + } else { + return errors.New("PDU doesn't exist") + } +} + +func (d *fakeDatabase) AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingEDUs[receipt]; ok { + if _, ok := d.associatedEDUs[serverName]; !ok { + d.associatedEDUs[serverName] = make(map[*shared.Receipt]struct{}) + } + d.associatedEDUs[serverName][receipt] = struct{}{} + return nil + } else { + return errors.New("EDU doesn't exist") + } +} + +func (d *fakeDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if pdus, ok := d.associatedPDUs[serverName]; ok { + for _, receipt := range receipts { + delete(pdus, receipt) + } + } + + return nil +} + +func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if edus, ok := d.associatedEDUs[serverName]; ok { + for _, receipt := range receipts { + delete(edus, receipt) + } + } + + return nil +} + +func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if pdus, ok := d.associatedPDUs[serverName]; ok { + count = int64(len(pdus)) + } + return count, nil +} + +func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if edus, ok := d.associatedEDUs[serverName]; ok { + count = int64(len(edus)) + } + return count, nil +} + +func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingPDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *fakeDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingEDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *fakeDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers[serverName] = struct{}{} + return nil +} + +func (d *fakeDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + delete(d.blacklistedServers, serverName) + return nil +} + +func (d *fakeDatabase) RemoveAllServersFromBlacklist() error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) + return nil +} + +func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + isBlacklisted := false + if _, ok := d.blacklistedServers[serverName]; ok { + isBlacklisted = true + } + + return isBlacklisted, nil +} + +type stubFederationRoomServerAPI struct { + rsapi.FederationRoomserverAPI +} + +func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Context, req *rsapi.QueryServerBannedFromRoomRequest, res *rsapi.QueryServerBannedFromRoomResponse) error { + res.Banned = false + return nil +} + +type stubFederationClient struct { + api.FederationClient + shouldTxSucceed bool + txCount atomic.Uint32 +} + +func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) { + var result error + if !f.shouldTxSucceed { + result = fmt.Errorf("transaction failed") + } + + f.txCount.Add(1) + return gomatrixserverlib.RespSend{}, result +} + +func mustCreatePDU(t *testing.T) *gomatrixserverlib.HeaderedEvent { + t.Helper() + content := `{"type":"m.room.message"}` + ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(content), false, gomatrixserverlib.RoomVersionV10) + if err != nil { + t.Fatalf("failed to create event: %v", err) + } + return ev.Headered(gomatrixserverlib.RoomVersionV10) +} + +func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU { + t.Helper() + return &gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping} +} + +func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { + db, processContext, close := mustCreateFederationDatabase(t, dbType, realDatabase) + + fc := &stubFederationClient{ + shouldTxSucceed: shouldTxSucceed, + txCount: *atomic.NewUint32(0), + } + rs := &stubFederationRoomServerAPI{} + stats := statistics.NewStatistics(db, failuresUntilBlacklist) + signingInfo := &SigningInfo{ + KeyID: "ed21019:auto", + PrivateKey: test.PrivateKeyA, + ServerName: "localhost", + } + queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo) + + return db, fc, queues, processContext, close +} + +func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == 1 { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == 1 { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUOnFailStoredInDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending + if fc.txCount.Load() >= 2 { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + return poll.Success() + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUOnFailStoredInDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending + if fc.txCount.Load() >= 2 { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + return poll.Success() + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending + if fc.txCount.Load() >= 2 { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + return poll.Success() + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + ev = mustCreatePDU(t) + err = queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + pollEnd := time.Now().Add(1 * time.Second) + immediateCheck := func(log poll.LogT) poll.Result { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Error(fmt.Errorf("The backoff was interrupted early")) + } + if time.Now().After(pollEnd) { + // Allow more than enough time for the backoff to be interrupted before + // reporting that it wasn't. + return poll.Success() + } + return poll.Continue("waiting for events to be removed from database. Currently present PDU: %d", len(data)) + } + poll.WaitOn(t, immediateCheck, poll.WithTimeout(2*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending + if fc.txCount.Load() >= 2 { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + return poll.Success() + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + ev = mustCreateEDU(t) + err = queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + pollEnd := time.Now().Add(1 * time.Second) + immediateCheck := func(log poll.LogT) poll.Result { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Error(fmt.Errorf("The backoff was interrupted early")) + } + if time.Now().After(pollEnd) { + // Allow more than enough time for the backoff to be interrupted before + // reporting that it wasn't. + return poll.Success() + } + return poll.Continue("waiting for events to be removed from database. Currently present EDU: %d", len(data)) + } + poll.WaitOn(t, immediateCheck, poll.WithTimeout(2*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + queues.statistics.ForServer(destination).Failure() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + queues.statistics.ForServer(destination).Failure() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestRetryServerSendsPDUSuccessfully(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // NOTE : getQueue before sending event to ensure we grab the same queue reference + // before it is blacklisted and deleted. + dest := queues.getQueue(destination) + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + checkBlacklisted := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + if !dest.running.Load() { + return poll.Success() + } + return poll.Continue("waiting for queue to stop completely") + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + db.RemoveServerFromBlacklist(destination) + queues.RetryServer(destination) + checkRetry := func(log poll.LogT) poll.Result { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data)) + } + poll.WaitOn(t, checkRetry, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestRetryServerSendsEDUSuccessfully(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // NOTE : getQueue before sending event to ensure we grab the same queue reference + // before it is blacklisted and deleted. + dest := queues.getQueue(destination) + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + checkBlacklisted := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + if !dest.running.Load() { + return poll.Success() + } + return poll.Continue("waiting for queue to stop completely") + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + db.RemoveServerFromBlacklist(destination) + queues.RetryServer(destination) + checkRetry := func(log poll.LogT) poll.Result { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data)) + } + poll.WaitOn(t, checkRetry, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUBatches(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + + // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // Populate database with > maxPDUsPerTransaction + pduMultiplier := uint32(3) + for i := 0; i < maxPDUsPerTransaction*int(pduMultiplier); i++ { + ev := mustCreatePDU(t) + headeredJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) + now := gomatrixserverlib.AsTimestamp(time.Now()) + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, i)) + db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid) + } + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == pduMultiplier+1 { // +1 for the extra SendEvent() + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for all events to be removed from database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + // }) +} + +func TestSendEDUBatches(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + + // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // Populate database with > maxEDUsPerTransaction + eduMultiplier := uint32(3) + for i := 0; i < maxEDUsPerTransaction*int(eduMultiplier); i++ { + ev := mustCreateEDU(t) + ephemeralJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) + db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil) + } + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == eduMultiplier+1 { // +1 for the extra SendEvent() + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for all events to be removed from database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + // }) +} + +func TestSendPDUAndEDUBatches(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + + // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // Populate database with > maxEDUsPerTransaction + multiplier := uint32(3) + + for i := 0; i < maxPDUsPerTransaction*int(multiplier)+1; i++ { + ev := mustCreatePDU(t) + headeredJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) + now := gomatrixserverlib.AsTimestamp(time.Now()) + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, i)) + db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid) + } + + for i := 0; i < maxEDUsPerTransaction*int(multiplier); i++ { + ev := mustCreateEDU(t) + ephemeralJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) + db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil) + } + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == multiplier+1 { // +1 for the extra SendEvent() + pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrPDU) + eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrEDU) + if len(pduData) == 0 && len(eduData) == 0 { + return poll.Success() + } + return poll.Continue("waiting for all events to be removed from database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData)) + } + return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + // }) +} + +func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + dest := queues.getQueue(destination) + queues.statistics.ForServer(destination).Failure() + + ev := mustCreatePDU(t) + headeredJSON, _ := json.Marshal(ev) + nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) + now := gomatrixserverlib.AsTimestamp(time.Now()) + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, 1)) + db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid) + + pollEnd := time.Now().Add(3 * time.Second) + runningCheck := func(log poll.LogT) poll.Result { + if dest.running.Load() || fc.txCount.Load() > 0 { + return poll.Error(fmt.Errorf("The queue was started")) + } + if time.Now().After(pollEnd) { + // Allow more than enough time for the queue to be started in the case + // of backoff triggering it to start. + return poll.Success() + } + return poll.Continue("waiting to ensure queue doesn't start.") + } + poll.WaitOn(t, runningCheck, poll.WithTimeout(4*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { + // NOTE : Only one test case against real databases can be run at a time. + t.Parallel() + failuresUntilBlacklist := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true) + // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up. + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + // NOTE : getQueue before sending event to ensure we grab the same queue reference + // before it is blacklisted and deleted. + dest := queues.getQueue(destination) + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + edu := mustCreateEDU(t) + errEDU := queues.SendEDU(edu, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, errEDU) + + checkBlacklisted := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilBlacklist { + pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrPDU) + eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrEDU) + if len(pduData) == 1 && len(eduData) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + if !dest.running.Load() { + return poll.Success() + } + return poll.Continue("waiting for queue to stop completely") + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for events to be added to database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + db.RemoveServerFromBlacklist(destination) + queues.RetryServer(destination) + checkRetry := func(log poll.LogT) poll.Result { + pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrPDU) + eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200) + assert.NoError(t, dbErrEDU) + if len(pduData) == 0 && len(eduData) == 0 { + return poll.Success() + } + return poll.Continue("waiting for events to be removed from database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData)) + } + poll.WaitOn(t, checkRetry, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) + }) +} diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index db6d5c735..2ba99112c 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -2,6 +2,7 @@ package statistics import ( "math" + "math/rand" "sync" "time" @@ -20,12 +21,23 @@ type Statistics struct { servers map[gomatrixserverlib.ServerName]*ServerStatistics mutex sync.RWMutex + backoffTimers map[gomatrixserverlib.ServerName]*time.Timer + backoffMutex sync.RWMutex + // How many times should we tolerate consecutive failures before we // just blacklist the host altogether? The backoff is exponential, // so the max time here to attempt is 2**failures seconds. FailuresUntilBlacklist uint32 } +func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistics { + return Statistics{ + DB: db, + FailuresUntilBlacklist: failuresUntilBlacklist, + backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), + } +} + // ForServer returns server statistics for the given server name. If it // does not exist, it will create empty statistics and return those. func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics { @@ -45,7 +57,6 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS server = &ServerStatistics{ statistics: s, serverName: serverName, - interrupt: make(chan struct{}), } s.servers[serverName] = server s.mutex.Unlock() @@ -64,29 +75,43 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS // many times we failed etc. It also manages the backoff time and black- // listing a remote host if it remains uncooperative. type ServerStatistics struct { - statistics *Statistics // - serverName gomatrixserverlib.ServerName // - blacklisted atomic.Bool // is the node blacklisted - backoffStarted atomic.Bool // is the backoff started - backoffUntil atomic.Value // time.Time until this backoff interval ends - backoffCount atomic.Uint32 // number of times BackoffDuration has been called - interrupt chan struct{} // interrupts the backoff goroutine - successCounter atomic.Uint32 // how many times have we succeeded? + statistics *Statistics // + serverName gomatrixserverlib.ServerName // + blacklisted atomic.Bool // is the node blacklisted + backoffStarted atomic.Bool // is the backoff started + backoffUntil atomic.Value // time.Time until this backoff interval ends + backoffCount atomic.Uint32 // number of times BackoffDuration has been called + successCounter atomic.Uint32 // how many times have we succeeded? + backoffNotifier func() // notifies destination queue when backoff completes + notifierMutex sync.Mutex } +const maxJitterMultiplier = 1.4 +const minJitterMultiplier = 0.8 + // duration returns how long the next backoff interval should be. func (s *ServerStatistics) duration(count uint32) time.Duration { - return time.Second * time.Duration(math.Exp2(float64(count))) + // Add some jitter to minimise the chance of having multiple backoffs + // ending at the same time. + jitter := rand.Float64()*(maxJitterMultiplier-minJitterMultiplier) + minJitterMultiplier + duration := time.Millisecond * time.Duration(math.Exp2(float64(count))*jitter*1000) + return duration } // cancel will interrupt the currently active backoff. func (s *ServerStatistics) cancel() { s.blacklisted.Store(false) s.backoffUntil.Store(time.Time{}) - select { - case s.interrupt <- struct{}{}: - default: - } + + s.ClearBackoff() +} + +// AssignBackoffNotifier configures the channel to send to when +// a backoff completes. +func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) { + s.notifierMutex.Lock() + defer s.notifierMutex.Unlock() + s.backoffNotifier = notifier } // Success updates the server statistics with a new successful @@ -95,8 +120,8 @@ func (s *ServerStatistics) cancel() { // we will unblacklist it. func (s *ServerStatistics) Success() { s.cancel() - s.successCounter.Inc() s.backoffCount.Store(0) + s.successCounter.Inc() if s.statistics.DB != nil { if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) @@ -105,13 +130,17 @@ func (s *ServerStatistics) Success() { } // Failure marks a failure and starts backing off if needed. -// The next call to BackoffIfRequired will do the right thing -// after this. It will return the time that the current failure +// It will return the time that the current failure // will result in backoff waiting until, and a bool signalling // whether we have blacklisted and therefore to give up. func (s *ServerStatistics) Failure() (time.Time, bool) { + // Return immediately if we have blacklisted this node. + if s.blacklisted.Load() { + return time.Time{}, true + } + // If we aren't already backing off, this call will start - // a new backoff period. Increase the failure counter and + // a new backoff period, increase the failure counter and // start a goroutine which will wait out the backoff and // unset the backoffStarted flag when done. if s.backoffStarted.CompareAndSwap(false, true) { @@ -122,40 +151,48 @@ func (s *ServerStatistics) Failure() (time.Time, bool) { logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName) } } + s.ClearBackoff() return time.Time{}, true } - go func() { - until, ok := s.backoffUntil.Load().(time.Time) - if ok && !until.IsZero() { - select { - case <-time.After(time.Until(until)): - case <-s.interrupt: - } - s.backoffStarted.Store(false) - } - }() + // We're starting a new back off so work out what the next interval + // will be. + count := s.backoffCount.Load() + until := time.Now().Add(s.duration(count)) + s.backoffUntil.Store(until) + + s.statistics.backoffMutex.Lock() + defer s.statistics.backoffMutex.Unlock() + s.statistics.backoffTimers[s.serverName] = time.AfterFunc(time.Until(until), s.backoffFinished) } - // Check if we have blacklisted this node. - if s.blacklisted.Load() { - return time.Now(), true - } + return s.backoffUntil.Load().(time.Time), false +} - // If we're already backing off and we haven't yet surpassed - // the deadline then return that. Repeated calls to Failure - // within a single backoff interval will have no side effects. - if until, ok := s.backoffUntil.Load().(time.Time); ok && !time.Now().After(until) { - return until, false +// ClearBackoff stops the backoff timer for this destination if it is running +// and removes the timer from the backoffTimers map. +func (s *ServerStatistics) ClearBackoff() { + // If the timer is still running then stop it so it's memory is cleaned up sooner. + s.statistics.backoffMutex.Lock() + defer s.statistics.backoffMutex.Unlock() + if timer, ok := s.statistics.backoffTimers[s.serverName]; ok { + timer.Stop() } + delete(s.statistics.backoffTimers, s.serverName) - // We're either backing off and have passed the deadline, or - // we aren't backing off, so work out what the next interval - // will be. - count := s.backoffCount.Load() - until := time.Now().Add(s.duration(count)) - s.backoffUntil.Store(until) - return until, false + s.backoffStarted.Store(false) +} + +// backoffFinished will clear the previous backoff and notify the destination queue. +func (s *ServerStatistics) backoffFinished() { + s.ClearBackoff() + + // Notify the destinationQueue if one is currently running. + s.notifierMutex.Lock() + defer s.notifierMutex.Unlock() + if s.backoffNotifier != nil { + s.backoffNotifier() + } } // BackoffInfo returns information about the current or previous backoff. @@ -174,6 +211,12 @@ func (s *ServerStatistics) Blacklisted() bool { return s.blacklisted.Load() } +// RemoveBlacklist removes the blacklisted status from the server. +func (s *ServerStatistics) RemoveBlacklist() { + s.cancel() + s.backoffCount.Store(0) +} + // SuccessCount returns the number of successful requests. This is // usually useful in constructing transaction IDs. func (s *ServerStatistics) SuccessCount() uint32 { diff --git a/federationapi/statistics/statistics_test.go b/federationapi/statistics/statistics_test.go index 225350b6d..6aa997f44 100644 --- a/federationapi/statistics/statistics_test.go +++ b/federationapi/statistics/statistics_test.go @@ -7,9 +7,7 @@ import ( ) func TestBackoff(t *testing.T) { - stats := Statistics{ - FailuresUntilBlacklist: 7, - } + stats := NewStatistics(nil, 7) server := ServerStatistics{ statistics: &stats, serverName: "test.com", @@ -36,7 +34,7 @@ func TestBackoff(t *testing.T) { // Get the duration. _, blacklist := server.BackoffInfo() - duration := time.Until(until).Round(time.Second) + duration := time.Until(until) // Unset the backoff, or otherwise our next call will think that // there's a backoff in progress and return the same result. @@ -57,8 +55,17 @@ func TestBackoff(t *testing.T) { // Check if the duration is what we expect. t.Logf("Backoff %d is for %s", i, duration) - if wanted := time.Second * time.Duration(math.Exp2(float64(i))); !blacklist && duration != wanted { - t.Fatalf("Backoff %d should have been %s but was %s", i, wanted, duration) + roundingAllowance := 0.01 + minDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*minJitterMultiplier*1000-roundingAllowance) + maxDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*maxJitterMultiplier*1000+roundingAllowance) + var inJitterRange bool + if duration >= minDuration && duration <= maxDuration { + inJitterRange = true + } else { + inJitterRange = false + } + if !blacklist && !inJitterRange { + t.Fatalf("Backoff %d should have been between %s and %s but was %s", i, minDuration, maxDuration, duration) } } } diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 9e40f311c..6afb313a8 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -52,6 +52,10 @@ type Receipt struct { nid int64 } +func NewReceipt(nid int64) Receipt { + return Receipt{nid: nid} +} + func (r *Receipt) String() string { return fmt.Sprintf("%d", r.nid) } diff --git a/go.mod b/go.mod index 911d36c1c..2248e73c6 100644 --- a/go.mod +++ b/go.mod @@ -50,6 +50,7 @@ require ( golang.org/x/term v0.0.0-20220919170432-7a66f970e087 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 + gotest.tools/v3 v3.4.0 nhooyr.io/websocket v1.8.7 ) @@ -127,7 +128,6 @@ require ( gopkg.in/macaroon.v2 v2.1.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - gotest.tools/v3 v3.4.0 // indirect ) go 1.18 From f3dae0e749ca35b1527fbfcb0371e89d0e9833ab Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 19 Oct 2022 11:40:38 +0100 Subject: [PATCH 07/22] Bump nokogiri from 1.13.6 to 1.13.9 in /docs (#2809) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [nokogiri](https://github.com/sparklemotion/nokogiri) from 1.13.6 to 1.13.9.
Release notes

Sourced from nokogiri's releases.

1.13.9 / 2022-10-18

Security

Dependencies

  • [CRuby] Vendored libxml2 is updated to v2.10.3 from v2.9.14.
  • [CRuby] Vendored libxslt is updated to v1.1.37 from v1.1.35.
  • [CRuby] Vendored zlib is updated from 1.2.12 to 1.2.13. (See LICENSE-DEPENDENCIES.md for details on which packages redistribute this library.)

Fixed

  • [CRuby] Nokogiri::XML::Namespace objects, when compacted, update their internal struct's reference to the Ruby object wrapper. Previously, with GC compaction enabled, a segmentation fault was possible after compaction was triggered. [#2658] (Thanks, @​eightbitraptor and @​peterzhu2118!)
  • [CRuby] Document#remove_namespaces! now defers freeing the underlying xmlNs struct until the Document is GCed. Previously, maintaining a reference to a Namespace object that was removed in this way could lead to a segfault. [#2658]

sha256 checksums:

9b69829561d30c4461ea803baeaf3460e8b145cff7a26ce397119577a4083a02
nokogiri-1.13.9-aarch64-linux.gem
e76ebb4b7b2e02c72b2d1541289f8b0679fb5984867cf199d89b8ef485764956
nokogiri-1.13.9-arm64-darwin.gem
15bae7d08bddeaa898d8e3f558723300137c26a2dc2632a1f89c8574c4467165
nokogiri-1.13.9-java.gem
f6a1dbc7229184357f3129503530af73cc59ceba4932c700a458a561edbe04b9
nokogiri-1.13.9-x64-mingw-ucrt.gem
36d935d799baa4dc488024f71881ff0bc8b172cecdfc54781169c40ec02cbdb3
nokogiri-1.13.9-x64-mingw32.gem
ebaf82aa9a11b8fafb67873d19ee48efb565040f04c898cdce8ca0cd53ff1a12
nokogiri-1.13.9-x86-linux.gem
11789a2a11b28bc028ee111f23311461104d8c4468d5b901ab7536b282504154
nokogiri-1.13.9-x86-mingw32.gem
01830e1646803ff91c0fe94bc768ff40082c6de8cfa563dafd01b3f7d5f9d795
nokogiri-1.13.9-x86_64-darwin.gem
8e93b8adec22958013799c8690d81c2cdf8a90b6f6e8150ab22e11895844d781
nokogiri-1.13.9-x86_64-linux.gem
96f37c1baf0234d3ae54c2c89aef7220d4a8a1b03d2675ff7723565b0a095531
nokogiri-1.13.9.gem

1.13.8 / 2022-07-23

Deprecated

  • XML::Reader#attribute_nodes is deprecated due to incompatibility between libxml2's xmlReader memory semantics and Ruby's garbage collector. Although this method continues to exist for backwards compatibility, it is unsafe to call and may segfault. This method will be removed in a future version of Nokogiri, and callers should use #attribute_hash instead. [#2598]

Improvements

  • XML::Reader#attribute_hash is a new method to safely retrieve the attributes of a node from XML::Reader. [#2598, #2599]

Fixed

... (truncated)

Changelog

Sourced from nokogiri's changelog.

1.13.9 / 2022-10-18

Security

Dependencies

  • [CRuby] Vendored libxml2 is updated to v2.10.3 from v2.9.14.
  • [CRuby] Vendored libxslt is updated to v1.1.37 from v1.1.35.
  • [CRuby] Vendored zlib is updated from 1.2.12 to 1.2.13. (See LICENSE-DEPENDENCIES.md for details on which packages redistribute this library.)

Fixed

  • [CRuby] Nokogiri::XML::Namespace objects, when compacted, update their internal struct's reference to the Ruby object wrapper. Previously, with GC compaction enabled, a segmentation fault was possible after compaction was triggered. [#2658] (Thanks, @​eightbitraptor and @​peterzhu2118!)
  • [CRuby] Document#remove_namespaces! now defers freeing the underlying xmlNs struct until the Document is GCed. Previously, maintaining a reference to a Namespace object that was removed in this way could lead to a segfault. [#2658]

1.13.8 / 2022-07-23

Deprecated

  • XML::Reader#attribute_nodes is deprecated due to incompatibility between libxml2's xmlReader memory semantics and Ruby's garbage collector. Although this method continues to exist for backwards compatibility, it is unsafe to call and may segfault. This method will be removed in a future version of Nokogiri, and callers should use #attribute_hash instead. [#2598]

Improvements

  • XML::Reader#attribute_hash is a new method to safely retrieve the attributes of a node from XML::Reader. [#2598, #2599]

Fixed

  • [CRuby] Calling XML::Reader#attributes is now safe to call. In Nokogiri <= 1.13.7 this method may segfault. [#2598, #2599]

1.13.7 / 2022-07-12

Fixed

XML::Node objects, when compacted, update their internal struct's reference to the Ruby object wrapper. Previously, with GC compaction enabled, a segmentation fault was possible after compaction was triggered. [#2578] (Thanks, @​eightbitraptor!)

Commits
  • 897759c version bump to v1.13.9
  • aeb1ac3 doc: update CHANGELOG
  • c663e49 Merge pull request #2671 from sparklemotion/flavorjones-update-zlib-1.2.13_v1...
  • 212e07d ext: hack to cross-compile zlib v1.2.13 on darwin
  • 76dbc8c dep: update zlib to v1.2.13
  • 24e3a9c doc: update CHANGELOG
  • 4db3b4d Merge pull request #2668 from sparklemotion/flavorjones-namespace-scopes-comp...
  • 73d73d6 fix: Document#remove_namespaces! use-after-free bug
  • 5f58b34 fix: namespace nodes behave properly when compacted
  • b08a858 test: repro namespace_scopes compaction issue
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=nokogiri&package-manager=bundler&previous-version=1.13.6&new-version=1.13.9)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) - `@dependabot use these labels` will set the current labels as the default for future PRs for this repo and language - `@dependabot use these reviewers` will set the current reviewers as the default for future PRs for this repo and language - `@dependabot use these assignees` will set the current assignees as the default for future PRs for this repo and language - `@dependabot use this milestone` will set the current milestone as the default for future PRs for this repo and language You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/matrix-org/dendrite/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/Gemfile.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index bc73df728..c7ba43711 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -231,9 +231,9 @@ GEM jekyll-seo-tag (~> 2.1) minitest (5.15.0) multipart-post (2.1.1) - nokogiri (1.13.6-arm64-darwin) + nokogiri (1.13.9-arm64-darwin) racc (~> 1.4) - nokogiri (1.13.6-x86_64-linux) + nokogiri (1.13.9-x86_64-linux) racc (~> 1.4) octokit (4.22.0) faraday (>= 0.9) From c1463db6c9183aa67ef41e7ea85ed36dc5817d18 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 19 Oct 2022 12:03:12 +0100 Subject: [PATCH 08/22] Fix concurrent map write in key server --- keyserver/internal/internal.go | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 06fc4987c..d2ea20935 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -250,6 +250,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap // nolint:gocyclo func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { + var respMu sync.Mutex res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey) res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) @@ -329,7 +330,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } // attempt to satisfy key queries from the local database first as we should get device updates pushed to us - domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys) + domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys) if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 { // perform key queries for remote devices a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys) @@ -407,7 +408,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } func (a *KeyInternalAPI) remoteKeysFromDatabase( - ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, + ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string, ) map[string]map[string][]string { fetchRemote := make(map[string]map[string][]string) for domain, userToDeviceMap := range domainToDeviceKeys { @@ -415,7 +416,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase( // we can't safely return keys from the db when all devices are requested as we don't // know if one has just been added. if len(deviceIDs) > 0 { - err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs) + err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs) if err == nil { continue } @@ -542,7 +543,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( // refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this // user so the fact that we're populating all devices here isn't a problem so long as we have devices. respMu.Lock() - err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil) + err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil) respMu.Unlock() if err != nil { logrus.WithFields(logrus.Fields{ @@ -573,7 +574,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( // inspecting the failures map though so they can know it's a cached response. for userID, dkeys := range devKeys { // drop the error as it's already a failure at this point - _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys) + _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys) } // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache @@ -585,7 +586,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( } func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( - ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string, + ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string, ) error { keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) // if we can't query the db or there are fewer keys than requested, fetch from remote. @@ -598,9 +599,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( if len(deviceIDs) == 0 && len(keys) == 0 { return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID) } + respMu.Lock() if res.DeviceKeys[userID] == nil { res.DeviceKeys[userID] = make(map[string]json.RawMessage) } + respMu.Unlock() for _, key := range keys { if len(key.KeyJSON) == 0 { @@ -610,7 +613,9 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { DisplayName string `json:"device_display_name,omitempty"` }{key.DisplayName}) + respMu.Lock() res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON + respMu.Unlock() } return nil } From 8cbe14bd6d985ceb2f7c098548a3fbeedfce2d55 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 19 Oct 2022 12:27:34 +0100 Subject: [PATCH 09/22] Fix lock contention --- keyserver/internal/internal.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index d2ea20935..89621aa87 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -542,9 +542,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( } // refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this // user so the fact that we're populating all devices here isn't a problem so long as we have devices. - respMu.Lock() err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil) - respMu.Unlock() if err != nil { logrus.WithFields(logrus.Fields{ logrus.ErrorKey: err, @@ -568,6 +566,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( res.Failures[serverName] = map[string]interface{}{ "message": err.Error(), } + respMu.Unlock() // last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server // is down, better to return something than nothing at all. Clients can know about the failure by @@ -578,11 +577,11 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( } // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache + respMu.Lock() if len(res.DeviceKeys) > 0 { delete(res.Failures, serverName) } respMu.Unlock() - } func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( From e79bfd8fd55781783482cb45ae6d4e78062bb8ac Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Wed, 19 Oct 2022 14:05:39 +0200 Subject: [PATCH 10/22] Get state deltas without filters (#2810) This makes the following changes: - get state deltas without the user supplied filter, so we can actually "calculate" state transitions - closes `stmt` when using SQLite - Adds presence for users who newly joined a room, even if the syncing user already knows about the presence status (should fix https://github.com/matrix-org/complement/pull/516) --- .../postgres/output_room_events_table.go | 77 ++++++++++++------- syncapi/storage/shared/storage_sync.go | 30 ++++++-- .../sqlite3/output_room_events_table.go | 45 +++++++++-- syncapi/streams/stream_pdu.go | 16 ++-- syncapi/streams/stream_presence.go | 3 +- syncapi/sync/request.go | 19 ++--- syncapi/types/provider.go | 5 +- sytest-blacklist | 6 -- sytest-whitelist | 8 +- 9 files changed, 144 insertions(+), 65 deletions(-) diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index b562e6804..0ecbdf4d2 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -28,8 +28,9 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" "github.com/lib/pq" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) const outputRoomEventsSchema = ` @@ -133,7 +134,7 @@ const updateEventJSONSQL = "" + "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2" // In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). -const selectStateInRangeSQL = "" + +const selectStateInRangeFilteredSQL = "" + "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids, history_visibility" + " FROM syncapi_output_room_events" + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + @@ -146,6 +147,15 @@ const selectStateInRangeSQL = "" + " ORDER BY id ASC" + " LIMIT $9" +// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). +const selectStateInRangeSQL = "" + + "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids, history_visibility" + + " FROM syncapi_output_room_events" + + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + + " AND room_id = ANY($3)" + + " ORDER BY id ASC" + + " LIMIT $4" + const deleteEventsForRoomSQL = "" + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" @@ -171,20 +181,21 @@ const selectContextAfterEventSQL = "" + const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE id > $1 AND type = ANY($2) ORDER BY id ASC LIMIT $3" type outputRoomEventsStatements struct { - insertEventStmt *sql.Stmt - selectEventsStmt *sql.Stmt - selectEventsWitFilterStmt *sql.Stmt - selectMaxEventIDStmt *sql.Stmt - selectRecentEventsStmt *sql.Stmt - selectRecentEventsForSyncStmt *sql.Stmt - selectEarlyEventsStmt *sql.Stmt - selectStateInRangeStmt *sql.Stmt - updateEventJSONStmt *sql.Stmt - deleteEventsForRoomStmt *sql.Stmt - selectContextEventStmt *sql.Stmt - selectContextBeforeEventStmt *sql.Stmt - selectContextAfterEventStmt *sql.Stmt - selectSearchStmt *sql.Stmt + insertEventStmt *sql.Stmt + selectEventsStmt *sql.Stmt + selectEventsWitFilterStmt *sql.Stmt + selectMaxEventIDStmt *sql.Stmt + selectRecentEventsStmt *sql.Stmt + selectRecentEventsForSyncStmt *sql.Stmt + selectEarlyEventsStmt *sql.Stmt + selectStateInRangeFilteredStmt *sql.Stmt + selectStateInRangeStmt *sql.Stmt + updateEventJSONStmt *sql.Stmt + deleteEventsForRoomStmt *sql.Stmt + selectContextEventStmt *sql.Stmt + selectContextBeforeEventStmt *sql.Stmt + selectContextAfterEventStmt *sql.Stmt + selectSearchStmt *sql.Stmt } func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { @@ -214,6 +225,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectRecentEventsStmt, selectRecentEventsSQL}, {&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL}, {&s.selectEarlyEventsStmt, selectEarlyEventsSQL}, + {&s.selectStateInRangeFilteredStmt, selectStateInRangeFilteredSQL}, {&s.selectStateInRangeStmt, selectStateInRangeSQL}, {&s.updateEventJSONStmt, updateEventJSONSQL}, {&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL}, @@ -240,17 +252,28 @@ func (s *outputRoomEventsStatements) SelectStateInRange( ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter, roomIDs []string, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { - stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) - senders, notSenders := getSendersStateFilterFilter(stateFilter) - rows, err := stmt.QueryContext( - ctx, r.Low(), r.High(), pq.StringArray(roomIDs), - pq.StringArray(senders), - pq.StringArray(notSenders), - pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), - pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), - stateFilter.ContainsURL, - stateFilter.Limit, - ) + var rows *sql.Rows + var err error + if stateFilter != nil { + stmt := sqlutil.TxStmt(txn, s.selectStateInRangeFilteredStmt) + senders, notSenders := getSendersStateFilterFilter(stateFilter) + rows, err = stmt.QueryContext( + ctx, r.Low(), r.High(), pq.StringArray(roomIDs), + pq.StringArray(senders), + pq.StringArray(notSenders), + pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), + pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), + stateFilter.ContainsURL, + stateFilter.Limit, + ) + } else { + stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) + rows, err = stmt.QueryContext( + ctx, r.Low(), r.High(), pq.StringArray(roomIDs), + r.High()-r.Low(), + ) + } + if err != nil { return nil, nil, err } diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index cb61c1c26..1f66ccc0e 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -5,10 +5,11 @@ import ( "database/sql" "fmt" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) type DatabaseTransaction struct { @@ -277,6 +278,7 @@ func (d *DatabaseTransaction) GetBackwardTopologyPos( // exclusive of oldPos, inclusive of newPos, for the rooms in which // the user has new membership events. // A list of joined room IDs is also returned in case the caller needs it. +// nolint:gocyclo func (d *DatabaseTransaction) GetStateDeltas( ctx context.Context, device *userapi.Device, r types.Range, userID string, @@ -311,7 +313,7 @@ func (d *DatabaseTransaction) GetStateDeltas( } // get all the state events ever (i.e. for all available rooms) between these two positions - stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) + stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, nil, allRoomIDs) if err != nil { if err == sql.ErrNoRows { return nil, nil, nil @@ -326,6 +328,22 @@ func (d *DatabaseTransaction) GetStateDeltas( return nil, nil, err } + // get all the state events ever (i.e. for all available rooms) between these two positions + stateNeededFiltered, eventMapFiltered, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + stateFiltered, err := d.fetchStateEvents(ctx, d.txn, stateNeededFiltered, eventMapFiltered) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + // find out which rooms this user is peeking, if any. // We do this before joins so any peeks get overwritten peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r) @@ -371,6 +389,7 @@ func (d *DatabaseTransaction) GetStateDeltas( // If our membership is now join but the previous membership wasn't // then this is a "join transition", so we'll insert this room. if prevMembership != membership { + newlyJoinedRooms[roomID] = true // Get the full room state, as we'll send that down for a newly // joined room instead of a delta. var s []types.StreamEvent @@ -383,8 +402,7 @@ func (d *DatabaseTransaction) GetStateDeltas( // Add the information for this room into the state so that // it will get added with all of the rest of the joined rooms. - state[roomID] = s - newlyJoinedRooms[roomID] = true + stateFiltered[roomID] = s } // We won't add joined rooms into the delta at this point as they @@ -395,7 +413,7 @@ func (d *DatabaseTransaction) GetStateDeltas( deltas = append(deltas, types.StateDelta{ Membership: membership, MembershipPos: ev.StreamPosition, - StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), + StateEvents: d.StreamEventsToEvents(device, stateFiltered[roomID]), RoomID: roomID, }) break @@ -407,7 +425,7 @@ func (d *DatabaseTransaction) GetStateDeltas( for _, joinedRoomID := range joinedRoomIDs { deltas = append(deltas, types.StateDelta{ Membership: gomatrixserverlib.Join, - StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), + StateEvents: d.StreamEventsToEvents(device, stateFiltered[joinedRoomID]), RoomID: joinedRoomID, NewlyJoined: newlyJoinedRooms[joinedRoomID], }) diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index d6a674b9c..77c692ff0 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -29,8 +29,9 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) const outputRoomEventsSchema = ` @@ -189,21 +190,36 @@ func (s *outputRoomEventsStatements) SelectStateInRange( for _, roomID := range roomIDs { inputParams = append(inputParams, roomID) } - stmt, params, err := prepareWithFilters( - s.db, txn, stmtSQL, inputParams, - stateFilter.Senders, stateFilter.NotSenders, - stateFilter.Types, stateFilter.NotTypes, - nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc, + var ( + stmt *sql.Stmt + params []any + err error ) + if stateFilter != nil { + stmt, params, err = prepareWithFilters( + s.db, txn, stmtSQL, inputParams, + stateFilter.Senders, stateFilter.NotSenders, + stateFilter.Types, stateFilter.NotTypes, + nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc, + ) + } else { + stmt, params, err = prepareWithFilters( + s.db, txn, stmtSQL, inputParams, + nil, nil, + nil, nil, + nil, nil, int(r.High()-r.Low()), FilterOrderAsc, + ) + } if err != nil { return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err) } + defer internal.CloseAndLogIfError(ctx, stmt, "selectStateInRange: stmt.close() failed") rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, nil, err } - defer rows.Close() // nolint: errcheck + defer internal.CloseAndLogIfError(ctx, rows, "selectStateInRange: rows.close() failed") // Fetch all the state change events for all rooms between the two positions then loop each event and: // - Keep a cache of the event by ID (99% of state change events are for the event itself) // - For each room ID, build up an array of event IDs which represents cumulative adds/removes @@ -269,6 +285,7 @@ func (s *outputRoomEventsStatements) SelectMaxEventID( ) (id int64, err error) { var nullableID sql.NullInt64 stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt) + defer internal.CloseAndLogIfError(ctx, stmt, "SelectMaxEventID: stmt.close() failed") err = stmt.QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 @@ -323,6 +340,7 @@ func (s *outputRoomEventsStatements) InsertEvent( return 0, err } insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) + defer internal.CloseAndLogIfError(ctx, insertStmt, "InsertEvent: stmt.close() failed") _, err = insertStmt.ExecContext( ctx, streamPos, @@ -367,6 +385,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( if err != nil { return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) } + defer internal.CloseAndLogIfError(ctx, stmt, "selectRecentEvents: stmt.close() failed") rows, err := stmt.QueryContext(ctx, params...) if err != nil { @@ -415,6 +434,8 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( if err != nil { return nil, fmt.Errorf("s.prepareWithFilters: %w", err) } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectEarlyEvents: stmt.close() failed") + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, err @@ -456,6 +477,8 @@ func (s *outputRoomEventsStatements) SelectEvents( if err != nil { return nil, err } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectEvents: stmt.close() failed") + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, err @@ -558,6 +581,10 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent( filter.Types, filter.NotTypes, nil, filter.ContainsURL, filter.Limit, FilterOrderDesc, ) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectContextBeforeEvent: stmt.close() failed") rows, err := stmt.QueryContext(ctx, params...) if err != nil { @@ -596,6 +623,10 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent( filter.Types, filter.NotTypes, nil, filter.ContainsURL, filter.Limit, FilterOrderAsc, ) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectContextAfterEvent: stmt.close() failed") rows, err := stmt.QueryContext(ctx, params...) if err != nil { diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 613ac434f..9ec2b61cd 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -194,7 +194,7 @@ func (p *PDUStreamProvider) IncrementalSync( } } var pos types.StreamPosition - if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil { + if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req); err != nil { req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed") if err == context.DeadlineExceeded || err == context.Canceled || err == sql.ErrTxDone { return newPos @@ -225,7 +225,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( delta types.StateDelta, eventFilter *gomatrixserverlib.RoomEventFilter, stateFilter *gomatrixserverlib.StateFilter, - res *types.Response, + req *types.SyncRequest, ) (types.StreamPosition, error) { if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave { // make sure we don't leak recent events after the leave event. @@ -290,8 +290,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( hasMembershipChange := false for _, recentEvent := range recentStreamEvents { if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil { + if membership, _ := recentEvent.Membership(); membership == gomatrixserverlib.Join { + req.MembershipChanges[*recentEvent.StateKey()] = struct{}{} + } hasMembershipChange = true - break } } @@ -318,9 +320,9 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. - jr.Timeline.Limited = limited && len(events) == len(recentEvents) + jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Join[delta.RoomID] = jr + req.Response.Rooms.Join[delta.RoomID] = jr case gomatrixserverlib.Peek: jr := types.NewJoinResponse() @@ -329,7 +331,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = limited jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Peek[delta.RoomID] = jr + req.Response.Rooms.Peek[delta.RoomID] = jr case gomatrixserverlib.Leave: fallthrough // transitions to leave are the same as ban @@ -342,7 +344,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( // didn't "remove" events, return that the response is limited. lr.Timeline.Limited = limited && len(events) == len(recentEvents) lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Leave[delta.RoomID] = lr + req.Response.Rooms.Leave[delta.RoomID] = lr } return latestPosition, nil diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index 8b87af452..030b7c5d5 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -121,7 +121,8 @@ func (p *PresenceStreamProvider) IncrementalSync( prevPresence := pres.(*types.PresenceInternal) currentlyActive := prevPresence.CurrentlyActive() skip := prevPresence.Equals(presence) && currentlyActive && req.Device.UserID != presence.UserID - if skip { + _, membershipChange := req.MembershipChanges[presence.UserID] + if skip && !membershipChange { req.Log.Tracef("Skipping presence, no change (%s)", presence.UserID) continue } diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index 268ed70c6..620dfdcdb 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -91,15 +91,16 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat }) return &types.SyncRequest{ - Context: req.Context(), // - Log: logger, // - Device: &device, // - Response: types.NewResponse(), // Populated by all streams - Filter: filter, // - Since: since, // - Timeout: timeout, // - Rooms: make(map[string]string), // Populated by the PDU stream - WantFullState: wantFullState, // + Context: req.Context(), // + Log: logger, // + Device: &device, // + Response: types.NewResponse(), // Populated by all streams + Filter: filter, // + Since: since, // + Timeout: timeout, // + Rooms: make(map[string]string), // Populated by the PDU stream + WantFullState: wantFullState, // + MembershipChanges: make(map[string]struct{}), // Populated by the PDU stream }, nil } diff --git a/syncapi/types/provider.go b/syncapi/types/provider.go index 378cafe99..9a533002b 100644 --- a/syncapi/types/provider.go +++ b/syncapi/types/provider.go @@ -4,9 +4,10 @@ import ( "context" "time" - userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" + + userapi "github.com/matrix-org/dendrite/userapi/api" ) type SyncRequest struct { @@ -22,6 +23,8 @@ type SyncRequest struct { // Updated by the PDU stream. Rooms map[string]string // Updated by the PDU stream. + MembershipChanges map[string]struct{} + // Updated by the PDU stream. IgnoredUsers IgnoredUsers } diff --git a/sytest-blacklist b/sytest-blacklist index 634c07cf3..fe48fb791 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -39,12 +39,6 @@ Events in rooms with AS-hosted room aliases are sent to AS server Inviting an AS-hosted user asks the AS server Accesing an AS-hosted room alias asks the AS server -# Flakey, need additional investigation - -Messages that notify from another user increment notification_count -Messages that highlight from another user increment unread highlight count -Notifications can be viewed with GET /notifications - # More flakey Guest users can join guest_access rooms diff --git a/sytest-whitelist b/sytest-whitelist index 93d447d28..1387838f7 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -746,4 +746,10 @@ Existing members see new member's presence Inbound federation can return missing events for joined visibility outliers whose auth_events are in a different room are correctly rejected Messages that notify from another user increment notification_count -Messages that highlight from another user increment unread highlight count \ No newline at end of file +Messages that highlight from another user increment unread highlight count +Newly joined room has correct timeline in incremental sync +When user joins a room the state is included in the next sync +When user joins a room the state is included in a gapped sync +Messages that notify from another user increment notification_count +Messages that highlight from another user increment unread highlight count +Notifications can be viewed with GET /notifications \ No newline at end of file From 6a93858125a2ece1c2cc557e11d34e51a67ada45 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 20 Oct 2022 10:45:59 +0200 Subject: [PATCH 11/22] Fix race condition --- federationapi/queue/destinationqueue.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 00e02b2d9..768ed1f2b 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -21,16 +21,17 @@ import ( "sync" "time" + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + "go.uber.org/atomic" + fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/gomatrix" - "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" - "go.uber.org/atomic" ) const ( @@ -541,6 +542,8 @@ func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) // the pending events and EDUs, and wipe our transaction ID. oq.statistics.Success() oq.pendingMutex.Lock() + defer oq.pendingMutex.Unlock() + for i := range oq.pendingPDUs[:pduCount] { oq.pendingPDUs[i] = nil } @@ -549,7 +552,6 @@ func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) } oq.pendingPDUs = oq.pendingPDUs[pduCount:] oq.pendingEDUs = oq.pendingEDUs[eduCount:] - oq.pendingMutex.Unlock() if len(oq.pendingPDUs) > 0 || len(oq.pendingEDUs) > 0 { select { From 539c61b3db4a76729e90a52823be89c32cbeb5ec Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 20 Oct 2022 12:34:53 +0200 Subject: [PATCH 12/22] Remove test from blacklist --- sytest-blacklist | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sytest-blacklist b/sytest-blacklist index fe48fb791..14edf398a 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -22,10 +22,6 @@ Forgotten room messages cannot be paginated Local device key changes get to remote servers with correct prev_id -# Flakey - -Local device key changes appear in /keys/changes - # we don't support groups Remove group category From b58c9bb094f3a069a4f40bbd6cc4a0ac205afcb6 Mon Sep 17 00:00:00 2001 From: devonh Date: Thu, 20 Oct 2022 15:37:35 +0000 Subject: [PATCH 13/22] Fix flakey queue test (#2818) Ensure both events are added to the database, even if the destination is already blacklisted. --- federationapi/queue/queue_test.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index 6da9e6b30..40419b91f 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -1004,9 +1004,12 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) assert.NoError(t, err) + // NOTE : The server can be blacklisted before this, so manually inject the event + // into the database. edu := mustCreateEDU(t) - errEDU := queues.SendEDU(edu, "localhost", []gomatrixserverlib.ServerName{destination}) - assert.NoError(t, errEDU) + ephemeralJSON, _ := json.Marshal(edu) + nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) + db.AssociateEDUWithDestination(pc.Context(), destination, nid, edu.Type, nil) checkBlacklisted := func(log poll.LogT) poll.Result { if fc.txCount.Load() == failuresUntilBlacklist { From 90414912012b274f49894b2819f5e6e393928da9 Mon Sep 17 00:00:00 2001 From: devonh Date: Thu, 20 Oct 2022 15:54:18 +0000 Subject: [PATCH 14/22] Mutex protect query keys response (#2812) --- keyserver/internal/internal.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 89621aa87..49ef03054 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -472,7 +472,9 @@ func (a *KeyInternalAPI) queryRemoteKeys( close(resultCh) }() - for result := range resultCh { + processResult := func(result *gomatrixserverlib.RespQueryKeys) { + respMu.Lock() + defer respMu.Unlock() for userID, nest := range result.DeviceKeys { res.DeviceKeys[userID] = make(map[string]json.RawMessage) for deviceID, deviceKey := range nest { @@ -495,6 +497,10 @@ func (a *KeyInternalAPI) queryRemoteKeys( // TODO: do we want to persist these somewhere now // that we have fetched them? } + + for result := range resultCh { + processResult(result) + } } func (a *KeyInternalAPI) queryRemoteKeysOnServer( From 73e02463cf6e267fdba950d0d231f98f95bc7994 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 21 Oct 2022 09:19:52 +0100 Subject: [PATCH 15/22] Allow `m.read.private` to clear notifications (#2811) Otherwise if a user switches to private read receipts, they may not be able to clear notification counts. --- userapi/consumers/clientapi.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go index c220d35cb..79f1bf06f 100644 --- a/userapi/consumers/clientapi.go +++ b/userapi/consumers/clientapi.go @@ -81,7 +81,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats readPos := msg.Header.Get(jetstream.EventID) evType := msg.Header.Get("type") - if readPos == "" || evType != "m.read" { + if readPos == "" || (evType != "m.read" && evType != "m.read.private") { return true } From 40cfb9a4ea23f1c9214553255feb296c2578b213 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 21 Oct 2022 10:26:22 +0200 Subject: [PATCH 16/22] Fix `invite -> leave -> join` dance when accepting invites (#2817) As mentioned in https://github.com/matrix-org/dendrite/issues/2361#issuecomment-1139394565 and observed by ourselves, this should fix the odd `invite -> leave -> join` dance when accepting invites. --- syncapi/consumers/roomserver.go | 7 +++++++ syncapi/streams/stream_invite.go | 33 ++++++++++++++++++-------------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index cfbb05327..f767615c8 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -428,6 +428,13 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( return } + // Only notify clients about retired invite events, if the user didn't accept the invite. + // The PDU stream will also receive an event about accepting the invitation, so there should + // be a "smooth" transition from invite -> join, and not invite -> leave -> join + if msg.Membership == gomatrixserverlib.Join { + return + } + // Notify any active sync requests that the invite has been retired. s.inviteStream.Advance(pduPos) s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID) diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 7875ffa35..700f25c10 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -74,21 +74,26 @@ func (p *InviteStreamProvider) IncrementalSync( return to } for roomID := range retiredInvites { - if _, ok := req.Response.Rooms.Join[roomID]; !ok { - lr := types.NewLeaveResponse() - h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...)) - lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{ - // fake event ID which muxes in the to position - EventID: "$" + base64.RawURLEncoding.EncodeToString(h[:]), - OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), - RoomID: roomID, - Sender: req.Device.UserID, - StateKey: &req.Device.UserID, - Type: "m.room.member", - Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`), - }) - req.Response.Rooms.Leave[roomID] = lr + if _, ok := req.Response.Rooms.Invite[roomID]; ok { + continue } + if _, ok := req.Response.Rooms.Join[roomID]; ok { + continue + } + lr := types.NewLeaveResponse() + h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...)) + lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{ + // fake event ID which muxes in the to position + EventID: "$" + base64.RawURLEncoding.EncodeToString(h[:]), + OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), + RoomID: roomID, + Sender: req.Device.UserID, + StateKey: &req.Device.UserID, + Type: "m.room.member", + Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`), + }) + req.Response.Rooms.Leave[roomID] = lr + } return maxID From e57b30172227c4a0b7de15ba635b20921dedda5e Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 21 Oct 2022 10:48:25 +0200 Subject: [PATCH 17/22] Set `display_name` and/or `avatar_url` for server notices (#2820) This should fix #2815 by making sure we actually set the `display_name` and/or `avatar_url` and create the needed membership event. To avoid creating a new membership event when starting Dendrite, `SetAvatarURL` and `SetDisplayName` now return a `Changed` value, which also makes the regular endpoints idempotent. --- clientapi/routing/profile.go | 127 ++++++++-------------- clientapi/routing/routing.go | 2 +- clientapi/routing/server_notices.go | 31 +++++- userapi/api/api.go | 12 +- userapi/api/api_trace.go | 2 +- userapi/internal/api.go | 14 ++- userapi/inthttp/client.go | 2 +- userapi/storage/interface.go | 4 +- userapi/storage/postgres/profile_table.go | 36 ++++-- userapi/storage/shared/storage.go | 16 ++- userapi/storage/sqlite3/profile_table.go | 40 +++++-- userapi/storage/storage_test.go | 24 ++-- userapi/storage/tables/interface.go | 4 +- userapi/userapi_test.go | 9 +- 14 files changed, 191 insertions(+), 132 deletions(-) diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 0685c7352..c9647eb1b 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -19,6 +19,8 @@ import ( "net/http" "time" + "github.com/matrix-org/gomatrixserverlib" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" @@ -27,7 +29,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrix" "github.com/matrix-org/util" @@ -126,20 +127,6 @@ func SetAvatarURL( } } - res := &userapi.QueryProfileResponse{} - err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{ - UserID: userID, - }, res) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed") - return jsonerror.InternalServerError() - } - oldProfile := &authtypes.Profile{ - Localpart: localpart, - DisplayName: res.DisplayName, - AvatarURL: res.AvatarURL, - } - setRes := &userapi.PerformSetAvatarURLResponse{} if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{ Localpart: localpart, @@ -148,41 +135,17 @@ func SetAvatarURL( util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed") return jsonerror.InternalServerError() } - - var roomsRes api.QueryRoomsForUserResponse - err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ - UserID: device.UserID, - WantMembership: "join", - }, &roomsRes) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") - return jsonerror.InternalServerError() - } - - newProfile := authtypes.Profile{ - Localpart: localpart, - DisplayName: oldProfile.DisplayName, - AvatarURL: r.AvatarURL, - } - - events, err := buildMembershipEvents( - req.Context(), roomsRes.RoomIDs, newProfile, userID, cfg, evTime, rsAPI, - ) - switch e := err.(type) { - case nil: - case gomatrixserverlib.BadJSONError: + // No need to build new membership events, since nothing changed + if !setRes.Changed { return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(e.Error()), + Code: http.StatusOK, + JSON: struct{}{}, } - default: - util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed") - return jsonerror.InternalServerError() } - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() + response, err := updateProfile(req.Context(), rsAPI, device, setRes.Profile, userID, cfg, evTime) + if err != nil { + return response } return util.JSONResponse{ @@ -255,47 +218,51 @@ func SetDisplayName( } } - pRes := &userapi.QueryProfileResponse{} - err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{ - UserID: userID, - }, pRes) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed") - return jsonerror.InternalServerError() - } - oldProfile := &authtypes.Profile{ - Localpart: localpart, - DisplayName: pRes.DisplayName, - AvatarURL: pRes.AvatarURL, - } - + profileRes := &userapi.PerformUpdateDisplayNameResponse{} err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{ Localpart: localpart, DisplayName: r.DisplayName, - }, &struct{}{}) + }, profileRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed") return jsonerror.InternalServerError() } + // No need to build new membership events, since nothing changed + if !profileRes.Changed { + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } + } + response, err := updateProfile(req.Context(), rsAPI, device, profileRes.Profile, userID, cfg, evTime) + if err != nil { + return response + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + +func updateProfile( + ctx context.Context, rsAPI api.ClientRoomserverAPI, device *userapi.Device, + profile *authtypes.Profile, + userID string, cfg *config.ClientAPI, evTime time.Time, +) (util.JSONResponse, error) { var res api.QueryRoomsForUserResponse - err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ + err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ UserID: device.UserID, WantMembership: "join", }, &res) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") - return jsonerror.InternalServerError() - } - - newProfile := authtypes.Profile{ - Localpart: localpart, - DisplayName: r.DisplayName, - AvatarURL: oldProfile.AvatarURL, + util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed") + return jsonerror.InternalServerError(), err } events, err := buildMembershipEvents( - req.Context(), res.RoomIDs, newProfile, userID, cfg, evTime, rsAPI, + ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, ) switch e := err.(type) { case nil: @@ -303,21 +270,17 @@ func SetDisplayName( return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(e.Error()), - } + }, e default: - util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed") - return jsonerror.InternalServerError() + util.GetLogger(ctx).WithError(err).Error("buildMembershipEvents failed") + return jsonerror.InternalServerError(), e } - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() - } - - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct{}{}, + if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { + util.GetLogger(ctx).WithError(err).Error("SendEvents failed") + return jsonerror.InternalServerError(), err } + return util.JSONResponse{}, nil } // getProfile gets the full profile of a user by querying the database or a diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index ec5ca899e..4ca8e59c5 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -178,7 +178,7 @@ func Setup( // server notifications if cfg.Matrix.ServerNotices.Enabled { logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") - serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, cfg) + serverNotificationSender, err := getSenderDevice(context.Background(), rsAPI, userAPI, cfg) if err != nil { logrus.WithError(err).Fatal("unable to get account for sending sending server notices") } diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index 7729eddd8..a6a78061d 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -277,6 +277,7 @@ func (r sendServerNoticeRequest) valid() (ok bool) { // It returns an userapi.Device, which is used for building the event func getSenderDevice( ctx context.Context, + rsAPI api.ClientRoomserverAPI, userAPI userapi.ClientUserAPI, cfg *config.ClientAPI, ) (*userapi.Device, error) { @@ -291,16 +292,32 @@ func getSenderDevice( return nil, err } - // set the avatarurl for the user - res := &userapi.PerformSetAvatarURLResponse{} + // Set the avatarurl for the user + avatarRes := &userapi.PerformSetAvatarURLResponse{} if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{ Localpart: cfg.Matrix.ServerNotices.LocalPart, AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, - }, res); err != nil { + }, avatarRes); err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed") return nil, err } + profile := avatarRes.Profile + + // Set the displayname for the user + displayNameRes := &userapi.PerformUpdateDisplayNameResponse{} + if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{ + Localpart: cfg.Matrix.ServerNotices.LocalPart, + DisplayName: cfg.Matrix.ServerNotices.DisplayName, + }, displayNameRes); err != nil { + util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed") + return nil, err + } + + if displayNameRes.Changed { + profile.DisplayName = cfg.Matrix.ServerNotices.DisplayName + } + // Check if we got existing devices deviceRes := &userapi.QueryDevicesResponse{} err = userAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{ @@ -310,7 +327,15 @@ func getSenderDevice( return nil, err } + // We've got an existing account, return the first device of it if len(deviceRes.Devices) > 0 { + // If there were changes to the profile, create a new membership event + if displayNameRes.Changed || avatarRes.Changed { + _, err = updateProfile(ctx, rsAPI, &deviceRes.Devices[0], profile, accRes.Account.UserID, cfg, time.Now()) + if err != nil { + return nil, err + } + } return &deviceRes.Devices[0], nil } diff --git a/userapi/api/api.go b/userapi/api/api.go index 66ee9c7c8..eef29144a 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -96,7 +96,7 @@ type ClientUserAPI interface { PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error - SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error + SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error @@ -579,7 +579,10 @@ type Notification struct { type PerformSetAvatarURLRequest struct { Localpart, AvatarURL string } -type PerformSetAvatarURLResponse struct{} +type PerformSetAvatarURLResponse struct { + Profile *authtypes.Profile `json:"profile"` + Changed bool `json:"changed"` +} type QueryNumericLocalpartResponse struct { ID int64 @@ -606,6 +609,11 @@ type PerformUpdateDisplayNameRequest struct { Localpart, DisplayName string } +type PerformUpdateDisplayNameResponse struct { + Profile *authtypes.Profile `json:"profile"` + Changed bool `json:"changed"` +} + type QueryLocalpartForThreePIDRequest struct { ThreePID, Medium string } diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index 7e2f69615..90834f7e3 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -168,7 +168,7 @@ func (t *UserInternalAPITrace) QueryAccountAvailability(ctx context.Context, req return err } -func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error { +func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error { err := t.Impl.SetDisplayName(ctx, req, res) util.GetLogger(ctx).Infof("SetDisplayName req=%+v res=%+v", js(req), js(res)) return err diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 2f7795dfe..63044eedb 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -170,7 +170,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return nil } - if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { + if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { return err } @@ -813,7 +813,10 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush } func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { - return a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) + profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) + res.Profile = profile + res.Changed = changed + return err } func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error { @@ -847,8 +850,11 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q } } -func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, _ *struct{}) error { - return a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) +func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error { + profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) + res.Profile = profile + res.Changed = changed + return err } func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index a375d6caa..aa5d46d9f 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -388,7 +388,7 @@ func (h *httpUserInternalAPI) QueryAccountByPassword( func (h *httpUserInternalAPI) SetDisplayName( ctx context.Context, request *api.PerformUpdateDisplayNameRequest, - response *struct{}, + response *api.PerformUpdateDisplayNameResponse, ) error { return httputil.CallInternalRPCAPI( "SetDisplayName", h.apiURL+PerformSetDisplayNamePath, diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 02efe7afe..fb12b53af 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -29,8 +29,8 @@ import ( type Profile interface { GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) - SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error - SetDisplayName(ctx context.Context, localpart string, displayName string) error + SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error) } type Account interface { diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index f686127be..2753b23d9 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -44,10 +44,18 @@ const selectProfileByLocalpartSQL = "" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" const setAvatarURLSQL = "" + - "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + "UPDATE userapi_profiles AS new" + + " SET avatar_url = $1" + + " FROM userapi_profiles AS old" + + " WHERE new.localpart = $2" + + " RETURNING new.display_name, old.avatar_url <> new.avatar_url" const setDisplayNameSQL = "" + - "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + "UPDATE userapi_profiles AS new" + + " SET display_name = $1" + + " FROM userapi_profiles AS old" + + " WHERE new.localpart = $2" + + " RETURNING new.avatar_url, old.display_name <> new.display_name" const selectProfilesBySearchSQL = "" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" @@ -100,16 +108,28 @@ func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, -) (err error) { - _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) - return +) (*authtypes.Profile, bool, error) { + profile := &authtypes.Profile{ + Localpart: localpart, + AvatarURL: avatarURL, + } + var changed bool + stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) + err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed) + return profile, changed, err } func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, localpart string, displayName string, -) (err error) { - _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) - return +) (*authtypes.Profile, bool, error) { + profile := &authtypes.Profile{ + Localpart: localpart, + DisplayName: displayName, + } + var changed bool + stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) + err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed) + return profile, changed, err } func (s *profilesStatements) SelectProfilesBySearch( diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 4e28f7b5a..f8b6ad311 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -96,20 +96,24 @@ func (d *Database) GetProfileByLocalpart( // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetAvatarURL( ctx context.Context, localpart string, avatarURL string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) +) (profile *authtypes.Profile, changed bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) + return err }) + return } // SetDisplayName updates the display name of the profile associated with the given // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetDisplayName( ctx context.Context, localpart string, displayName string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) +) (profile *authtypes.Profile, changed bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) + return err }) + return } // SetPassword sets the account password to the given hash. diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index 267daf044..b6130a1e3 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -44,10 +44,12 @@ const selectProfileByLocalpartSQL = "" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" const setAvatarURLSQL = "" + - "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + + " RETURNING display_name" const setDisplayNameSQL = "" + - "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + + " RETURNING avatar_url" const selectProfilesBySearchSQL = "" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" @@ -102,18 +104,40 @@ func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, -) (err error) { +) (*authtypes.Profile, bool, error) { + profile := &authtypes.Profile{ + Localpart: localpart, + AvatarURL: avatarURL, + } + old, err := s.SelectProfileByLocalpart(ctx, localpart) + if err != nil { + return old, false, err + } + if old.AvatarURL == avatarURL { + return old, false, nil + } stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) - _, err = stmt.ExecContext(ctx, avatarURL, localpart) - return + err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName) + return profile, true, err } func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, localpart string, displayName string, -) (err error) { +) (*authtypes.Profile, bool, error) { + profile := &authtypes.Profile{ + Localpart: localpart, + DisplayName: displayName, + } + old, err := s.SelectProfileByLocalpart(ctx, localpart) + if err != nil { + return old, false, err + } + if old.DisplayName == displayName { + return old, false, nil + } stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) - _, err = stmt.ExecContext(ctx, displayName, localpart) - return + err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL) + return profile, true, err } func (s *profilesStatements) SelectProfilesBySearch( diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 8e5b32b6a..354f085fc 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -382,15 +382,23 @@ func Test_Profile(t *testing.T) { // set avatar & displayname wantProfile.DisplayName = "Alice" - wantProfile.AvatarURL = "mxc://aliceAvatar" - err = db.SetDisplayName(ctx, aliceLocalpart, "Alice") - assert.NoError(t, err, "unable to set displayname") - err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") - assert.NoError(t, err, "unable to set avatar url") - // verify profile - gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart) - assert.NoError(t, err, "unable to get profile by localpart") + gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice") assert.Equal(t, wantProfile, gotProfile) + assert.NoError(t, err, "unable to set displayname") + assert.True(t, changed) + + wantProfile.AvatarURL = "mxc://aliceAvatar" + gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + assert.NoError(t, err, "unable to set avatar url") + assert.Equal(t, wantProfile, gotProfile) + assert.True(t, changed) + + // Setting the same avatar again doesn't change anything + wantProfile.AvatarURL = "mxc://aliceAvatar" + gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + assert.NoError(t, err, "unable to set avatar url") + assert.Equal(t, wantProfile, gotProfile) + assert.False(t, changed) // search profiles searchRes, err := db.SearchProfiles(ctx, "Alice", 2) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index cc4287997..1b239e442 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -84,8 +84,8 @@ type OpenIDTable interface { type ProfileTable interface { InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) - SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error) - SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error) + SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error) SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 4417f4dc0..aaa93f45b 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -23,13 +23,14 @@ import ( "time" "github.com/gorilla/mux" + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi/inthttp" - "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" @@ -83,10 +84,10 @@ func TestQueryProfile(t *testing.T) { if err != nil { t.Fatalf("failed to make account: %s", err) } - if err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil { + if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil { t.Fatalf("failed to set avatar url: %s", err) } - if err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil { + if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil { t.Fatalf("failed to set display name: %s", err) } From e98d75fd63103243c5af2a63f2f547e4300adc4d Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 21 Oct 2022 10:15:08 +0100 Subject: [PATCH 18/22] Verify `room_id`, `type`, `sender` and `state_key` field lengths using bytes rather than codepoints (update to matrix-org/gomatrixserverlib@7c772f1, reverts bbb3ade4a2b49cfdaf7ec86ddf079ff7d48e0cf3) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 2248e73c6..7f9bb3897 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a + github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index a141fc9b4..5cce7e0d8 100644 --- a/go.sum +++ b/go.sum @@ -387,8 +387,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a h1:bQKHk3AWlgm7XhzPhuU3Iw3pUptW5l1DR/1y0o7zCKQ= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221018085104-a72a83f0e19a/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a h1:6rJFN5NBuzZ7h5meYkLtXKa6VFZfDc8oVXHd4SDXr5o= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 h1:lzkSQvBv8TuqKJCPoVwOVvEnARTlua5rrNy/Qw2Vxeo= github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= From 9e4c3171da4e2d6d7b95731e702891513d081b49 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 21 Oct 2022 12:50:51 +0200 Subject: [PATCH 19/22] Optimize inserting pending PDUs/EDUs (#2821) This optimizes the association of PDUs/EDUs to their destination by inserting all destinations in one transaction. --- federationapi/queue/destinationqueue.go | 88 ++++++-------------- federationapi/queue/queue.go | 35 +++++++- federationapi/queue/queue_test.go | 60 +++++++------ federationapi/storage/interface.go | 7 +- federationapi/storage/shared/storage_edus.go | 25 +++--- federationapi/storage/shared/storage_pdus.go | 24 +++--- federationapi/storage/storage_test.go | 7 +- 7 files changed, 127 insertions(+), 119 deletions(-) diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 768ed1f2b..1b7670e9a 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -76,40 +76,22 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re return } - // Create a database entry that associates the given PDU NID with - // this destination queue. We'll then be able to retrieve the PDU - // later. - if err := oq.db.AssociatePDUWithDestination( - oq.process.Context(), - "", // TODO: remove this, as we don't need to persist the transaction ID - oq.destination, // the destination server name - receipt, // NIDs from federationapi_queue_json table - ); err != nil { - logrus.WithError(err).Errorf("failed to associate PDU %q with destination %q", event.EventID(), oq.destination) - return - } - // Check if the destination is blacklisted. If it isn't then wake - // up the queue. - if !oq.statistics.Blacklisted() { - // If there's room in memory to hold the event then add it to the - // list. - oq.pendingMutex.Lock() - if len(oq.pendingPDUs) < maxPDUsInMemory { - oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{ - pdu: event, - receipt: receipt, - }) - } else { - oq.overflowed.Store(true) - } - oq.pendingMutex.Unlock() - - if !oq.backingOff.Load() { - oq.wakeQueueAndNotify() - } + // If there's room in memory to hold the event then add it to the + // list. + oq.pendingMutex.Lock() + if len(oq.pendingPDUs) < maxPDUsInMemory { + oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{ + pdu: event, + receipt: receipt, + }) } else { oq.overflowed.Store(true) } + oq.pendingMutex.Unlock() + + if !oq.backingOff.Load() { + oq.wakeQueueAndNotify() + } } // sendEDU adds the EDU event to the pending queue for the destination. @@ -120,41 +102,23 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination) return } - // Create a database entry that associates the given PDU NID with - // this destination queue. We'll then be able to retrieve the PDU - // later. - if err := oq.db.AssociateEDUWithDestination( - oq.process.Context(), - oq.destination, // the destination server name - receipt, // NIDs from federationapi_queue_json table - event.Type, - nil, // this will use the default expireEDUTypes map - ); err != nil { - logrus.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination) - return - } - // Check if the destination is blacklisted. If it isn't then wake - // up the queue. - if !oq.statistics.Blacklisted() { - // If there's room in memory to hold the event then add it to the - // list. - oq.pendingMutex.Lock() - if len(oq.pendingEDUs) < maxEDUsInMemory { - oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{ - edu: event, - receipt: receipt, - }) - } else { - oq.overflowed.Store(true) - } - oq.pendingMutex.Unlock() - if !oq.backingOff.Load() { - oq.wakeQueueAndNotify() - } + // If there's room in memory to hold the event then add it to the + // list. + oq.pendingMutex.Lock() + if len(oq.pendingEDUs) < maxEDUsInMemory { + oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{ + edu: event, + receipt: receipt, + }) } else { oq.overflowed.Store(true) } + oq.pendingMutex.Unlock() + + if !oq.backingOff.Load() { + oq.wakeQueueAndNotify() + } } // handleBackoffNotifier is registered as the backoff notification diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 68f789e37..328334379 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -24,6 +24,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -247,11 +248,25 @@ func (oqs *OutgoingQueues) SendEvent( } for destination := range destmap { - if queue := oqs.getQueue(destination); queue != nil { + if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() { queue.sendEvent(ev, nid) + } else { + delete(destmap, destination) } } + // Create a database entry that associates the given PDU NID with + // this destinations queue. We'll then be able to retrieve the PDU + // later. + if err := oqs.db.AssociatePDUWithDestinations( + oqs.process.Context(), + destmap, + nid, // NIDs from federationapi_queue_json table + ); err != nil { + logrus.WithError(err).Errorf("failed to associate PDUs %q with destinations", nid) + return err + } + return nil } @@ -321,11 +336,27 @@ func (oqs *OutgoingQueues) SendEDU( } for destination := range destmap { - if queue := oqs.getQueue(destination); queue != nil { + if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() { queue.sendEDU(e, nid) + } else { + delete(destmap, destination) } } + // Create a database entry that associates the given PDU NID with + // this destination queue. We'll then be able to retrieve the PDU + // later. + if err := oqs.db.AssociateEDUWithDestinations( + oqs.process.Context(), + destmap, // the destination server name + nid, // NIDs from federationapi_queue_json table + e.Type, + nil, // this will use the default expireEDUTypes map + ); err != nil { + logrus.WithError(err).Errorf("failed to associate EDU with destinations") + return err + } + return nil } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index 40419b91f..a1b280103 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -25,6 +25,10 @@ import ( "go.uber.org/atomic" "gotest.tools/v3/poll" + "github.com/matrix-org/gomatrixserverlib" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" @@ -34,9 +38,6 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" - "github.com/matrix-org/gomatrixserverlib" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" ) func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *process.ProcessContext, func()) { @@ -158,30 +159,36 @@ func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixse return edus, nil } -func (d *fakeDatabase) AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error { +func (d *fakeDatabase) AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() if _, ok := d.pendingPDUs[receipt]; ok { - if _, ok := d.associatedPDUs[serverName]; !ok { - d.associatedPDUs[serverName] = make(map[*shared.Receipt]struct{}) + for destination := range destinations { + if _, ok := d.associatedPDUs[destination]; !ok { + d.associatedPDUs[destination] = make(map[*shared.Receipt]struct{}) + } + d.associatedPDUs[destination][receipt] = struct{}{} } - d.associatedPDUs[serverName][receipt] = struct{}{} + return nil } else { return errors.New("PDU doesn't exist") } } -func (d *fakeDatabase) AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { +func (d *fakeDatabase) AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() if _, ok := d.pendingEDUs[receipt]; ok { - if _, ok := d.associatedEDUs[serverName]; !ok { - d.associatedEDUs[serverName] = make(map[*shared.Receipt]struct{}) + for destination := range destinations { + if _, ok := d.associatedEDUs[destination]; !ok { + d.associatedEDUs[destination] = make(map[*shared.Receipt]struct{}) + } + d.associatedEDUs[destination][receipt] = struct{}{} } - d.associatedEDUs[serverName][receipt] = struct{}{} + return nil } else { return errors.New("EDU doesn't exist") @@ -821,15 +828,15 @@ func TestSendPDUBatches(t *testing.T) { <-pc.WaitForShutdown() }() + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} // Populate database with > maxPDUsPerTransaction pduMultiplier := uint32(3) for i := 0; i < maxPDUsPerTransaction*int(pduMultiplier); i++ { ev := mustCreatePDU(t) headeredJSON, _ := json.Marshal(ev) nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) - now := gomatrixserverlib.AsTimestamp(time.Now()) - transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, i)) - db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid) + err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid) + assert.NoError(t, err, "failed to associate PDU with destinations") } ev := mustCreatePDU(t) @@ -865,13 +872,15 @@ func TestSendEDUBatches(t *testing.T) { <-pc.WaitForShutdown() }() + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} // Populate database with > maxEDUsPerTransaction eduMultiplier := uint32(3) for i := 0; i < maxEDUsPerTransaction*int(eduMultiplier); i++ { ev := mustCreateEDU(t) ephemeralJSON, _ := json.Marshal(ev) nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) - db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil) + err := db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, ev.Type, nil) + assert.NoError(t, err, "failed to associate EDU with destinations") } ev := mustCreateEDU(t) @@ -907,23 +916,23 @@ func TestSendPDUAndEDUBatches(t *testing.T) { <-pc.WaitForShutdown() }() + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} // Populate database with > maxEDUsPerTransaction multiplier := uint32(3) - for i := 0; i < maxPDUsPerTransaction*int(multiplier)+1; i++ { ev := mustCreatePDU(t) headeredJSON, _ := json.Marshal(ev) nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) - now := gomatrixserverlib.AsTimestamp(time.Now()) - transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, i)) - db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid) + err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid) + assert.NoError(t, err, "failed to associate PDU with destinations") } for i := 0; i < maxEDUsPerTransaction*int(multiplier); i++ { ev := mustCreateEDU(t) ephemeralJSON, _ := json.Marshal(ev) nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) - db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil) + err := db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, ev.Type, nil) + assert.NoError(t, err, "failed to associate EDU with destinations") } ev := mustCreateEDU(t) @@ -960,13 +969,12 @@ func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) { dest := queues.getQueue(destination) queues.statistics.ForServer(destination).Failure() - + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} ev := mustCreatePDU(t) headeredJSON, _ := json.Marshal(ev) nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) - now := gomatrixserverlib.AsTimestamp(time.Now()) - transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, 1)) - db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid) + err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid) + assert.NoError(t, err, "failed to associate PDU with destinations") pollEnd := time.Now().Add(3 * time.Second) runningCheck := func(log poll.LogT) poll.Result { @@ -988,6 +996,7 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) destination := gomatrixserverlib.ServerName("remotehost") + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true) // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up. @@ -1009,7 +1018,8 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { edu := mustCreateEDU(t) ephemeralJSON, _ := json.Marshal(edu) nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) - db.AssociateEDUWithDestination(pc.Context(), destination, nid, edu.Type, nil) + err = db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, edu.Type, nil) + assert.NoError(t, err, "failed to associate EDU with destinations") checkBlacklisted := func(log poll.LogT) poll.Result { if fc.txCount.Load() == failuresUntilBlacklist { diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index b8109b432..09098cd1e 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -18,9 +18,10 @@ import ( "context" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/types" - "github.com/matrix-org/gomatrixserverlib" ) type Database interface { @@ -38,8 +39,8 @@ type Database interface { GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) - AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error - AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error + AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error + AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index e0c740c11..c796d2f8f 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -38,9 +38,9 @@ var defaultExpireEDUTypes = map[string]time.Duration{ // AssociateEDUWithDestination creates an association that the // destination queues will use to determine which JSON blobs to send // to which servers. -func (d *Database) AssociateEDUWithDestination( +func (d *Database) AssociateEDUWithDestinations( ctx context.Context, - serverName gomatrixserverlib.ServerName, + destinations map[gomatrixserverlib.ServerName]struct{}, receipt *Receipt, eduType string, expireEDUTypes map[string]time.Duration, @@ -59,17 +59,18 @@ func (d *Database) AssociateEDUWithDestination( expiresAt = 0 } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.FederationQueueEDUs.InsertQueueEDU( - ctx, // context - txn, // SQL transaction - eduType, // EDU type for coalescing - serverName, // destination server name - receipt.nid, // NID from the federationapi_queue_json table - expiresAt, // The timestamp this EDU will expire - ); err != nil { - return fmt.Errorf("InsertQueueEDU: %w", err) + var err error + for destination := range destinations { + err = d.FederationQueueEDUs.InsertQueueEDU( + ctx, // context + txn, // SQL transaction + eduType, // EDU type for coalescing + destination, // destination server name + receipt.nid, // NID from the federationapi_queue_json table + expiresAt, // The timestamp this EDU will expire + ) } - return nil + return err }) } diff --git a/federationapi/storage/shared/storage_pdus.go b/federationapi/storage/shared/storage_pdus.go index 5a12c388a..dc37d7507 100644 --- a/federationapi/storage/shared/storage_pdus.go +++ b/federationapi/storage/shared/storage_pdus.go @@ -27,23 +27,23 @@ import ( // AssociatePDUWithDestination creates an association that the // destination queues will use to determine which JSON blobs to send // to which servers. -func (d *Database) AssociatePDUWithDestination( +func (d *Database) AssociatePDUWithDestinations( ctx context.Context, - transactionID gomatrixserverlib.TransactionID, - serverName gomatrixserverlib.ServerName, + destinations map[gomatrixserverlib.ServerName]struct{}, receipt *Receipt, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.FederationQueuePDUs.InsertQueuePDU( - ctx, // context - txn, // SQL transaction - transactionID, // transaction ID - serverName, // destination server name - receipt.nid, // NID from the federationapi_queue_json table - ); err != nil { - return fmt.Errorf("InsertQueuePDU: %w", err) + var err error + for destination := range destinations { + err = d.FederationQueuePDUs.InsertQueuePDU( + ctx, // context + txn, // SQL transaction + "", // transaction ID + destination, // destination server name + receipt.nid, // NID from the federationapi_queue_json table + ) } - return nil + return err }) } diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 3b0268e55..6272fd2b1 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -35,6 +35,7 @@ func TestExpireEDUs(t *testing.T) { } ctx := context.Background() + destinations := map[gomatrixserverlib.ServerName]struct{}{"localhost": {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateFederationDatabase(t, dbType) defer close() @@ -43,7 +44,7 @@ func TestExpireEDUs(t *testing.T) { receipt, err := db.StoreJSON(ctx, "{}") assert.NoError(t, err) - err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MReceipt, expireEDUTypes) + err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MReceipt, expireEDUTypes) assert.NoError(t, err) } // add data without expiry @@ -51,7 +52,7 @@ func TestExpireEDUs(t *testing.T) { assert.NoError(t, err) // m.read_marker gets the default expiry of 24h, so won't be deleted further down in this test - err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, "m.read_marker", expireEDUTypes) + err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, "m.read_marker", expireEDUTypes) assert.NoError(t, err) // Delete expired EDUs @@ -67,7 +68,7 @@ func TestExpireEDUs(t *testing.T) { receipt, err = db.StoreJSON(ctx, "{}") assert.NoError(t, err) - err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes) + err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes) assert.NoError(t, err) err = db.DeleteExpiredEDUs(ctx) From 3cf42a1d64712f057fde0a5a4b3db1cf33ca432d Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 21 Oct 2022 12:53:04 +0200 Subject: [PATCH 20/22] Add `syncapi_memberships` table tests (#2805) --- syncapi/storage/tables/memberships_test.go | 198 +++++++++++++++++++++ 1 file changed, 198 insertions(+) create mode 100644 syncapi/storage/tables/memberships_test.go diff --git a/syncapi/storage/tables/memberships_test.go b/syncapi/storage/tables/memberships_test.go new file mode 100644 index 000000000..0cee7f5a5 --- /dev/null +++ b/syncapi/storage/tables/memberships_test.go @@ -0,0 +1,198 @@ +package tables_test + +import ( + "context" + "database/sql" + "reflect" + "sort" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/postgres" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" +) + +func newMembershipsTable(t *testing.T, dbType test.DBType) (tables.Memberships, *sql.DB, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open db: %s", err) + } + + var tab tables.Memberships + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresMembershipsTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSqliteMembershipsTable(db) + } + if err != nil { + t.Fatalf("failed to make new table: %s", err) + } + return tab, db, close +} + +func TestMembershipsTable(t *testing.T) { + + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + // Create users + var userEvents []*gomatrixserverlib.HeaderedEvent + users := []string{alice.ID} + for _, x := range room.CurrentState() { + if x.StateKeyEquals(alice.ID) { + if _, err := x.Membership(); err == nil { + userEvents = append(userEvents, x) + break + } + } + } + + if len(userEvents) == 0 { + t.Fatalf("didn't find creator membership event") + } + + for i := 0; i < 10; i++ { + u := test.NewUser(t) + users = append(users, u.ID) + + ev := room.CreateAndInsert(t, u, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(u.ID)) + userEvents = append(userEvents, ev) + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + table, _, close := newMembershipsTable(t, dbType) + defer close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + for _, ev := range userEvents { + if err := table.UpsertMembership(ctx, nil, ev, types.StreamPosition(ev.Depth()), 1); err != nil { + t.Fatalf("failed to upsert membership: %s", err) + } + } + + testUpsert(t, ctx, table, userEvents[0], alice, room) + testMembershipCount(t, ctx, table, room) + testHeroes(t, ctx, table, alice, room, users) + }) +} + +func testHeroes(t *testing.T, ctx context.Context, table tables.Memberships, user *test.User, room *test.Room, users []string) { + + // Re-slice and sort the expected users + users = users[1:] + sort.Strings(users) + type testCase struct { + name string + memberships []string + wantHeroes []string + } + + testCases := []testCase{ + {name: "no memberships queried", memberships: []string{}}, + {name: "joined memberships queried should be limited", memberships: []string{gomatrixserverlib.Join}, wantHeroes: users[:5]}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := table.SelectHeroes(ctx, nil, room.ID, user.ID, tc.memberships) + if err != nil { + t.Fatalf("unable to select heroes: %s", err) + } + if gotLen := len(got); gotLen != len(tc.wantHeroes) { + t.Fatalf("expected %d heroes, got %d", len(tc.wantHeroes), gotLen) + } + + if !reflect.DeepEqual(got, tc.wantHeroes) { + t.Fatalf("expected heroes to be %+v, got %+v", tc.wantHeroes, got) + } + }) + } +} + +func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) { + t.Run("membership counts are correct", func(t *testing.T) { + // After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users) + count, err := table.SelectMembershipCount(ctx, nil, room.ID, gomatrixserverlib.Join, 10) + if err != nil { + t.Fatalf("failed to get membership count: %s", err) + } + expectedCount := 6 + if expectedCount != count { + t.Fatalf("expected member count to be %d, got %d", expectedCount, count) + } + + // After 100 events, we should have all 11 users + count, err = table.SelectMembershipCount(ctx, nil, room.ID, gomatrixserverlib.Join, 100) + if err != nil { + t.Fatalf("failed to get membership count: %s", err) + } + expectedCount = 11 + if expectedCount != count { + t.Fatalf("expected member count to be %d, got %d", expectedCount, count) + } + }) +} + +func testUpsert(t *testing.T, ctx context.Context, table tables.Memberships, membershipEvent *gomatrixserverlib.HeaderedEvent, user *test.User, room *test.Room) { + t.Run("upserting works as expected", func(t *testing.T) { + if err := table.UpsertMembership(ctx, nil, membershipEvent, 1, 1); err != nil { + t.Fatalf("failed to upsert membership: %s", err) + } + membership, pos, err := table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 1) + if err != nil { + t.Fatalf("failed to select membership: %s", err) + } + expectedPos := 1 + if pos != expectedPos { + t.Fatalf("expected pos to be %d, got %d", expectedPos, pos) + } + if membership != gomatrixserverlib.Join { + t.Fatalf("expected membership to be join, got %s", membership) + } + // Create a new event which gets upserted and should not cause issues + ev := room.CreateAndInsert(t, user, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": gomatrixserverlib.Join, + }, test.WithStateKey(user.ID)) + // Insert the same event again, but with different positions, which should get updated + if err = table.UpsertMembership(ctx, nil, ev, 2, 2); err != nil { + t.Fatalf("failed to upsert membership: %s", err) + } + + // Verify the position got updated + membership, pos, err = table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 10) + if err != nil { + t.Fatalf("failed to select membership: %s", err) + } + expectedPos = 2 + if pos != expectedPos { + t.Fatalf("expected pos to be %d, got %d", expectedPos, pos) + } + if membership != gomatrixserverlib.Join { + t.Fatalf("expected membership to be join, got %s", membership) + } + + // If we can't find a membership, it should default to leave + if membership, _, err = table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 1); err != nil { + t.Fatalf("failed to select membership: %s", err) + } + if membership != gomatrixserverlib.Leave { + t.Fatalf("expected membership to be leave, got %s", membership) + } + }) +} From 411db6083b8257bfe96663e6bb7ce763609216fa Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 21 Oct 2022 15:00:51 +0100 Subject: [PATCH 21/22] Version 0.10.4 (#2822) Changelog and version bump. --- CHANGES.md | 18 ++++++++++++++++++ internal/version.go | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index eea2c3c7c..1ed87824a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,23 @@ # Changelog +## Dendrite 0.10.4 (2022-10-21) + +### Features + +* Various tables belonging to the user API will be renamed so that they are namespaced with the `userapi_` prefix + * Note that, after upgrading to this version, you should not revert to an older version of Dendrite as the database changes **will not** be reverted automatically +* The backoff and retry behaviour in the federation API has been refactored and improved + +### Fixes + +* Private read receipt support is now advertised in the client `/versions` endpoint +* Private read receipts will now clear notification counts properly +* A bug where a false `leave` membership transition was inserted into the timeline after accepting an invite has been fixed +* Some panics caused by concurrent map writes in the key server have been fixed +* The sync API now calculates membership transitions from state deltas more accurately +* Transaction IDs are now scoped to endpoints, which should fix some bugs where transaction ID reuse could cause nonsensical cached responses from some endpoints +* The length of the `type`, `sender`, `state_key` and `room_id` fields in events are now verified by number of bytes rather than codepoints after a spec clarification, reverting a change made in Dendrite 0.9.6 + ## Dendrite 0.10.3 (2022-10-14) ### Features diff --git a/internal/version.go b/internal/version.go index c888748a8..5d739a45d 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 10 - VersionPatch = 3 + VersionPatch = 4 VersionTag = "" // example: "rc1" ) From fb14a2d91791dd9a4fd55d7fffbe43d08b8cae81 Mon Sep 17 00:00:00 2001 From: danielaloni Date: Thu, 27 Oct 2022 16:45:42 +0300 Subject: [PATCH 22/22] =?UTF-8?q?=F0=9F=90=9B=20account=5Faccounts=20table?= =?UTF-8?q?=20renamed=20to=20userapi=5Faccounts.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- userapi/storage/postgres/deltas/2022080800000000_no_guests.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/userapi/storage/postgres/deltas/2022080800000000_no_guests.go b/userapi/storage/postgres/deltas/2022080800000000_no_guests.go index cc6126aad..9985fd822 100644 --- a/userapi/storage/postgres/deltas/2022080800000000_no_guests.go +++ b/userapi/storage/postgres/deltas/2022080800000000_no_guests.go @@ -8,7 +8,7 @@ import ( func UpNoGuests(ctx context.Context, tx *sql.Tx) error { // AddAccountType introduced a bug where each user that had was registered as a regular user, but without user_id, became a guest. - _, err := tx.ExecContext(ctx, "UPDATE account_accounts SET account_type = 1 WHERE account_type = 2;") + _, err := tx.ExecContext(ctx, "UPDATE userapi_accounts SET account_type = 1 WHERE account_type = 2;") if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) }