From b00dc9dc8378382215070e738ae0df282283db1e Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Thu, 6 Aug 2020 19:13:33 +0100 Subject: [PATCH] Add stale_device_lists table and use db before asking remote for device keys --- keyserver/internal/device_list_update.go | 1 + keyserver/internal/internal.go | 33 +++++ .../storage/postgres/stale_device_lists.go | 118 ++++++++++++++++++ keyserver/storage/postgres/storage.go | 13 +- keyserver/storage/shared/storage.go | 13 +- .../storage/sqlite3/stale_device_lists.go | 118 ++++++++++++++++++ keyserver/storage/sqlite3/storage.go | 13 +- keyserver/storage/tables/interface.go | 6 + 8 files changed, 301 insertions(+), 14 deletions(-) create mode 100644 keyserver/storage/postgres/stale_device_lists.go create mode 100644 keyserver/storage/sqlite3/stale_device_lists.go diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 279da65aa..ac6c89998 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -87,6 +87,7 @@ type DeviceListUpdaterDatabase interface { PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) } +// KeyChangeProducer is the interface for producers.KeyChange useful for testing. type KeyChangeProducer interface { ProduceKeyChanges(keys []api.DeviceMessage) error } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index ff298c07c..7ce25d428 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -263,10 +263,43 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } // TODO: set device display names when they are known + // attempt to satisfy key queries from the local database first as we should get device updates pushed to us + domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys) + if len(domainToDeviceKeys) == 0 { + return // nothing to query + } + // perform key queries for remote devices a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) } +func (a *KeyInternalAPI) remoteKeysFromDatabase( + ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, +) map[string]map[string][]string { + fetchRemote := make(map[string]map[string][]string) + for domain, userToDeviceMap := range domainToDeviceKeys { + for userID, deviceIDs := range userToDeviceMap { + keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) + // if we can't query the db or there are fewer keys than requested, fetch from remote. + // NB: requesting all keys (deviceIDs==0) will not trigger this, provided some devices exist. + if err != nil || len(keys) < len(deviceIDs) { + if _, ok := fetchRemote[domain]; !ok { + fetchRemote[domain] = make(map[string][]string) + } + fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...) + continue + } + if res.DeviceKeys[userID] == nil { + res.DeviceKeys[userID] = make(map[string]json.RawMessage) + } + for _, key := range keys { + res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON + } + } + } + return fetchRemote +} + func (a *KeyInternalAPI) queryRemoteKeys( ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, ) { diff --git a/keyserver/storage/postgres/stale_device_lists.go b/keyserver/storage/postgres/stale_device_lists.go new file mode 100644 index 000000000..63281adfb --- /dev/null +++ b/keyserver/storage/postgres/stale_device_lists.go @@ -0,0 +1,118 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +var staleDeviceListsSchema = ` +-- Stores whether a user's device lists are stale or not. +CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( + user_id TEXT PRIMARY KEY NOT NULL, + domain TEXT NOT NULL, + is_stale BOOLEAN NOT NULL, + ts_added_secs BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); +` + +const upsertStaleDeviceListSQL = "" + + "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (user_id)" + + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" + +const selectStaleDeviceListsWithDomainsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" + +const selectStaleDeviceListsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" + +type staleDeviceListsStatements struct { + upsertStaleDeviceListStmt *sql.Stmt + selectStaleDeviceListsWithDomainsStmt *sql.Stmt + selectStaleDeviceListsStmt *sql.Stmt +} + +func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{} + _, err := db.Exec(staleDeviceListsSchema) + if err != nil { + return nil, err + } + if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + return err +} + +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + // we only query for 1 domain or all domains so optimise for those use cases + if len(domains) == 0 { + rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) + if err != nil { + return nil, err + } + return rowsToUserIDs(ctx, rows) + } + var result []string + for _, domain := range domains { + rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) + if err != nil { + return nil, err + } + userIDs, err := rowsToUserIDs(ctx, rows) + if err != nil { + return nil, err + } + result = append(result, userIDs...) + } + return result, nil +} + +func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { + defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go index a1d1c0feb..de2fabfdf 100644 --- a/keyserver/storage/postgres/storage.go +++ b/keyserver/storage/postgres/storage.go @@ -38,10 +38,15 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*s if err != nil { return nil, err } + sdl, err := NewPostgresStaleDeviceListsTable(db) + if err != nil { + return nil, err + } return &shared.Database{ - DB: db, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, + DB: db, + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, }, nil } diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 68964be67..4279eae77 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -26,10 +26,11 @@ import ( ) type Database struct { - DB *sql.DB - OneTimeKeysTable tables.OneTimeKeys - DeviceKeysTable tables.DeviceKeys - KeyChangesTable tables.KeyChanges + DB *sql.DB + OneTimeKeysTable tables.OneTimeKeys + DeviceKeysTable tables.DeviceKeys + KeyChangesTable tables.KeyChanges + StaleDeviceListsTable tables.StaleDeviceLists } func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { @@ -129,10 +130,10 @@ func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset, // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { - return nil, nil // TODO + return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) } // MarkDeviceListStale sets the stale bit for this user to isStale. func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { - return nil // TODO + return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) } diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go new file mode 100644 index 000000000..a989476d1 --- /dev/null +++ b/keyserver/storage/sqlite3/stale_device_lists.go @@ -0,0 +1,118 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +var staleDeviceListsSchema = ` +-- Stores whether a user's device lists are stale or not. +CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( + user_id TEXT PRIMARY KEY NOT NULL, + domain TEXT NOT NULL, + is_stale BOOLEAN NOT NULL, + ts_added_secs BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); +` + +const upsertStaleDeviceListSQL = "" + + "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (user_id)" + + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" + +const selectStaleDeviceListsWithDomainsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" + +const selectStaleDeviceListsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" + +type staleDeviceListsStatements struct { + upsertStaleDeviceListStmt *sql.Stmt + selectStaleDeviceListsWithDomainsStmt *sql.Stmt + selectStaleDeviceListsStmt *sql.Stmt +} + +func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{} + _, err := db.Exec(staleDeviceListsSchema) + if err != nil { + return nil, err + } + if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + return err +} + +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + // we only query for 1 domain or all domains so optimise for those use cases + if len(domains) == 0 { + rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) + if err != nil { + return nil, err + } + return rowsToUserIDs(ctx, rows) + } + var result []string + for _, domain := range domains { + rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) + if err != nil { + return nil, err + } + userIDs, err := rowsToUserIDs(ctx, rows) + if err != nil { + return nil, err + } + result = append(result, userIDs...) + } + return result, nil +} + +func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { + defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go index f9771cf16..bbfd1e793 100644 --- a/keyserver/storage/sqlite3/storage.go +++ b/keyserver/storage/sqlite3/storage.go @@ -41,10 +41,15 @@ func NewDatabase(dataSourceName string) (*shared.Database, error) { if err != nil { return nil, err } + sdl, err := NewSqliteStaleDeviceListsTable(db) + if err != nil { + return nil, err + } return &shared.Database{ - DB: db, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, + DB: db, + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, }, nil } diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index ac932d56d..a4d5dede2 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -20,6 +20,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/gomatrixserverlib" ) type OneTimeKeys interface { @@ -45,3 +46,8 @@ type KeyChanges interface { // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset. SelectKeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) } + +type StaleDeviceLists interface { + InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error + SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) +}