diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index 35852e8ac..3241234ac 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -189,7 +189,9 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir, if err := decoder.Decode(&dl); err != nil { return "", fmt.Errorf("failed to decode build image output line: %w", err) } - log.Printf("%s: %s", branchOrTagName, dl.Stream) + if len(strings.TrimSpace(dl.Stream)) > 0 { + log.Printf("%s: %s", branchOrTagName, dl.Stream) + } if dl.Aux != nil { imgID, ok := dl.Aux["ID"] if ok { @@ -425,8 +427,10 @@ func cleanup(dockerClient *client.Client) { // ignore all errors, we are just cleaning up and don't want to fail just because we fail to cleanup containers, _ := dockerClient.ContainerList(context.Background(), types.ContainerListOptions{ Filters: label(dendriteUpgradeTestLabel), + All: true, }) for _, c := range containers { + log.Printf("Removing container: %v %v\n", c.ID, c.Names) s := time.Second _ = dockerClient.ContainerStop(context.Background(), c.ID, &s) _ = dockerClient.ContainerRemove(context.Background(), c.ID, types.ContainerRemoveOptions{ diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 7593988f3..38b146d70 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -368,6 +368,7 @@ logging: - type: std level: info - type: file + # The logging level, must be one of debug, info, warn, error, fatal, panic. level: info params: path: ./logs diff --git a/docs/nginx/monolith-sample.conf b/docs/nginx/monolith-sample.conf index 0344aa96c..360eb9255 100644 --- a/docs/nginx/monolith-sample.conf +++ b/docs/nginx/monolith-sample.conf @@ -1,3 +1,7 @@ +#change IP to location of monolith server +upstream monolith{ + server 127.0.0.1:8008; +} server { listen 443 ssl; # IPv4 listen [::]:443 ssl; # IPv6 @@ -23,6 +27,6 @@ server { } location /_matrix { - proxy_pass http://monolith:8008; + proxy_pass http://monolith; } } diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 5a109cc65..0eea2f0fa 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -69,7 +69,8 @@ type DeviceMessage struct { *DeviceKeys `json:"DeviceKeys,omitempty"` *eduapi.OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"` // A monotonically increasing number which represents device changes for this user. - StreamID int + StreamID int + DeviceChangeID int64 } // DeviceKeys represents a set of device keys for a single device @@ -224,8 +225,6 @@ type QueryKeysResponse struct { } type QueryKeyChangesRequest struct { - // The partition which had key events sent to - Partition int32 // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning Offset int64 // The inclusive offset where to track key changes up to. Messages with this offset are included in the response. @@ -236,8 +235,6 @@ type QueryKeyChangesRequest struct { type QueryKeyChangesResponse struct { // The set of users who have had their keys change. UserIDs []string - // The partition being served - useful if the partition is unknown at request time - Partition int32 // The latest offset represented in this response. Offset int64 // Set if there was a problem handling the request. diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 3e91962ed..259249217 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -59,17 +59,13 @@ func (a *KeyInternalAPI) InputDeviceListUpdate( } func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) { - if req.Partition < 0 { - req.Partition = a.Producer.DefaultPartition() - } - userIDs, latest, err := a.DB.KeyChanges(ctx, req.Partition, req.Offset, req.ToOffset) + userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset) if err != nil { res.Error = &api.KeyError{ Err: err.Error(), } } res.Offset = latest - res.Partition = req.Partition res.UserIDs = userIDs } diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 03a221a6c..8cc50ea0d 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -40,16 +40,16 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { func NewInternalAPI( base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient, ) api.KeyInternalAPI { - _, consumer, producer := jetstream.Prepare(&cfg.Matrix.JetStream) + js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream) db, err := storage.NewDatabase(&cfg.Database) if err != nil { logrus.WithError(err).Panicf("failed to connect to key server database") } keyChangeProducer := &producers.KeyChange{ - Topic: string(cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent)), - Producer: producer, - DB: db, + Topic: string(cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent)), + JetStream: js, + DB: db, } ap := &internal.KeyInternalAPI{ DB: db, diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go index 782675c2a..fd143c6cf 100644 --- a/keyserver/producers/keychange.go +++ b/keyserver/producers/keychange.go @@ -18,52 +18,47 @@ import ( "context" "encoding/json" - "github.com/Shopify/sarama" eduapi "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) // KeyChange produces key change events for the sync API and federation sender to consume type KeyChange struct { - Topic string - Producer sarama.SyncProducer - DB storage.Database -} - -// DefaultPartition returns the default partition this process is sending key changes to. -// NB: A keyserver MUST send key changes to only 1 partition or else query operations will -// become inconsistent. Partitions can be sharded (e.g by hash of user ID of key change) but -// then all keyservers must be queried to calculate the entire set of key changes between -// two sync tokens. -func (p *KeyChange) DefaultPartition() int32 { - return 0 + Topic string + JetStream nats.JetStreamContext + DB storage.Database } // ProduceKeyChanges creates new change events for each key func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error { userToDeviceCount := make(map[string]int) for _, key := range keys { - var m sarama.ProducerMessage - + id, err := p.DB.StoreKeyChange(context.Background(), key.UserID) + if err != nil { + return err + } + key.DeviceChangeID = id value, err := json.Marshal(key) if err != nil { return err } - m.Topic = string(p.Topic) - m.Key = sarama.StringEncoder(key.UserID) - m.Value = sarama.ByteEncoder(value) + m := &nats.Msg{ + Subject: p.Topic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.UserID, key.UserID) + m.Data = value - partition, offset, err := p.Producer.SendMessage(&m) - if err != nil { - return err - } - err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID) + _, err = p.JetStream.PublishMsg(m) if err != nil { return err } + userToDeviceCount[key.UserID]++ } for userID, count := range userToDeviceCount { @@ -76,7 +71,6 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error { } func (p *KeyChange) ProduceSigningKeyUpdate(key eduapi.CrossSigningKeyUpdate) error { - var m sarama.ProducerMessage output := &api.DeviceMessage{ Type: api.TypeCrossSigningUpdate, OutputCrossSigningKeyUpdate: &eduapi.OutputCrossSigningKeyUpdate{ @@ -84,20 +78,25 @@ func (p *KeyChange) ProduceSigningKeyUpdate(key eduapi.CrossSigningKeyUpdate) er }, } + id, err := p.DB.StoreKeyChange(context.Background(), key.UserID) + if err != nil { + return err + } + output.DeviceChangeID = id + value, err := json.Marshal(output) if err != nil { return err } - m.Topic = string(p.Topic) - m.Key = sarama.StringEncoder(key.UserID) - m.Value = sarama.ByteEncoder(value) - - partition, offset, err := p.Producer.SendMessage(&m) - if err != nil { - return err + m := &nats.Msg{ + Subject: p.Topic, + Header: nats.Header{}, } - err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID) + m.Header.Set(jetstream.UserID, key.UserID) + m.Data = value + + _, err = p.JetStream.PublishMsg(m) if err != nil { return err } diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 99842bc58..87feae47d 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -66,14 +66,14 @@ type Database interface { // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) - // StoreKeyChange stores key change metadata after the change has been sent to Kafka. `userID` is the the user who has changed - // their keys in some way. - StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error + // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change. + // `userID` is the the user who has changed their keys in some way. + StoreKeyChange(ctx context.Context, userID string) (int64, error) // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive). // A to offset of sarama.OffsetNewest means no upper limit. // Returns the offset of the latest key change. - KeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) + KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. diff --git a/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go b/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go new file mode 100644 index 000000000..e5bcf08d1 --- /dev/null +++ b/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go @@ -0,0 +1,79 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/pressly/goose" +) + +func LoadFromGoose() { + goose.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges) +} + +func LoadRefactorKeyChanges(m *sqlutil.Migrations) { + m.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges) +} + +func UpRefactorKeyChanges(tx *sql.Tx) error { + // start counting from the last max offset, else 0. We need to do a count(*) first to see if there + // even are entries in this table to know if we can query for log_offset. Without the count then + // the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't + // exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/ + var count int + _ = tx.QueryRow(`SELECT count(*) FROM keyserver_key_changes`).Scan(&count) + if count > 0 { + var maxOffset int64 + _ = tx.QueryRow(`SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset) + if _, err := tx.Exec(fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil { + return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err) + } + } + + _, err := tx.Exec(` + -- make the new table + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'), + user_id TEXT NOT NULL, + CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownRefactorKeyChanges(tx *sql.Tx) error { + _, err := tx.Exec(` + -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers + DROP SEQUENCE IF EXISTS keyserver_key_changes_seq; + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + partition BIGINT NOT NULL, + log_offset BIGINT NOT NULL, + user_id TEXT NOT NULL, + CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/keyserver/storage/postgres/key_changes_table.go b/keyserver/storage/postgres/key_changes_table.go index df4b47e79..20d227c24 100644 --- a/keyserver/storage/postgres/key_changes_table.go +++ b/keyserver/storage/postgres/key_changes_table.go @@ -26,27 +26,25 @@ import ( var keyChangesSchema = ` -- Stores key change information about users. Used to determine when to send updated device lists to clients. +CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq; CREATE TABLE IF NOT EXISTS keyserver_key_changes ( - partition BIGINT NOT NULL, - log_offset BIGINT NOT NULL, + change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'), user_id TEXT NOT NULL, - CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset) + CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id) ); ` -// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped. -// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will -// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too. +// Replace based on user ID. We don't care how many times the user's keys have changed, only that they +// have changed, hence we can just keep bumping the change ID for this user. const upsertKeyChangeSQL = "" + - "INSERT INTO keyserver_key_changes (partition, log_offset, user_id)" + - " VALUES ($1, $2, $3)" + - " ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique" + - " DO UPDATE SET user_id = $3" + "INSERT INTO keyserver_key_changes (user_id)" + + " VALUES ($1)" + + " ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique_per_user" + + " DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" + + " RETURNING change_id" -// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just -// take the max offset value as the latest offset. const selectKeyChangesSQL = "" + - "SELECT user_id, MAX(log_offset) FROM keyserver_key_changes WHERE partition = $1 AND log_offset > $2 AND log_offset <= $3 GROUP BY user_id" + "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2" type keyChangesStatements struct { db *sql.DB @@ -59,31 +57,32 @@ func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { db: db, } _, err := db.Exec(keyChangesSchema) - if err != nil { - return nil, err - } - if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil { - return nil, err - } - if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil { - return nil, err - } - return s, nil + return s, err } -func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { - _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) - return err +func (s *keyChangesStatements) Prepare() (err error) { + if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil { + return err + } + if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil { + return err + } + return nil +} + +func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { + err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) + return } func (s *keyChangesStatements) SelectKeyChanges( - ctx context.Context, partition int32, fromOffset, toOffset int64, + ctx context.Context, fromOffset, toOffset int64, ) (userIDs []string, latestOffset int64, err error) { if toOffset == sarama.OffsetNewest { toOffset = math.MaxInt64 } latestOffset = fromOffset - rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset) + rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) if err != nil { return nil, 0, err } diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go index 52f3a7f6b..b71cc1a7a 100644 --- a/keyserver/storage/postgres/storage.go +++ b/keyserver/storage/postgres/storage.go @@ -16,6 +16,7 @@ package postgres import ( "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/keyserver/storage/shared" "github.com/matrix-org/dendrite/setup/config" ) @@ -51,6 +52,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) if err != nil { return nil, err } + m := sqlutil.NewMigrations() + deltas.LoadRefactorKeyChanges(m) + if err = m.RunDeltas(db, dbProperties); err != nil { + return nil, err + } + if err = kc.Prepare(); err != nil { + return nil, err + } d := &shared.Database{ DB: db, Writer: sqlutil.NewDummyWriter(), diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 5bd8be368..5914d28e1 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -135,14 +135,16 @@ func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[st return result, err } -func (d *Database) StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { - return d.Writer.Do(nil, nil, func(_ *sql.Tx) error { - return d.KeyChangesTable.InsertKeyChange(ctx, partition, offset, userID) +func (d *Database) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) { + err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID) + return err }) + return } -func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { - return d.KeyChangesTable.SelectKeyChanges(ctx, partition, fromOffset, toOffset) +func (d *Database) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { + return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset) } // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. diff --git a/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go b/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go new file mode 100644 index 000000000..fbc548c38 --- /dev/null +++ b/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go @@ -0,0 +1,76 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/pressly/goose" +) + +func LoadFromGoose() { + goose.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges) +} + +func LoadRefactorKeyChanges(m *sqlutil.Migrations) { + m.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges) +} + +func UpRefactorKeyChanges(tx *sql.Tx) error { + // start counting from the last max offset, else 0. + var maxOffset int64 + var userID string + _ = tx.QueryRow(`SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset) + + _, err := tx.Exec(` + -- make the new table + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + change_id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The key owner + user_id TEXT NOT NULL, + UNIQUE (user_id) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + // to start counting from maxOffset, insert a row with that value + if userID != "" { + _, err = tx.Exec(`INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID) + return err + } + return nil +} + +func DownRefactorKeyChanges(tx *sql.Tx) error { + _, err := tx.Exec(` + -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + partition BIGINT NOT NULL, + offset BIGINT NOT NULL, + -- The key owner + user_id TEXT NOT NULL, + UNIQUE (partition, offset) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go index b4753ccc5..d43c15ca9 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -27,27 +27,22 @@ import ( var keyChangesSchema = ` -- Stores key change information about users. Used to determine when to send updated device lists to clients. CREATE TABLE IF NOT EXISTS keyserver_key_changes ( - partition BIGINT NOT NULL, - offset BIGINT NOT NULL, + change_id INTEGER PRIMARY KEY AUTOINCREMENT, -- The key owner user_id TEXT NOT NULL, - UNIQUE (partition, offset) + UNIQUE (user_id) ); ` -// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped. -// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will -// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too. +// Replace based on user ID. We don't care how many times the user's keys have changed, only that they +// have changed, hence we can just keep bumping the change ID for this user. const upsertKeyChangeSQL = "" + - "INSERT INTO keyserver_key_changes (partition, offset, user_id)" + - " VALUES ($1, $2, $3)" + - " ON CONFLICT (partition, offset)" + - " DO UPDATE SET user_id = $3" + "INSERT OR REPLACE INTO keyserver_key_changes (user_id)" + + " VALUES ($1)" + + " RETURNING change_id" -// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just -// take the max offset value as the latest offset. const selectKeyChangesSQL = "" + - "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id" + "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2" type keyChangesStatements struct { db *sql.DB @@ -60,31 +55,32 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { db: db, } _, err := db.Exec(keyChangesSchema) - if err != nil { - return nil, err - } - if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil { - return nil, err - } - if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil { - return nil, err - } - return s, nil + return s, err } -func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { - _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) - return err +func (s *keyChangesStatements) Prepare() (err error) { + if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil { + return err + } + if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil { + return err + } + return nil +} + +func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { + err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) + return } func (s *keyChangesStatements) SelectKeyChanges( - ctx context.Context, partition int32, fromOffset, toOffset int64, + ctx context.Context, fromOffset, toOffset int64, ) (userIDs []string, latestOffset int64, err error) { if toOffset == sarama.OffsetNewest { toOffset = math.MaxInt64 } latestOffset = fromOffset - rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset) + rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) if err != nil { return nil, 0, err } diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go index ee1746cd6..50ce00d05 100644 --- a/keyserver/storage/sqlite3/storage.go +++ b/keyserver/storage/sqlite3/storage.go @@ -17,6 +17,7 @@ package sqlite3 import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/storage/shared" + "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/setup/config" ) @@ -49,6 +50,15 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) if err != nil { return nil, err } + + m := sqlutil.NewMigrations() + deltas.LoadRefactorKeyChanges(m) + if err = m.RunDeltas(db, dbProperties); err != nil { + return nil, err + } + if err = kc.Prepare(); err != nil { + return nil, err + } d := &shared.Database{ DB: db, Writer: sqlutil.NewExclusiveWriter(), diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index 4e0a8af1d..2f8cf809b 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -44,15 +44,18 @@ func MustNotError(t *testing.T, err error) { func TestKeyChanges(t *testing.T) { db, clean := MustCreateDatabase(t) defer clean() - MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) - MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost")) - MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost")) - userIDs, latest, err := db.KeyChanges(ctx, 0, 1, sarama.OffsetNewest) + _, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, sarama.OffsetNewest) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } - if latest != 2 { - t.Fatalf("KeyChanges: got latest=%d want 2", latest) + if latest != deviceChangeIDC { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) } if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) @@ -62,15 +65,21 @@ func TestKeyChanges(t *testing.T) { func TestKeyChangesNoDupes(t *testing.T) { db, clean := MustCreateDatabase(t) defer clean() - MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) - MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost")) - MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost")) - userIDs, latest, err := db.KeyChanges(ctx, 0, 0, sarama.OffsetNewest) + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + if deviceChangeIDA == deviceChangeIDB { + t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) + } + deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, 0, sarama.OffsetNewest) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } - if latest != 2 { - t.Fatalf("KeyChanges: got latest=%d want 2", latest) + if latest != deviceChangeID { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) } if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) @@ -80,15 +89,18 @@ func TestKeyChangesNoDupes(t *testing.T) { func TestKeyChangesUpperLimit(t *testing.T) { db, clean := MustCreateDatabase(t) defer clean() - MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) - MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost")) - MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost")) - userIDs, latest, err := db.KeyChanges(ctx, 0, 0, 1) + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + _, err = db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } - if latest != 1 { - t.Fatalf("KeyChanges: got latest=%d want 1", latest) + if latest != deviceChangeIDB { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) } if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 612eeb867..0d94c94cc 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -44,10 +44,12 @@ type DeviceKeys interface { } type KeyChanges interface { - InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error + InsertKeyChange(ctx context.Context, userID string) (int64, error) // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets. // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset. - SelectKeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) + SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) + + Prepare() error } type StaleDeviceLists interface { diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index f49536f4e..d4c5ebb5b 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -842,9 +842,13 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s if err != nil { return nil, err } - if roomInfo == nil || roomInfo.IsStub { + if roomInfo == nil { return nil, fmt.Errorf("room %s doesn't exist", roomID) } + // e.g invited rooms + if roomInfo.IsStub { + return nil, nil + } eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) if err == sql.ErrNoRows { // No rooms have an event of this type, otherwise we'd have an event type NID diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 76b143d8b..97685cc04 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -17,7 +17,6 @@ package consumers import ( "context" "encoding/json" - "sync" "github.com/Shopify/sarama" "github.com/getsentry/sentry-go" @@ -34,16 +33,14 @@ import ( // OutputKeyChangeEventConsumer consumes events that originated in the key server. type OutputKeyChangeEventConsumer struct { - ctx context.Context - keyChangeConsumer *internal.ContinualConsumer - db storage.Database - notifier *notifier.Notifier - stream types.PartitionedStreamProvider - serverName gomatrixserverlib.ServerName // our server name - rsAPI roomserverAPI.RoomserverInternalAPI - keyAPI api.KeyInternalAPI - partitionToOffset map[int32]int64 - partitionToOffsetMu sync.Mutex + ctx context.Context + keyChangeConsumer *internal.ContinualConsumer + db storage.Database + notifier *notifier.Notifier + stream types.StreamProvider + serverName gomatrixserverlib.ServerName // our server name + rsAPI roomserverAPI.RoomserverInternalAPI + keyAPI api.KeyInternalAPI } // NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer. @@ -57,7 +54,7 @@ func NewOutputKeyChangeEventConsumer( rsAPI roomserverAPI.RoomserverInternalAPI, store storage.Database, notifier *notifier.Notifier, - stream types.PartitionedStreamProvider, + stream types.StreamProvider, ) *OutputKeyChangeEventConsumer { consumer := internal.ContinualConsumer{ @@ -69,16 +66,14 @@ func NewOutputKeyChangeEventConsumer( } s := &OutputKeyChangeEventConsumer{ - ctx: process.Context(), - keyChangeConsumer: &consumer, - db: store, - serverName: serverName, - keyAPI: keyAPI, - rsAPI: rsAPI, - partitionToOffset: make(map[int32]int64), - partitionToOffsetMu: sync.Mutex{}, - notifier: notifier, - stream: stream, + ctx: process.Context(), + keyChangeConsumer: &consumer, + db: store, + serverName: serverName, + keyAPI: keyAPI, + rsAPI: rsAPI, + notifier: notifier, + stream: stream, } consumer.ProcessMessage = s.onMessage @@ -88,24 +83,10 @@ func NewOutputKeyChangeEventConsumer( // Start consuming from the key server func (s *OutputKeyChangeEventConsumer) Start() error { - offsets, err := s.keyChangeConsumer.StartOffsets() - s.partitionToOffsetMu.Lock() - for _, o := range offsets { - s.partitionToOffset[o.Partition] = o.Offset - } - s.partitionToOffsetMu.Unlock() - return err -} - -func (s *OutputKeyChangeEventConsumer) updateOffset(msg *sarama.ConsumerMessage) { - s.partitionToOffsetMu.Lock() - defer s.partitionToOffsetMu.Unlock() - s.partitionToOffset[msg.Partition] = msg.Offset + return s.keyChangeConsumer.Start() } func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { - defer s.updateOffset(msg) - var m api.DeviceMessage if err := json.Unmarshal(msg.Value, &m); err != nil { logrus.WithError(err).Errorf("failed to read device message from key change topic") @@ -118,15 +99,15 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er } switch m.Type { case api.TypeCrossSigningUpdate: - return s.onCrossSigningMessage(m, msg.Offset, msg.Partition) + return s.onCrossSigningMessage(m, m.DeviceChangeID) case api.TypeDeviceKeyUpdate: fallthrough default: - return s.onDeviceKeyMessage(m, msg.Offset, msg.Partition) + return s.onDeviceKeyMessage(m, m.DeviceChangeID) } } -func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, offset int64, partition int32) error { +func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, deviceChangeID int64) error { if m.DeviceKeys == nil { return nil } @@ -143,10 +124,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, o } // make sure we get our own key updates too! queryRes.UserIDsToCount[output.UserID] = 1 - posUpdate := types.LogPosition{ - Offset: offset, - Partition: partition, - } + posUpdate := types.StreamPosition(deviceChangeID) s.stream.Advance(posUpdate) for userID := range queryRes.UserIDsToCount { @@ -156,7 +134,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, o return nil } -func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, offset int64, partition int32) error { +func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, deviceChangeID int64) error { output := m.CrossSigningKeyUpdate // work out who we need to notify about the new key var queryRes roomserverAPI.QuerySharedUsersResponse @@ -170,10 +148,7 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage } // make sure we get our own key updates too! queryRes.UserIDsToCount[output.UserID] = 1 - posUpdate := types.LogPosition{ - Offset: offset, - Partition: partition, - } + posUpdate := types.StreamPosition(deviceChangeID) s.stream.Advance(posUpdate) for userID := range queryRes.UserIDsToCount { diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 56a438fb5..41efd4a07 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -47,8 +47,8 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID, // be already filled in with join/leave information. func DeviceListCatchup( ctx context.Context, keyAPI keyapi.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, - userID string, res *types.Response, from, to types.LogPosition, -) (newPos types.LogPosition, hasNew bool, err error) { + userID string, res *types.Response, from, to types.StreamPosition, +) (newPos types.StreamPosition, hasNew bool, err error) { // Track users who we didn't track before but now do by virtue of sharing a room with them, or not. newlyJoinedRooms := joinedRooms(res, userID) @@ -64,27 +64,18 @@ func DeviceListCatchup( } // now also track users who we already share rooms with but who have updated their devices between the two tokens - - var partition int32 - var offset int64 - partition = -1 - offset = sarama.OffsetOldest - // Extract partition/offset from sync token - // TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make. - if !from.IsEmpty() { - partition = from.Partition - offset = from.Offset + offset := sarama.OffsetOldest + toOffset := sarama.OffsetNewest + if to > 0 && to > from { + toOffset = int64(to) } - var toOffset int64 - toOffset = sarama.OffsetNewest - if toLog := to; toLog.Partition == partition && toLog.Offset > 0 { - toOffset = toLog.Offset + if from > 0 { + offset = int64(from) } var queryRes keyapi.QueryKeyChangesResponse keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{ - Partition: partition, - Offset: offset, - ToOffset: toOffset, + Offset: offset, + ToOffset: toOffset, }, &queryRes) if queryRes.Error != nil { // don't fail the catchup because we may have got useful information by tracking membership @@ -95,8 +86,8 @@ func DeviceListCatchup( var sharedUsersMap map[string]int sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs) util.GetLogger(ctx).Debugf( - "QueryKeyChanges request p=%d,off=%d,to=%d response p=%d off=%d uids=%v", - partition, offset, toOffset, queryRes.Partition, queryRes.Offset, queryRes.UserIDs, + "QueryKeyChanges request off=%d,to=%d response off=%d uids=%v", + offset, toOffset, queryRes.Offset, queryRes.UserIDs, ) userSet := make(map[string]bool) for _, userID := range res.DeviceLists.Changed { @@ -125,13 +116,8 @@ func DeviceListCatchup( res.DeviceLists.Left = append(res.DeviceLists.Left, userID) } } - // set the new token - to = types.LogPosition{ - Partition: queryRes.Partition, - Offset: queryRes.Offset, - } - return to, hasNew, nil + return types.StreamPosition(queryRes.Offset), hasNew, nil } // TrackChangedUsers calculates the values of device_lists.changed|left in the /sync response. diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index e52e55564..d9fb9cf82 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -6,7 +6,6 @@ import ( "sort" "testing" - "github.com/Shopify/sarama" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" @@ -16,11 +15,7 @@ import ( var ( syncingUser = "@alice:localhost" - emptyToken = types.LogPosition{} - newestToken = types.LogPosition{ - Offset: sarama.OffsetNewest, - Partition: 0, - } + emptyToken = types.StreamPosition(0) ) type mockKeyAPI struct{} @@ -186,7 +181,7 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) { "!another:room": {syncingUser}, }, } - _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) if err != nil { t.Fatalf("DeviceListCatchup returned an error: %s", err) } @@ -209,7 +204,7 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) { "!another:room": {syncingUser}, }, } - _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) if err != nil { t.Fatalf("DeviceListCatchup returned an error: %s", err) } @@ -232,7 +227,7 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) { "!another:room": {syncingUser, existingUser}, }, } - _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -254,7 +249,7 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) { "!another:room": {syncingUser, existingUser}, }, } - _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) if err != nil { t.Fatalf("DeviceListCatchup returned an error: %s", err) } @@ -313,7 +308,7 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) { roomID: {syncingUser, existingUser}, }, } - _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) if err != nil { t.Fatalf("DeviceListCatchup returned an error: %s", err) } @@ -341,7 +336,7 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) { "!another:room": {syncingUser}, }, } - _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -427,7 +422,7 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) { }, } _, hasNew, err := DeviceListCatchup( - context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken, + context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken, ) if err != nil { t.Fatalf("DeviceListCatchup returned an error: %s", err) diff --git a/syncapi/streams/stream_devicelist.go b/syncapi/streams/stream_devicelist.go index 9ea9d088f..6ff8a7fd5 100644 --- a/syncapi/streams/stream_devicelist.go +++ b/syncapi/streams/stream_devicelist.go @@ -10,7 +10,7 @@ import ( ) type DeviceListStreamProvider struct { - PartitionedStreamProvider + StreamProvider rsAPI api.RoomserverInternalAPI keyAPI keyapi.KeyInternalAPI } @@ -18,15 +18,15 @@ type DeviceListStreamProvider struct { func (p *DeviceListStreamProvider) CompleteSync( ctx context.Context, req *types.SyncRequest, -) types.LogPosition { +) types.StreamPosition { return p.LatestPosition(ctx) } func (p *DeviceListStreamProvider) IncrementalSync( ctx context.Context, req *types.SyncRequest, - from, to types.LogPosition, -) types.LogPosition { + from, to types.StreamPosition, +) types.StreamPosition { var err error to, _, err = internal.DeviceListCatchup(context.Background(), p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) if err != nil { diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index ba4118df5..6b02c75ea 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -18,7 +18,7 @@ type Streams struct { InviteStreamProvider types.StreamProvider SendToDeviceStreamProvider types.StreamProvider AccountDataStreamProvider types.StreamProvider - DeviceListStreamProvider types.PartitionedStreamProvider + DeviceListStreamProvider types.StreamProvider } func NewSyncStreamProviders( @@ -48,9 +48,9 @@ func NewSyncStreamProviders( userAPI: userAPI, }, DeviceListStreamProvider: &DeviceListStreamProvider{ - PartitionedStreamProvider: PartitionedStreamProvider{DB: d}, - rsAPI: rsAPI, - keyAPI: keyAPI, + StreamProvider: StreamProvider{DB: d}, + rsAPI: rsAPI, + keyAPI: keyAPI, }, } diff --git a/syncapi/streams/template_pstream.go b/syncapi/streams/template_pstream.go deleted file mode 100644 index 265e22a20..000000000 --- a/syncapi/streams/template_pstream.go +++ /dev/null @@ -1,38 +0,0 @@ -package streams - -import ( - "context" - "sync" - - "github.com/matrix-org/dendrite/syncapi/storage" - "github.com/matrix-org/dendrite/syncapi/types" -) - -type PartitionedStreamProvider struct { - DB storage.Database - latest types.LogPosition - latestMutex sync.RWMutex -} - -func (p *PartitionedStreamProvider) Setup() { -} - -func (p *PartitionedStreamProvider) Advance( - latest types.LogPosition, -) { - p.latestMutex.Lock() - defer p.latestMutex.Unlock() - - if latest.IsAfter(&p.latest) { - p.latest = latest - } -} - -func (p *PartitionedStreamProvider) LatestPosition( - ctx context.Context, -) types.LogPosition { - p.latestMutex.RLock() - defer p.latestMutex.RUnlock() - - return p.latest -} diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index a45736106..ca35951a0 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -140,6 +140,12 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. // Extract values from request syncReq, err := newSyncRequest(req, *device, rp.db) if err != nil { + if err == types.ErrMalformedSyncToken { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue(err.Error()), + } + } return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.Unknown(err.Error()), diff --git a/syncapi/types/provider.go b/syncapi/types/provider.go index 93ed12661..f6185fcb5 100644 --- a/syncapi/types/provider.go +++ b/syncapi/types/provider.go @@ -42,11 +42,3 @@ type StreamProvider interface { // LatestPosition returns the latest stream position for this stream. LatestPosition(ctx context.Context) StreamPosition } - -type PartitionedStreamProvider interface { - Setup() - Advance(latest LogPosition) - CompleteSync(ctx context.Context, req *SyncRequest) LogPosition - IncrementalSync(ctx context.Context, req *SyncRequest, from, to LogPosition) LogPosition - LatestPosition(ctx context.Context) LogPosition -} diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 44e718b38..68c308d83 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -16,6 +16,7 @@ package types import ( "encoding/json" + "errors" "fmt" "strconv" "strings" @@ -26,13 +27,10 @@ import ( ) var ( - // ErrInvalidSyncTokenType is returned when an attempt at creating a - // new instance of SyncToken with an invalid type (i.e. neither "s" - // nor "t"). - ErrInvalidSyncTokenType = fmt.Errorf("sync token has an unknown prefix (should be either s or t)") - // ErrInvalidSyncTokenLen is returned when the pagination token is an - // invalid length - ErrInvalidSyncTokenLen = fmt.Errorf("sync token has an invalid length") + // This error is returned when parsing sync tokens if the token is invalid. Callers can use this + // error to detect whether to 400 or 401 the client. It is recommended to 401 them to force a + // logout. + ErrMalformedSyncToken = errors.New("malformed sync token") ) type StateDelta struct { @@ -47,27 +45,6 @@ type StateDelta struct { // StreamPosition represents the offset in the sync stream a client is at. type StreamPosition int64 -// LogPosition represents the offset in a Kafka log a client is at. -type LogPosition struct { - Partition int32 - Offset int64 -} - -func (p *LogPosition) IsEmpty() bool { - return p.Offset == 0 -} - -// IsAfter returns true if this position is after `lp`. -func (p *LogPosition) IsAfter(lp *LogPosition) bool { - if lp == nil { - return false - } - if p.Partition != lp.Partition { - return false - } - return p.Offset > lp.Offset -} - // StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event. type StreamEvent struct { *gomatrixserverlib.HeaderedEvent @@ -124,7 +101,7 @@ type StreamingToken struct { SendToDevicePosition StreamPosition InvitePosition StreamPosition AccountDataPosition StreamPosition - DeviceListPosition LogPosition + DeviceListPosition StreamPosition } // This will be used as a fallback by json.Marshal. @@ -140,14 +117,11 @@ func (s *StreamingToken) UnmarshalText(text []byte) (err error) { func (t StreamingToken) String() string { posStr := fmt.Sprintf( - "s%d_%d_%d_%d_%d_%d", + "s%d_%d_%d_%d_%d_%d_%d", t.PDUPosition, t.TypingPosition, t.ReceiptPosition, t.SendToDevicePosition, - t.InvitePosition, t.AccountDataPosition, + t.InvitePosition, t.AccountDataPosition, t.DeviceListPosition, ) - if dl := t.DeviceListPosition; !dl.IsEmpty() { - posStr += fmt.Sprintf(".dl-%d-%d", dl.Partition, dl.Offset) - } return posStr } @@ -166,14 +140,14 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool { return true case t.AccountDataPosition > other.AccountDataPosition: return true - case t.DeviceListPosition.IsAfter(&other.DeviceListPosition): + case t.DeviceListPosition > other.DeviceListPosition: return true } return false } func (t *StreamingToken) IsEmpty() bool { - return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition == 0 && t.DeviceListPosition.IsEmpty() + return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition+t.DeviceListPosition == 0 } // WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken. @@ -208,7 +182,7 @@ func (t *StreamingToken) ApplyUpdates(other StreamingToken) { if other.AccountDataPosition > t.AccountDataPosition { t.AccountDataPosition = other.AccountDataPosition } - if other.DeviceListPosition.IsAfter(&t.DeviceListPosition) { + if other.DeviceListPosition > t.DeviceListPosition { t.DeviceListPosition = other.DeviceListPosition } } @@ -292,16 +266,18 @@ func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) { func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { if len(tok) < 1 { - err = fmt.Errorf("empty stream token") + err = ErrMalformedSyncToken return } if tok[0] != SyncTokenTypeStream[0] { - err = fmt.Errorf("stream token must start with 's'") + err = ErrMalformedSyncToken return } - categories := strings.Split(tok[1:], ".") - parts := strings.Split(categories[0], "_") - var positions [6]StreamPosition + // Migration: Remove everything after and including '.' - we previously had tokens like: + // s478_0_0_0_0_13.dl-0-2 but we have now removed partitioned stream positions + tok = strings.Split(tok, ".")[0] + parts := strings.Split(tok[1:], "_") + var positions [7]StreamPosition for i, p := range parts { if i > len(positions) { break @@ -309,6 +285,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { var pos int pos, err = strconv.Atoi(p) if err != nil { + err = ErrMalformedSyncToken return } positions[i] = StreamPosition(pos) @@ -320,31 +297,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { SendToDevicePosition: positions[3], InvitePosition: positions[4], AccountDataPosition: positions[5], - } - // dl-0-1234 - // $log_name-$partition-$offset - for _, logStr := range categories[1:] { - segments := strings.Split(logStr, "-") - if len(segments) != 3 { - err = fmt.Errorf("invalid log position %q", logStr) - return - } - switch segments[0] { - case "dl": - // Device list syncing - var partition, offset int - if partition, err = strconv.Atoi(segments[1]); err != nil { - return - } - if offset, err = strconv.Atoi(segments[2]); err != nil { - return - } - token.DeviceListPosition.Partition = int32(partition) - token.DeviceListPosition.Offset = int64(offset) - default: - err = fmt.Errorf("unrecognised token type %q", segments[0]) - return - } + DeviceListPosition: positions[6], } return token, nil } diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index 3e5777888..cda178b37 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -2,50 +2,17 @@ package types import ( "encoding/json" - "reflect" "testing" "github.com/matrix-org/gomatrixserverlib" ) -func TestNewSyncTokenWithLogs(t *testing.T) { - tests := map[string]*StreamingToken{ - "s4_0_0_0_0_0": { - PDUPosition: 4, - }, - "s4_0_0_0_0_0.dl-0-123": { - PDUPosition: 4, - DeviceListPosition: LogPosition{ - Partition: 0, - Offset: 123, - }, - }, - } - for tok, want := range tests { - got, err := NewStreamTokenFromString(tok) - if err != nil { - if want == nil { - continue // error expected - } - t.Errorf("%s errored: %s", tok, err) - continue - } - if !reflect.DeepEqual(got, *want) { - t.Errorf("%s mismatch: got %v want %v", tok, got, want) - } - gotStr := got.String() - if gotStr != tok { - t.Errorf("%s reserialisation mismatch: got %s want %s", tok, gotStr, tok) - } - } -} - func TestSyncTokens(t *testing.T) { shouldPass := map[string]string{ - "s4_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, LogPosition{}}.String(), - "s3_1_0_0_0_0.dl-1-2": StreamingToken{3, 1, 0, 0, 0, 0, LogPosition{1, 2}}.String(), - "s3_1_2_3_5_0": StreamingToken{3, 1, 2, 3, 5, 0, LogPosition{}}.String(), - "t3_1": TopologyToken{3, 1}.String(), + "s4_0_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, 0}.String(), + "s3_1_0_0_0_0_2": StreamingToken{3, 1, 0, 0, 0, 0, 2}.String(), + "s3_1_2_3_5_0_0": StreamingToken{3, 1, 2, 3, 5, 0, 0}.String(), + "t3_1": TopologyToken{3, 1}.String(), } for a, b := range shouldPass { diff --git a/sytest-whitelist b/sytest-whitelist index 558eb29a6..f11cd96bf 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -588,3 +588,4 @@ User can invite remote user to room with version 9 Remote user can backfill in a room with version 9 Can reject invites over federation for rooms with version 9 Can receive redactions from regular users over federation in room version 9 +Forward extremities remain so even after the next events are populated as outliers