From adf7b5929401f56bedba92ef778b5e56feefc479 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Tue, 28 Jul 2020 17:38:30 +0100 Subject: [PATCH 1/4] Persist partition|offset|user_id in the keyserver (#1226) * Persist partition|offset|user_id in the keyserver Required for a query API which will be used by the syncapi which will be called when a `/sync` request comes in which will return a list of user IDs of people who have changed their device keys between two tokens. * Add tests and fix maxOffset bug * s/offset/log_offset/g because 'offset' is a reserved word in postgres --- keyserver/keyserver.go | 1 + keyserver/producers/keychange.go | 7 ++ keyserver/storage/interface.go | 8 ++ .../storage/postgres/key_changes_table.go | 97 ++++++++++++++++++ keyserver/storage/postgres/storage.go | 5 + keyserver/storage/shared/storage.go | 9 ++ .../storage/sqlite3/key_changes_table.go | 98 +++++++++++++++++++ keyserver/storage/sqlite3/storage.go | 5 + keyserver/storage/storage_test.go | 57 +++++++++++ keyserver/storage/tables/interface.go | 5 + 10 files changed, 292 insertions(+) create mode 100644 keyserver/storage/postgres/key_changes_table.go create mode 100644 keyserver/storage/sqlite3/key_changes_table.go create mode 100644 keyserver/storage/storage_test.go diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 47c6a8c37..c748d7ce3 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -49,6 +49,7 @@ func NewInternalAPI( keyChangeProducer := &producers.KeyChange{ Topic: string(cfg.Kafka.Topics.OutputKeyChangeEvent), Producer: producer, + DB: db, } return &internal.KeyInternalAPI{ DB: db, diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go index 6683a9364..d59dd2002 100644 --- a/keyserver/producers/keychange.go +++ b/keyserver/producers/keychange.go @@ -15,10 +15,12 @@ package producers import ( + "context" "encoding/json" "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/storage" "github.com/sirupsen/logrus" ) @@ -26,6 +28,7 @@ import ( type KeyChange struct { Topic string Producer sarama.SyncProducer + DB storage.Database } // ProduceKeyChanges creates new change events for each key @@ -46,6 +49,10 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error { if err != nil { return err } + err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID) + if err != nil { + return err + } logrus.WithFields(logrus.Fields{ "user_id": key.UserID, "device_id": key.DeviceID, diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 7a0328bd7..f4787790c 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -43,4 +43,12 @@ type Database interface { // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) + + // StoreKeyChange stores key change metadata after the change has been sent to Kafka. `userID` is the the user who has changed + // their keys in some way. + StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error + + // KeyChanges returns a list of user IDs who have modified their keys from the offset given. + // Returns the offset of the latest key change. + KeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error) } diff --git a/keyserver/storage/postgres/key_changes_table.go b/keyserver/storage/postgres/key_changes_table.go new file mode 100644 index 000000000..9d259f9f0 --- /dev/null +++ b/keyserver/storage/postgres/key_changes_table.go @@ -0,0 +1,97 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver/storage/tables" +) + +var keyChangesSchema = ` +-- Stores key change information about users. Used to determine when to send updated device lists to clients. +CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + partition BIGINT NOT NULL, + log_offset BIGINT NOT NULL, + user_id TEXT NOT NULL, + CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset) +); +` + +// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped. +// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will +// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too. +const upsertKeyChangeSQL = "" + + "INSERT INTO keyserver_key_changes (partition, log_offset, user_id)" + + " VALUES ($1, $2, $3)" + + " ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique" + + " DO UPDATE SET user_id = $3" + +// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just +// take the max offset value as the latest offset. +const selectKeyChangesSQL = "" + + "SELECT user_id, MAX(log_offset) FROM keyserver_key_changes WHERE partition = $1 AND log_offset > $2 GROUP BY user_id" + +type keyChangesStatements struct { + db *sql.DB + upsertKeyChangeStmt *sql.Stmt + selectKeyChangesStmt *sql.Stmt +} + +func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { + s := &keyChangesStatements{ + db: db, + } + _, err := db.Exec(keyChangesSchema) + if err != nil { + return nil, err + } + if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil { + return nil, err + } + if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { + _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) + return err +} + +func (s *keyChangesStatements) SelectKeyChanges( + ctx context.Context, partition int32, fromOffset int64, +) (userIDs []string, latestOffset int64, err error) { + rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset) + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed") + for rows.Next() { + var userID string + var offset int64 + if err := rows.Scan(&userID, &offset); err != nil { + return nil, 0, err + } + if offset > latestOffset { + latestOffset = offset + } + userIDs = append(userIDs, userID) + } + return +} diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go index 4f3217b65..a1d1c0feb 100644 --- a/keyserver/storage/postgres/storage.go +++ b/keyserver/storage/postgres/storage.go @@ -34,9 +34,14 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*s if err != nil { return nil, err } + kc, err := NewPostgresKeyChangesTable(db) + if err != nil { + return nil, err + } return &shared.Database{ DB: db, OneTimeKeysTable: otk, DeviceKeysTable: dk, + KeyChangesTable: kc, }, nil } diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 156b5b415..537a5f7b9 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -28,6 +28,7 @@ type Database struct { DB *sql.DB OneTimeKeysTable tables.OneTimeKeys DeviceKeysTable tables.DeviceKeys + KeyChangesTable tables.KeyChanges } func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { @@ -72,3 +73,11 @@ func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[st }) return result, err } + +func (d *Database) StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { + return d.KeyChangesTable.InsertKeyChange(ctx, partition, offset, userID) +} + +func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error) { + return d.KeyChangesTable.SelectKeyChanges(ctx, partition, fromOffset) +} diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go new file mode 100644 index 000000000..b830214d1 --- /dev/null +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -0,0 +1,98 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver/storage/tables" +) + +var keyChangesSchema = ` +-- Stores key change information about users. Used to determine when to send updated device lists to clients. +CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + partition BIGINT NOT NULL, + offset BIGINT NOT NULL, + -- The key owner + user_id TEXT NOT NULL, + UNIQUE (partition, offset) +); +` + +// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped. +// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will +// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too. +const upsertKeyChangeSQL = "" + + "INSERT INTO keyserver_key_changes (partition, offset, user_id)" + + " VALUES ($1, $2, $3)" + + " ON CONFLICT (partition, offset)" + + " DO UPDATE SET user_id = $3" + +// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just +// take the max offset value as the latest offset. +const selectKeyChangesSQL = "" + + "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 GROUP BY user_id" + +type keyChangesStatements struct { + db *sql.DB + upsertKeyChangeStmt *sql.Stmt + selectKeyChangesStmt *sql.Stmt +} + +func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { + s := &keyChangesStatements{ + db: db, + } + _, err := db.Exec(keyChangesSchema) + if err != nil { + return nil, err + } + if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil { + return nil, err + } + if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { + _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) + return err +} + +func (s *keyChangesStatements) SelectKeyChanges( + ctx context.Context, partition int32, fromOffset int64, +) (userIDs []string, latestOffset int64, err error) { + rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset) + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed") + for rows.Next() { + var userID string + var offset int64 + if err := rows.Scan(&userID, &offset); err != nil { + return nil, 0, err + } + if offset > latestOffset { + latestOffset = offset + } + userIDs = append(userIDs, userID) + } + return +} diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go index f3566ef5c..f9771cf16 100644 --- a/keyserver/storage/sqlite3/storage.go +++ b/keyserver/storage/sqlite3/storage.go @@ -37,9 +37,14 @@ func NewDatabase(dataSourceName string) (*shared.Database, error) { if err != nil { return nil, err } + kc, err := NewSqliteKeyChangesTable(db) + if err != nil { + return nil, err + } return &shared.Database{ DB: db, OneTimeKeysTable: otk, DeviceKeysTable: dk, + KeyChangesTable: kc, }, nil } diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go new file mode 100644 index 000000000..889724789 --- /dev/null +++ b/keyserver/storage/storage_test.go @@ -0,0 +1,57 @@ +package storage + +import ( + "context" + "reflect" + "testing" +) + +var ctx = context.Background() + +func MustNotError(t *testing.T, err error) { + t.Helper() + if err == nil { + return + } + t.Fatalf("operation failed: %s", err) +} + +func TestKeyChanges(t *testing.T) { + db, err := NewDatabase("file::memory:", nil) + if err != nil { + t.Fatalf("Failed to NewDatabase: %s", err) + } + MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) + MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost")) + MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost")) + userIDs, latest, err := db.KeyChanges(ctx, 0, 1) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != 2 { + t.Fatalf("KeyChanges: got latest=%d want 2", latest) + } + if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } +} + +func TestKeyChangesNoDupes(t *testing.T) { + db, err := NewDatabase("file::memory:", nil) + if err != nil { + t.Fatalf("Failed to NewDatabase: %s", err) + } + MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) + MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost")) + MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost")) + userIDs, latest, err := db.KeyChanges(ctx, 0, 0) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != 2 { + t.Fatalf("KeyChanges: got latest=%d want 2", latest) + } + if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } +} diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 216be773b..824b9f0fe 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -35,3 +35,8 @@ type DeviceKeys interface { InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) } + +type KeyChanges interface { + InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error + SelectKeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error) +} From 9a5fb489c5f80148a8512e61c95c8df7bb46d314 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Tue, 28 Jul 2020 18:25:16 +0100 Subject: [PATCH 2/4] Add QueryKeyChanges (#1228) Hook some things up to call it as well. --- keyserver/api/api.go | 17 ++++++++++++ keyserver/internal/internal.go | 11 ++++++++ keyserver/inthttp/client.go | 18 +++++++++++++ keyserver/inthttp/server.go | 11 ++++++++ syncapi/consumers/keychange.go | 42 +++++++++++++++++++++++------ syncapi/consumers/keychange_test.go | 42 +++++++++++++++++++---------- 6 files changed, 119 insertions(+), 22 deletions(-) diff --git a/keyserver/api/api.go b/keyserver/api/api.go index d42fb60cf..406a252d5 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -26,6 +26,7 @@ type KeyInternalAPI interface { // PerformClaimKeys claims one-time keys for use in pre-key messages PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) + QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) } // KeyError is returned if there was a problem performing/querying the server @@ -131,3 +132,19 @@ type QueryKeysResponse struct { // Set if there was a fatal error processing this query Error *KeyError } + +type QueryKeyChangesRequest struct { + // The partition which had key events sent to + Partition int32 + // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning + Offset int64 +} + +type QueryKeyChangesResponse struct { + // The set of users who have had their keys change. + UserIDs []string + // The latest offset represented in this response. + Offset int64 + // Set if there was a problem handling the request. + Error *KeyError +} diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index d3a6d4bae..240a56403 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -40,6 +40,17 @@ type KeyInternalAPI struct { Producer *producers.KeyChange } +func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) { + userIDs, latest, err := a.DB.KeyChanges(ctx, req.Partition, req.Offset) + if err != nil { + res.Error = &api.KeyError{ + Err: err.Error(), + } + } + res.Offset = latest + res.UserIDs = userIDs +} + func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { res.KeyErrors = make(map[string]map[string]*api.KeyError) a.uploadDeviceKeys(ctx, req, res) diff --git a/keyserver/inthttp/client.go b/keyserver/inthttp/client.go index 4c0f1e53a..cd9cf70d4 100644 --- a/keyserver/inthttp/client.go +++ b/keyserver/inthttp/client.go @@ -29,6 +29,7 @@ const ( PerformUploadKeysPath = "/keyserver/performUploadKeys" PerformClaimKeysPath = "/keyserver/performClaimKeys" QueryKeysPath = "/keyserver/queryKeys" + QueryKeyChangesPath = "/keyserver/queryKeyChanges" ) // NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API. @@ -101,3 +102,20 @@ func (h *httpKeyInternalAPI) QueryKeys( } } } + +func (h *httpKeyInternalAPI) QueryKeyChanges( + ctx context.Context, + request *api.QueryKeyChangesRequest, + response *api.QueryKeyChangesResponse, +) { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyChanges") + defer span.Finish() + + apiURL := h.apiURL + QueryKeyChangesPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + if err != nil { + response.Error = &api.KeyError{ + Err: err.Error(), + } + } +} diff --git a/keyserver/inthttp/server.go b/keyserver/inthttp/server.go index ec78b6132..f3d2882c2 100644 --- a/keyserver/inthttp/server.go +++ b/keyserver/inthttp/server.go @@ -58,4 +58,15 @@ func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryKeyChangesPath, + httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse { + request := api.QueryKeyChangesRequest{} + response := api.QueryKeyChangesResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + s.QueryKeyChanges(req.Context(), &request, &response) + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 4a1c73090..78aff6011 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -26,16 +26,17 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" log "github.com/sirupsen/logrus" ) // OutputKeyChangeEventConsumer consumes events that originated in the key server. type OutputKeyChangeEventConsumer struct { - keyChangeConsumer *internal.ContinualConsumer - db storage.Database - serverName gomatrixserverlib.ServerName // our server name - currentStateAPI currentstateAPI.CurrentStateInternalAPI - // keyAPI api.KeyInternalAPI + keyChangeConsumer *internal.ContinualConsumer + db storage.Database + serverName gomatrixserverlib.ServerName // our server name + currentStateAPI currentstateAPI.CurrentStateInternalAPI + keyAPI api.KeyInternalAPI partitionToOffset map[int32]int64 partitionToOffsetMu sync.Mutex } @@ -46,6 +47,7 @@ func NewOutputKeyChangeEventConsumer( serverName gomatrixserverlib.ServerName, topic string, kafkaConsumer sarama.Consumer, + keyAPI api.KeyInternalAPI, currentStateAPI currentstateAPI.CurrentStateInternalAPI, store storage.Database, ) *OutputKeyChangeEventConsumer { @@ -60,6 +62,7 @@ func NewOutputKeyChangeEventConsumer( keyChangeConsumer: &consumer, db: store, serverName: serverName, + keyAPI: keyAPI, currentStateAPI: currentStateAPI, partitionToOffset: make(map[int32]int64), partitionToOffsetMu: sync.Mutex{}, @@ -115,21 +118,44 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er // be already filled in with join/leave information. func (s *OutputKeyChangeEventConsumer) Catchup( ctx context.Context, userID string, res *types.Response, tok types.StreamingToken, -) (hasNew bool, err error) { +) (newTok *types.StreamingToken, hasNew bool, err error) { // Track users who we didn't track before but now do by virtue of sharing a room with them, or not. newlyJoinedRooms := joinedRooms(res, userID) newlyLeftRooms := leftRooms(res) if len(newlyJoinedRooms) > 0 || len(newlyLeftRooms) > 0 { changed, left, err := s.trackChangedUsers(ctx, userID, newlyJoinedRooms, newlyLeftRooms) if err != nil { - return false, err + return nil, false, err } res.DeviceLists.Changed = changed res.DeviceLists.Left = left hasNew = len(changed) > 0 || len(left) > 0 } - // TODO: now also track users who we already share rooms with but who have updated their devices between the two tokens + // now also track users who we already share rooms with but who have updated their devices between the two tokens + // TODO: Extract partition/offset from sync token + var partition int32 + var offset int64 + var queryRes api.QueryKeyChangesResponse + s.keyAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{ + Partition: partition, + Offset: offset, + }, &queryRes) + if queryRes.Error != nil { + // don't fail the catchup because we may have got useful information by tracking membership + util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed") + } else { + // TODO: Make a new streaming token using the new offset + userSet := make(map[string]bool) + for _, userID := range res.DeviceLists.Changed { + userSet[userID] = true + } + for _, userID := range queryRes.UserIDs { + if !userSet[userID] { + res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) + } + } + } return } diff --git a/syncapi/consumers/keychange_test.go b/syncapi/consumers/keychange_test.go index 7322e2083..f8e965700 100644 --- a/syncapi/consumers/keychange_test.go +++ b/syncapi/consumers/keychange_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/matrix-org/dendrite/currentstateserver/api" + keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -15,6 +16,19 @@ var ( syncingUser = "@alice:localhost" ) +type mockKeyAPI struct{} + +func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) { +} + +// PerformClaimKeys claims one-time keys for use in pre-key messages +func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) { +} +func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) { +} +func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) { +} + type mockCurrentStateAPI struct { roomIDToJoinedMembers map[string][]string } @@ -144,7 +158,7 @@ func leaveResponseWithRooms(syncResponse *types.Response, userID string, roomIDs func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) { newShareUser := "@bill:localhost" newlyJoinedRoom := "!TestKeyChangeCatchupOnJoinShareNewUser:bar" - consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{ + consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockKeyAPI{}, &mockCurrentStateAPI{ roomIDToJoinedMembers: map[string][]string{ newlyJoinedRoom: {syncingUser, newShareUser}, "!another:room": {syncingUser}, @@ -153,7 +167,7 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) { syncResponse := types.NewResponse() syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom}) - hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) + _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -167,7 +181,7 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) { func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) { removeUser := "@bill:localhost" newlyLeftRoom := "!TestKeyChangeCatchupOnLeaveShareLeftUser:bar" - consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{ + consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockKeyAPI{}, &mockCurrentStateAPI{ roomIDToJoinedMembers: map[string][]string{ newlyLeftRoom: {removeUser}, "!another:room": {syncingUser}, @@ -176,7 +190,7 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) { syncResponse := types.NewResponse() syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom}) - hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) + _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -190,7 +204,7 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) { func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) { existingUser := "@bob:localhost" newlyJoinedRoom := "!TestKeyChangeCatchupOnJoinShareNoNewUsers:bar" - consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{ + consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockKeyAPI{}, &mockCurrentStateAPI{ roomIDToJoinedMembers: map[string][]string{ newlyJoinedRoom: {syncingUser, existingUser}, "!another:room": {syncingUser, existingUser}, @@ -199,7 +213,7 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) { syncResponse := types.NewResponse() syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom}) - hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) + _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -212,7 +226,7 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) { func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) { existingUser := "@bob:localhost" newlyLeftRoom := "!TestKeyChangeCatchupOnLeaveShareNoUsers:bar" - consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{ + consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockKeyAPI{}, &mockCurrentStateAPI{ roomIDToJoinedMembers: map[string][]string{ newlyLeftRoom: {existingUser}, "!another:room": {syncingUser, existingUser}, @@ -221,7 +235,7 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) { syncResponse := types.NewResponse() syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom}) - hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) + _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -234,7 +248,7 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) { func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) { existingUser := "@bob1:localhost" roomID := "!TestKeyChangeCatchupNoNewJoinsButMessages:bar" - consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{ + consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockKeyAPI{}, &mockCurrentStateAPI{ roomIDToJoinedMembers: map[string][]string{ roomID: {syncingUser, existingUser}, }, @@ -280,7 +294,7 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) { jr.Timeline.Events = roomTimelineEvents syncResponse.Rooms.Join[roomID] = jr - hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) + _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -297,7 +311,7 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) { newlyLeftUser2 := "@debra:localhost" newlyJoinedRoom := "!join:bar" newlyLeftRoom := "!left:bar" - consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{ + consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockKeyAPI{}, &mockCurrentStateAPI{ roomIDToJoinedMembers: map[string][]string{ newlyJoinedRoom: {syncingUser, newShareUser, newShareUser2}, newlyLeftRoom: {newlyLeftUser, newlyLeftUser2}, @@ -308,7 +322,7 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) { syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom}) syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom}) - hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) + _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -333,7 +347,7 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) { newShareUser := "@berta:localhost" newShareUser2 := "@bobby:localhost" roomID := "!join:bar" - consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{ + consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockKeyAPI{}, &mockCurrentStateAPI{ roomIDToJoinedMembers: map[string][]string{ roomID: {newShareUser, newShareUser2}, "!another:room": {syncingUser}, @@ -393,7 +407,7 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) { lr.Timeline.Events = roomEvents syncResponse.Rooms.Leave[roomID] = lr - hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) + _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } From 0fdd4f14d123e76bd3d0368947d3aab84a787946 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Wed, 29 Jul 2020 19:00:04 +0100 Subject: [PATCH 3/4] Add support for logs in StreamingToken (#1229) * Add support for logs in StreamingToken Tokens now end up looking like `s11_22|dl-0-123|ab-0-12224` where `dl` and `ab` are log names, `0` is the partition and `123` and `12224` are the offsets. * Also test reserialisation * s/|/./g so tokens url escape nicely --- syncapi/consumers/clientapi.go | 2 +- syncapi/consumers/eduserver_sendtodevice.go | 2 +- syncapi/consumers/eduserver_typing.go | 4 +- syncapi/consumers/keychange_test.go | 15 +-- syncapi/consumers/roomserver.go | 6 +- syncapi/storage/shared/syncserver.go | 4 +- syncapi/storage/storage_test.go | 32 +++--- syncapi/sync/notifier_test.go | 10 +- syncapi/sync/request.go | 2 +- syncapi/sync/requestpool.go | 2 +- syncapi/types/types.go | 105 ++++++++++++++++++-- syncapi/types/types_test.go | 60 ++++++++++- 12 files changed, 189 insertions(+), 55 deletions(-) diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index ad6290e3f..f7cf96d94 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -91,7 +91,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error }).Panicf("could not save account data") } - s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.NewStreamToken(pduPos, 0)) + s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.NewStreamToken(pduPos, 0, nil)) return nil } diff --git a/syncapi/consumers/eduserver_sendtodevice.go b/syncapi/consumers/eduserver_sendtodevice.go index 487018031..06a8928da 100644 --- a/syncapi/consumers/eduserver_sendtodevice.go +++ b/syncapi/consumers/eduserver_sendtodevice.go @@ -106,7 +106,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage) s.notifier.OnNewSendToDevice( output.UserID, []string{output.DeviceID}, - types.NewStreamToken(0, streamPos), + types.NewStreamToken(0, streamPos, nil), ) return nil diff --git a/syncapi/consumers/eduserver_typing.go b/syncapi/consumers/eduserver_typing.go index 12b1efbc0..0a9a9c0cd 100644 --- a/syncapi/consumers/eduserver_typing.go +++ b/syncapi/consumers/eduserver_typing.go @@ -65,7 +65,7 @@ func (s *OutputTypingEventConsumer) Start() error { s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) { s.notifier.OnNewEvent( nil, roomID, nil, - types.NewStreamToken(0, types.StreamPosition(latestSyncPosition)), + types.NewStreamToken(0, types.StreamPosition(latestSyncPosition), nil), ) }) @@ -94,6 +94,6 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID) } - s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.NewStreamToken(0, typingPos)) + s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.NewStreamToken(0, typingPos, nil)) return nil } diff --git a/syncapi/consumers/keychange_test.go b/syncapi/consumers/keychange_test.go index f8e965700..3ecb3f583 100644 --- a/syncapi/consumers/keychange_test.go +++ b/syncapi/consumers/keychange_test.go @@ -14,6 +14,7 @@ import ( var ( syncingUser = "@alice:localhost" + emptyToken = types.NewStreamToken(0, 0, nil) ) type mockKeyAPI struct{} @@ -167,7 +168,7 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) { syncResponse := types.NewResponse() syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom}) - _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) + _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, emptyToken) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -190,7 +191,7 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) { syncResponse := types.NewResponse() syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom}) - _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) + _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, emptyToken) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -213,7 +214,7 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) { syncResponse := types.NewResponse() syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom}) - _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) + _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, emptyToken) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -235,7 +236,7 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) { syncResponse := types.NewResponse() syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom}) - _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) + _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, emptyToken) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -294,7 +295,7 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) { jr.Timeline.Events = roomTimelineEvents syncResponse.Rooms.Join[roomID] = jr - _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) + _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, emptyToken) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -322,7 +323,7 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) { syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom}) syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom}) - _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) + _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, emptyToken) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -407,7 +408,7 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) { lr.Timeline.Events = roomEvents syncResponse.Rooms.Leave[roomID] = lr - _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) + _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, emptyToken) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index c65027168..da4a5366c 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -158,7 +158,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( }).Panicf("roomserver output log: write event failure") return nil } - s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0)) + s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0, nil)) return nil } @@ -176,7 +176,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( }).Panicf("roomserver output log: write invite failure") return nil } - s.notifier.OnNewEvent(&msg.Event, "", nil, types.NewStreamToken(pduPos, 0)) + s.notifier.OnNewEvent(&msg.Event, "", nil, types.NewStreamToken(pduPos, 0, nil)) return nil } @@ -194,7 +194,7 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( } // Notify any active sync requests that the invite has been retired. // Invites share the same stream counter as PDUs - s.notifier.OnNewEvent(nil, "", []string{msg.TargetUserID}, types.NewStreamToken(sp, 0)) + s.notifier.OnNewEvent(nil, "", []string{msg.TargetUserID}, types.NewStreamToken(sp, 0, nil)) return nil } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index e1312671b..dd5b838ce 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -434,7 +434,7 @@ func (d *Database) syncPositionTx( if maxInviteID > maxEventID { maxEventID = maxInviteID } - sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.EDUCache.GetLatestSyncPosition())) + sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.EDUCache.GetLatestSyncPosition()), nil) return } @@ -731,7 +731,7 @@ func (d *Database) CompleteSync( // Use a zero value SyncPosition for fromPos so all EDU states are added. err = d.addEDUDeltaToResponse( - types.NewStreamToken(0, 0), toPos, joinedRoomIDs, res, + types.NewStreamToken(0, 0, nil), toPos, joinedRoomIDs, res, ) if err != nil { return nil, err diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 474d3222b..1f679def3 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -163,7 +163,7 @@ func TestSyncResponse(t *testing.T) { Name: "IncrementalSync penultimate", DoSync: func() (*types.Response, error) { from := types.NewStreamToken( // pretend we are at the penultimate event - positions[len(positions)-2], types.StreamPosition(0), + positions[len(positions)-2], types.StreamPosition(0), nil, ) res := types.NewResponse() return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) @@ -176,7 +176,7 @@ func TestSyncResponse(t *testing.T) { Name: "IncrementalSync limited", DoSync: func() (*types.Response, error) { from := types.NewStreamToken( // pretend we are 10 events behind - positions[len(positions)-11], types.StreamPosition(0), + positions[len(positions)-11], types.StreamPosition(0), nil, ) res := types.NewResponse() // limit is set to 5 @@ -219,7 +219,7 @@ func TestSyncResponse(t *testing.T) { if err != nil { st.Fatalf("failed to do sync: %s", err) } - next := types.NewStreamToken(latest.PDUPosition(), latest.EDUPosition()) + next := types.NewStreamToken(latest.PDUPosition(), latest.EDUPosition(), nil) if res.NextBatch != next.String() { st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String()) } @@ -243,7 +243,7 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) { t.Fatalf("failed to get SyncPosition: %s", err) } from := types.NewStreamToken( - positions[len(positions)-2], types.StreamPosition(0), + positions[len(positions)-2], types.StreamPosition(0), nil, ) res := types.NewResponse() @@ -288,7 +288,7 @@ func TestGetEventsInRangeWithStreamToken(t *testing.T) { t.Fatalf("failed to get SyncPosition: %s", err) } // head towards the beginning of time - to := types.NewStreamToken(0, 0) + to := types.NewStreamToken(0, 0, nil) // backpaginate 5 messages starting at the latest position. paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true) @@ -531,14 +531,14 @@ func TestSendToDeviceBehaviour(t *testing.T) { // At this point there should be no messages. We haven't sent anything // yet. - events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0)) + events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0, nil)) if err != nil { t.Fatal(err) } if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 { t.Fatal("first call should have no updates") } - err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0)) + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0, nil)) if err != nil { return } @@ -556,14 +556,14 @@ func TestSendToDeviceBehaviour(t *testing.T) { // At this point we should get exactly one message. We're sending the sync position // that we were given from the update and the send-to-device update will be updated // in the database to reflect that this was the sync position we sent the message at. - events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos)) + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos, nil)) if err != nil { t.Fatal(err) } if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 { t.Fatal("second call should have one update") } - err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos)) + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos, nil)) if err != nil { return } @@ -571,35 +571,35 @@ func TestSendToDeviceBehaviour(t *testing.T) { // At this point we should still have one message because we haven't progressed the // sync position yet. This is equivalent to the client failing to /sync and retrying // with the same position. - events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos)) + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos, nil)) if err != nil { t.Fatal(err) } if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 { t.Fatal("third call should have one update still") } - err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos)) + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos, nil)) if err != nil { return } // At this point we should now have no updates, because we've progressed the sync // position. Therefore the update from before will not be sent again. - events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1)) + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1, nil)) if err != nil { t.Fatal(err) } if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 { t.Fatal("fourth call should have no updates") } - err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1)) + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1, nil)) if err != nil { return } // At this point we should still have no updates, because no new updates have been // sent. - events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2)) + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2, nil)) if err != nil { t.Fatal(err) } @@ -636,7 +636,7 @@ func TestInviteBehaviour(t *testing.T) { } // both invite events should appear in a new sync beforeRetireRes := types.NewResponse() - beforeRetireRes, err = db.IncrementalSync(ctx, beforeRetireRes, testUserDeviceA, types.NewStreamToken(0, 0), latest, 0, false) + beforeRetireRes, err = db.IncrementalSync(ctx, beforeRetireRes, testUserDeviceA, types.NewStreamToken(0, 0, nil), latest, 0, false) if err != nil { t.Fatalf("IncrementalSync failed: %s", err) } @@ -651,7 +651,7 @@ func TestInviteBehaviour(t *testing.T) { t.Fatalf("failed to get SyncPosition: %s", err) } res := types.NewResponse() - res, err = db.IncrementalSync(ctx, res, testUserDeviceA, types.NewStreamToken(0, 0), latest, 0, false) + res, err = db.IncrementalSync(ctx, res, testUserDeviceA, types.NewStreamToken(0, 0, nil), latest, 0, false) if err != nil { t.Fatalf("IncrementalSync failed: %s", err) } diff --git a/syncapi/sync/notifier_test.go b/syncapi/sync/notifier_test.go index f2a368ec2..5a4c7b31b 100644 --- a/syncapi/sync/notifier_test.go +++ b/syncapi/sync/notifier_test.go @@ -32,11 +32,11 @@ var ( randomMessageEvent gomatrixserverlib.HeaderedEvent aliceInviteBobEvent gomatrixserverlib.HeaderedEvent bobLeaveEvent gomatrixserverlib.HeaderedEvent - syncPositionVeryOld = types.NewStreamToken(5, 0) - syncPositionBefore = types.NewStreamToken(11, 0) - syncPositionAfter = types.NewStreamToken(12, 0) - syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition(), 1) - syncPositionAfter2 = types.NewStreamToken(13, 0) + syncPositionVeryOld = types.NewStreamToken(5, 0, nil) + syncPositionBefore = types.NewStreamToken(11, 0, nil) + syncPositionAfter = types.NewStreamToken(12, 0, nil) + syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition(), 1, nil) + syncPositionAfter2 = types.NewStreamToken(13, 0, nil) ) var ( diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index 41b18aa10..0996729e6 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -65,7 +65,7 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat since = &tok } if since == nil { - tok := types.NewStreamToken(0, 0) + tok := types.NewStreamToken(0, 0, nil) since = &tok } timelineLimit := DefaultTimelineLimit diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 196d446a2..bf6a9e01f 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -138,7 +138,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) { res = types.NewResponse() - since := types.NewStreamToken(0, 0) + since := types.NewStreamToken(0, 0, nil) if req.since != nil { since = *req.since } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 7dc022811..7bba8e522 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -39,6 +39,23 @@ var ( // StreamPosition represents the offset in the sync stream a client is at. type StreamPosition int64 +// LogPosition represents the offset in a Kafka log a client is at. +type LogPosition struct { + Partition int32 + Offset int64 +} + +// IsAfter returns true if this position is after `lp`. +func (p *LogPosition) IsAfter(lp *LogPosition) bool { + if lp == nil { + return false + } + if p.Partition != lp.Partition { + return false + } + return p.Offset > lp.Offset +} + // StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event. type StreamEvent struct { gomatrixserverlib.HeaderedEvent @@ -90,6 +107,15 @@ const ( type StreamingToken struct { syncToken + logs map[string]*LogPosition +} + +func (t *StreamingToken) Log(name string) *LogPosition { + l, ok := t.logs[name] + if !ok { + return nil + } + return l } func (t *StreamingToken) PDUPosition() StreamPosition { @@ -99,7 +125,15 @@ func (t *StreamingToken) EDUPosition() StreamPosition { return t.Positions[1] } func (t *StreamingToken) String() string { - return t.syncToken.String() + logStrings := []string{ + t.syncToken.String(), + } + for name, lp := range t.logs { + logStr := fmt.Sprintf("%s-%d-%d", name, lp.Partition, lp.Offset) + logStrings = append(logStrings, logStr) + } + // E.g s11_22_33.dl0-134.ab1-441 + return strings.Join(logStrings, ".") } // IsAfter returns true if ANY position in this token is greater than `other`. @@ -109,12 +143,22 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool { return true } } + for name := range t.logs { + otherLog := other.Log(name) + if otherLog == nil { + continue + } + if t.logs[name].IsAfter(otherLog) { + return true + } + } return false } // WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken. // If the latter StreamingToken contains a field that is not 0, it is considered an update, // and its value will replace the corresponding value in the StreamingToken on which WithUpdates is called. +// If the other token has a log, they will replace any existing log on this token. func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken) { ret.Type = t.Type ret.Positions = make([]StreamPosition, len(t.Positions)) @@ -125,6 +169,13 @@ func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken) } ret.Positions[i] = other.Positions[i] } + for name := range t.logs { + otherLog := other.Log(name) + if otherLog == nil { + continue + } + t.logs[name] = otherLog + } return ret } @@ -139,7 +190,7 @@ func (t *TopologyToken) PDUPosition() StreamPosition { return t.Positions[1] } func (t *TopologyToken) StreamToken() StreamingToken { - return NewStreamToken(t.PDUPosition(), 0) + return NewStreamToken(t.PDUPosition(), 0, nil) } func (t *TopologyToken) String() string { return t.syncToken.String() @@ -174,9 +225,9 @@ func (t *TopologyToken) Decrement() { // error if the token couldn't be parsed into an int64, or if the token type // isn't a known type (returns ErrInvalidSyncTokenType in the latter // case). -func newSyncTokenFromString(s string) (token *syncToken, err error) { +func newSyncTokenFromString(s string) (token *syncToken, categories []string, err error) { if len(s) == 0 { - return nil, ErrInvalidSyncTokenLen + return nil, nil, ErrInvalidSyncTokenLen } token = new(syncToken) @@ -185,16 +236,17 @@ func newSyncTokenFromString(s string) (token *syncToken, err error) { switch t := SyncTokenType(s[:1]); t { case SyncTokenTypeStream, SyncTokenTypeTopology: token.Type = t - positions = strings.Split(s[1:], "_") + categories = strings.Split(s[1:], ".") + positions = strings.Split(categories[0], "_") default: - return nil, ErrInvalidSyncTokenType + return nil, nil, ErrInvalidSyncTokenType } for _, pos := range positions { if posInt, err := strconv.ParseInt(pos, 10, 64); err != nil { - return nil, err + return nil, nil, err } else if posInt < 0 { - return nil, errors.New("negative position not allowed") + return nil, nil, errors.New("negative position not allowed") } else { token.Positions = append(token.Positions, StreamPosition(posInt)) } @@ -215,7 +267,7 @@ func NewTopologyToken(depth, streamPos StreamPosition) TopologyToken { } } func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) { - t, err := newSyncTokenFromString(tok) + t, _, err := newSyncTokenFromString(tok) if err != nil { return } @@ -233,16 +285,20 @@ func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) { } // NewStreamToken creates a new sync token for /sync -func NewStreamToken(pduPos, eduPos StreamPosition) StreamingToken { +func NewStreamToken(pduPos, eduPos StreamPosition, logs map[string]*LogPosition) StreamingToken { + if logs == nil { + logs = make(map[string]*LogPosition) + } return StreamingToken{ syncToken: syncToken{ Type: SyncTokenTypeStream, Positions: []StreamPosition{pduPos, eduPos}, }, + logs: logs, } } func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { - t, err := newSyncTokenFromString(tok) + t, categories, err := newSyncTokenFromString(tok) if err != nil { return } @@ -254,8 +310,35 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { err = fmt.Errorf("token %s wrong number of values, got %d want at least 2", tok, len(t.Positions)) return } + logs := make(map[string]*LogPosition) + if len(categories) > 1 { + // dl-0-1234 + // $log_name-$partition-$offset + for _, logStr := range categories[1:] { + segments := strings.Split(logStr, "-") + if len(segments) != 3 { + err = fmt.Errorf("token %s - invalid log: %s", tok, logStr) + return + } + var partition int64 + partition, err = strconv.ParseInt(segments[1], 10, 32) + if err != nil { + return + } + var offset int64 + offset, err = strconv.ParseInt(segments[2], 10, 64) + if err != nil { + return + } + logs[segments[0]] = &LogPosition{ + Partition: int32(partition), + Offset: offset, + } + } + } return StreamingToken{ syncToken: *t, + logs: logs, }, nil } diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index 1e27a8e32..7590ea522 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -1,11 +1,61 @@ package types -import "testing" +import ( + "reflect" + "testing" +) + +func TestNewSyncTokenWithLogs(t *testing.T) { + tests := map[string]*StreamingToken{ + "s4_0": &StreamingToken{ + syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}}, + logs: make(map[string]*LogPosition), + }, + "s4_0.dl-0-123": &StreamingToken{ + syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}}, + logs: map[string]*LogPosition{ + "dl": &LogPosition{ + Partition: 0, + Offset: 123, + }, + }, + }, + "s4_0.dl-0-123.ab-1-14419482332": &StreamingToken{ + syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}}, + logs: map[string]*LogPosition{ + "ab": &LogPosition{ + Partition: 1, + Offset: 14419482332, + }, + "dl": &LogPosition{ + Partition: 0, + Offset: 123, + }, + }, + }, + } + for tok, want := range tests { + got, err := NewStreamTokenFromString(tok) + if err != nil { + if want == nil { + continue // error expected + } + t.Errorf("%s errored: %s", tok, err) + continue + } + if !reflect.DeepEqual(got, *want) { + t.Errorf("%s mismatch: got %v want %v", tok, got, want) + } + if got.String() != tok { + t.Errorf("%s reserialisation mismatch: got %s want %s", tok, got.String(), tok) + } + } +} func TestNewSyncTokenFromString(t *testing.T) { shouldPass := map[string]syncToken{ - "s4_0": NewStreamToken(4, 0).syncToken, - "s3_1": NewStreamToken(3, 1).syncToken, + "s4_0": NewStreamToken(4, 0, nil).syncToken, + "s3_1": NewStreamToken(3, 1, nil).syncToken, "t3_1": NewTopologyToken(3, 1).syncToken, } @@ -21,7 +71,7 @@ func TestNewSyncTokenFromString(t *testing.T) { } for test, expected := range shouldPass { - result, err := newSyncTokenFromString(test) + result, _, err := newSyncTokenFromString(test) if err != nil { t.Error(err) } @@ -31,7 +81,7 @@ func TestNewSyncTokenFromString(t *testing.T) { } for _, test := range shouldFail { - if _, err := newSyncTokenFromString(test); err == nil { + if _, _, err := newSyncTokenFromString(test); err == nil { t.Errorf("input '%v' should have errored but didn't", test) } } From 9355fb5ac8c911bdbde6dcc0f279f716d8a8f60b Mon Sep 17 00:00:00 2001 From: Kegsay Date: Thu, 30 Jul 2020 11:15:46 +0100 Subject: [PATCH 4/4] Hook up device list updates to the sync notifier (#1231) * WIP hooking up key changes * Fix import cycle, get tests passing and binary compiling * Linting and update whitelist --- cmd/dendrite-sync-api-server/main.go | 4 +- internal/setup/monolith.go | 3 +- keyserver/api/api.go | 2 + keyserver/internal/internal.go | 4 + keyserver/producers/keychange.go | 9 + syncapi/consumers/keychange.go | 191 ++------------- syncapi/internal/keychange.go | 219 ++++++++++++++++++ .../{consumers => internal}/keychange_test.go | 93 ++++---- syncapi/sync/notifier.go | 10 + syncapi/sync/requestpool.go | 32 ++- syncapi/syncapi.go | 14 +- syncapi/types/types.go | 4 + sytest-whitelist | 1 + 13 files changed, 356 insertions(+), 230 deletions(-) create mode 100644 syncapi/internal/keychange.go rename syncapi/{consumers => internal}/keychange_test.go (86%) diff --git a/cmd/dendrite-sync-api-server/main.go b/cmd/dendrite-sync-api-server/main.go index d67395fb3..0761a1d10 100644 --- a/cmd/dendrite-sync-api-server/main.go +++ b/cmd/dendrite-sync-api-server/main.go @@ -29,7 +29,9 @@ func main() { rsAPI := base.RoomserverHTTPClient() - syncapi.AddPublicRoutes(base.PublicAPIMux, base.KafkaConsumer, userAPI, rsAPI, federation, cfg) + syncapi.AddPublicRoutes( + base.PublicAPIMux, base.KafkaConsumer, userAPI, rsAPI, base.KeyServerHTTPClient(), base.CurrentStateAPIClient(), + federation, cfg) base.SetupAndServeHTTP(string(base.Cfg.Bind.SyncAPI), string(base.Cfg.Listen.SyncAPI)) diff --git a/internal/setup/monolith.go b/internal/setup/monolith.go index 1f6d9a761..f33f97ee4 100644 --- a/internal/setup/monolith.go +++ b/internal/setup/monolith.go @@ -77,6 +77,7 @@ func (m *Monolith) AddAllPublicRoutes(publicMux *mux.Router) { ) mediaapi.AddPublicRoutes(publicMux, m.Config, m.UserAPI, m.Client) syncapi.AddPublicRoutes( - publicMux, m.KafkaConsumer, m.UserAPI, m.RoomserverAPI, m.FedClient, m.Config, + publicMux, m.KafkaConsumer, m.UserAPI, m.RoomserverAPI, + m.KeyAPI, m.StateAPI, m.FedClient, m.Config, ) } diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 406a252d5..c9afb09cc 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -143,6 +143,8 @@ type QueryKeyChangesRequest struct { type QueryKeyChangesResponse struct { // The set of users who have had their keys change. UserIDs []string + // The partition being served - useful if the partition is unknown at request time + Partition int32 // The latest offset represented in this response. Offset int64 // Set if there was a problem handling the request. diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 240a56403..9a41e44fc 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -41,6 +41,9 @@ type KeyInternalAPI struct { } func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) { + if req.Partition < 0 { + req.Partition = a.Producer.DefaultPartition() + } userIDs, latest, err := a.DB.KeyChanges(ctx, req.Partition, req.Offset) if err != nil { res.Error = &api.KeyError{ @@ -48,6 +51,7 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC } } res.Offset = latest + res.Partition = req.Partition res.UserIDs = userIDs } diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go index d59dd2002..c51d9f55d 100644 --- a/keyserver/producers/keychange.go +++ b/keyserver/producers/keychange.go @@ -31,6 +31,15 @@ type KeyChange struct { DB storage.Database } +// DefaultPartition returns the default partition this process is sending key changes to. +// NB: A keyserver MUST send key changes to only 1 partition or else query operations will +// become inconsistent. Partitions can be sharded (e.g by hash of user ID of key change) but +// then all keyservers must be queried to calculate the entire set of key changes between +// two sync tokens. +func (p *KeyChange) DefaultPartition() int32 { + return 0 +} + // ProduceKeyChanges creates new change events for each key func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error { for _, key := range keys { diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 78aff6011..35978be71 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -23,10 +23,11 @@ import ( currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" + syncinternal "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" + syncapi "github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" log "github.com/sirupsen/logrus" ) @@ -39,6 +40,7 @@ type OutputKeyChangeEventConsumer struct { keyAPI api.KeyInternalAPI partitionToOffset map[int32]int64 partitionToOffsetMu sync.Mutex + notifier *syncapi.Notifier } // NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer. @@ -47,6 +49,7 @@ func NewOutputKeyChangeEventConsumer( serverName gomatrixserverlib.ServerName, topic string, kafkaConsumer sarama.Consumer, + n *syncapi.Notifier, keyAPI api.KeyInternalAPI, currentStateAPI currentstateAPI.CurrentStateInternalAPI, store storage.Database, @@ -66,6 +69,7 @@ func NewOutputKeyChangeEventConsumer( currentStateAPI: currentStateAPI, partitionToOffset: make(map[int32]int64), partitionToOffsetMu: sync.Mutex{}, + notifier: n, } consumer.ProcessMessage = s.onMessage @@ -110,59 +114,22 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er return err } // TODO: f.e queryRes.UserIDsToCount : notify users by waking up streams + posUpdate := types.NewStreamToken(0, 0, map[string]*types.LogPosition{ + syncinternal.DeviceListLogName: &types.LogPosition{ + Offset: msg.Offset, + Partition: msg.Partition, + }, + }) + for userID := range queryRes.UserIDsToCount { + s.notifier.OnNewKeyChange(posUpdate, userID, output.UserID) + } return nil } -// Catchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response -// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST -// be already filled in with join/leave information. -func (s *OutputKeyChangeEventConsumer) Catchup( - ctx context.Context, userID string, res *types.Response, tok types.StreamingToken, -) (newTok *types.StreamingToken, hasNew bool, err error) { - // Track users who we didn't track before but now do by virtue of sharing a room with them, or not. - newlyJoinedRooms := joinedRooms(res, userID) - newlyLeftRooms := leftRooms(res) - if len(newlyJoinedRooms) > 0 || len(newlyLeftRooms) > 0 { - changed, left, err := s.trackChangedUsers(ctx, userID, newlyJoinedRooms, newlyLeftRooms) - if err != nil { - return nil, false, err - } - res.DeviceLists.Changed = changed - res.DeviceLists.Left = left - hasNew = len(changed) > 0 || len(left) > 0 - } - - // now also track users who we already share rooms with but who have updated their devices between the two tokens - // TODO: Extract partition/offset from sync token - var partition int32 - var offset int64 - var queryRes api.QueryKeyChangesResponse - s.keyAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{ - Partition: partition, - Offset: offset, - }, &queryRes) - if queryRes.Error != nil { - // don't fail the catchup because we may have got useful information by tracking membership - util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed") - } else { - // TODO: Make a new streaming token using the new offset - userSet := make(map[string]bool) - for _, userID := range res.DeviceLists.Changed { - userSet[userID] = true - } - for _, userID := range queryRes.UserIDs { - if !userSet[userID] { - res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) - } - } - } - return -} - func (s *OutputKeyChangeEventConsumer) OnJoinEvent(ev *gomatrixserverlib.HeaderedEvent) { // work out who we are now sharing rooms with which we previously were not and notify them about the joining // users keys: - changed, _, err := s.trackChangedUsers(context.Background(), *ev.StateKey(), []string{ev.RoomID()}, nil) + changed, _, err := syncinternal.TrackChangedUsers(context.Background(), s.currentStateAPI, *ev.StateKey(), []string{ev.RoomID()}, nil) if err != nil { log.WithError(err).Error("OnJoinEvent: failed to work out changed users") return @@ -175,7 +142,7 @@ func (s *OutputKeyChangeEventConsumer) OnJoinEvent(ev *gomatrixserverlib.Headere func (s *OutputKeyChangeEventConsumer) OnLeaveEvent(ev *gomatrixserverlib.HeaderedEvent) { // work out who we are no longer sharing any rooms with and notify them about the leaving user - _, left, err := s.trackChangedUsers(context.Background(), *ev.StateKey(), nil, []string{ev.RoomID()}) + _, left, err := syncinternal.TrackChangedUsers(context.Background(), s.currentStateAPI, *ev.StateKey(), nil, []string{ev.RoomID()}) if err != nil { log.WithError(err).Error("OnLeaveEvent: failed to work out left users") return @@ -186,129 +153,3 @@ func (s *OutputKeyChangeEventConsumer) OnLeaveEvent(ev *gomatrixserverlib.Header } } - -// nolint:gocyclo -func (s *OutputKeyChangeEventConsumer) trackChangedUsers( - ctx context.Context, userID string, newlyJoinedRooms, newlyLeftRooms []string, -) (changed, left []string, err error) { - // process leaves first, then joins afterwards so if we join/leave/join/leave we err on the side of including users. - - // Leave algorithm: - // - Get set of users and number of times they appear in rooms prior to leave. - QuerySharedUsersRequest with 'IncludeRoomID'. - // - Get users in newly left room. - QueryCurrentState - // - Loop set of users and decrement by 1 for each user in newly left room. - // - If count=0 then they share no more rooms so inform BOTH parties of this via 'left'=[...] in /sync. - var queryRes currentstateAPI.QuerySharedUsersResponse - err = s.currentStateAPI.QuerySharedUsers(ctx, ¤tstateAPI.QuerySharedUsersRequest{ - UserID: userID, - IncludeRoomIDs: newlyLeftRooms, - }, &queryRes) - if err != nil { - return nil, nil, err - } - var stateRes currentstateAPI.QueryBulkStateContentResponse - err = s.currentStateAPI.QueryBulkStateContent(ctx, ¤tstateAPI.QueryBulkStateContentRequest{ - RoomIDs: newlyLeftRooms, - StateTuples: []gomatrixserverlib.StateKeyTuple{ - { - EventType: gomatrixserverlib.MRoomMember, - StateKey: "*", - }, - }, - AllowWildcards: true, - }, &stateRes) - if err != nil { - return nil, nil, err - } - for _, state := range stateRes.Rooms { - for tuple, membership := range state { - if membership != gomatrixserverlib.Join { - continue - } - queryRes.UserIDsToCount[tuple.StateKey]-- - } - } - for userID, count := range queryRes.UserIDsToCount { - if count <= 0 { - left = append(left, userID) // left is returned - } - } - - // Join algorithm: - // - Get the set of all joined users prior to joining room - QuerySharedUsersRequest with 'ExcludeRoomID'. - // - Get users in newly joined room - QueryCurrentState - // - Loop set of users in newly joined room, do they appear in the set of users prior to joining? - // - If yes: then they already shared a room in common, do nothing. - // - If no: then they are a brand new user so inform BOTH parties of this via 'changed=[...]' - err = s.currentStateAPI.QuerySharedUsers(ctx, ¤tstateAPI.QuerySharedUsersRequest{ - UserID: userID, - ExcludeRoomIDs: newlyJoinedRooms, - }, &queryRes) - if err != nil { - return nil, left, err - } - err = s.currentStateAPI.QueryBulkStateContent(ctx, ¤tstateAPI.QueryBulkStateContentRequest{ - RoomIDs: newlyJoinedRooms, - StateTuples: []gomatrixserverlib.StateKeyTuple{ - { - EventType: gomatrixserverlib.MRoomMember, - StateKey: "*", - }, - }, - AllowWildcards: true, - }, &stateRes) - if err != nil { - return nil, left, err - } - for _, state := range stateRes.Rooms { - for tuple, membership := range state { - if membership != gomatrixserverlib.Join { - continue - } - // new user who we weren't previously sharing rooms with - if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok { - changed = append(changed, tuple.StateKey) // changed is returned - } - } - } - return changed, left, nil -} - -func joinedRooms(res *types.Response, userID string) []string { - var roomIDs []string - for roomID, join := range res.Rooms.Join { - // we would expect to see our join event somewhere if we newly joined the room. - // Normal events get put in the join section so it's not enough to know the room ID is present in 'join'. - newlyJoined := membershipEventPresent(join.State.Events, userID) - if newlyJoined { - roomIDs = append(roomIDs, roomID) - continue - } - newlyJoined = membershipEventPresent(join.Timeline.Events, userID) - if newlyJoined { - roomIDs = append(roomIDs, roomID) - } - } - return roomIDs -} - -func leftRooms(res *types.Response) []string { - roomIDs := make([]string, len(res.Rooms.Leave)) - i := 0 - for roomID := range res.Rooms.Leave { - roomIDs[i] = roomID - i++ - } - return roomIDs -} - -func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID string) bool { - for _, ev := range events { - // it's enough to know that we have our member event here, don't need to check membership content - // as it's implied by being in the respective section of the sync response. - if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID { - return true - } - } - return false -} diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go new file mode 100644 index 000000000..b594cc623 --- /dev/null +++ b/syncapi/internal/keychange.go @@ -0,0 +1,219 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + + "github.com/Shopify/sarama" + currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" + "github.com/matrix-org/dendrite/keyserver/api" + keyapi "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +const DeviceListLogName = "dl" + +// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response +// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST +// be already filled in with join/leave information. +func DeviceListCatchup( + ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, + userID string, res *types.Response, tok types.StreamingToken, +) (newTok *types.StreamingToken, hasNew bool, err error) { + // Track users who we didn't track before but now do by virtue of sharing a room with them, or not. + newlyJoinedRooms := joinedRooms(res, userID) + newlyLeftRooms := leftRooms(res) + if len(newlyJoinedRooms) > 0 || len(newlyLeftRooms) > 0 { + changed, left, err := TrackChangedUsers(ctx, stateAPI, userID, newlyJoinedRooms, newlyLeftRooms) + if err != nil { + return nil, false, err + } + res.DeviceLists.Changed = changed + res.DeviceLists.Left = left + hasNew = len(changed) > 0 || len(left) > 0 + } + + // now also track users who we already share rooms with but who have updated their devices between the two tokens + + var partition int32 + var offset int64 + // Extract partition/offset from sync token + // TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make. + logOffset := tok.Log(DeviceListLogName) + if logOffset != nil { + partition = logOffset.Partition + offset = logOffset.Offset + } else { + partition = -1 + offset = sarama.OffsetOldest + } + var queryRes api.QueryKeyChangesResponse + keyAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{ + Partition: partition, + Offset: offset, + }, &queryRes) + if queryRes.Error != nil { + // don't fail the catchup because we may have got useful information by tracking membership + util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed") + return + } + userSet := make(map[string]bool) + for _, userID := range res.DeviceLists.Changed { + userSet[userID] = true + } + for _, userID := range queryRes.UserIDs { + if !userSet[userID] { + res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) + hasNew = true + } + } + // Make a new streaming token using the new offset + tok.SetLog(DeviceListLogName, &types.LogPosition{ + Offset: queryRes.Offset, + Partition: queryRes.Partition, + }) + newTok = &tok + return +} + +// TrackChangedUsers calculates the values of device_lists.changed|left in the /sync response. +// nolint:gocyclo +func TrackChangedUsers( + ctx context.Context, stateAPI currentstateAPI.CurrentStateInternalAPI, userID string, newlyJoinedRooms, newlyLeftRooms []string, +) (changed, left []string, err error) { + // process leaves first, then joins afterwards so if we join/leave/join/leave we err on the side of including users. + + // Leave algorithm: + // - Get set of users and number of times they appear in rooms prior to leave. - QuerySharedUsersRequest with 'IncludeRoomID'. + // - Get users in newly left room. - QueryCurrentState + // - Loop set of users and decrement by 1 for each user in newly left room. + // - If count=0 then they share no more rooms so inform BOTH parties of this via 'left'=[...] in /sync. + var queryRes currentstateAPI.QuerySharedUsersResponse + err = stateAPI.QuerySharedUsers(ctx, ¤tstateAPI.QuerySharedUsersRequest{ + UserID: userID, + IncludeRoomIDs: newlyLeftRooms, + }, &queryRes) + if err != nil { + return nil, nil, err + } + var stateRes currentstateAPI.QueryBulkStateContentResponse + err = stateAPI.QueryBulkStateContent(ctx, ¤tstateAPI.QueryBulkStateContentRequest{ + RoomIDs: newlyLeftRooms, + StateTuples: []gomatrixserverlib.StateKeyTuple{ + { + EventType: gomatrixserverlib.MRoomMember, + StateKey: "*", + }, + }, + AllowWildcards: true, + }, &stateRes) + if err != nil { + return nil, nil, err + } + for _, state := range stateRes.Rooms { + for tuple, membership := range state { + if membership != gomatrixserverlib.Join { + continue + } + queryRes.UserIDsToCount[tuple.StateKey]-- + } + } + for userID, count := range queryRes.UserIDsToCount { + if count <= 0 { + left = append(left, userID) // left is returned + } + } + + // Join algorithm: + // - Get the set of all joined users prior to joining room - QuerySharedUsersRequest with 'ExcludeRoomID'. + // - Get users in newly joined room - QueryCurrentState + // - Loop set of users in newly joined room, do they appear in the set of users prior to joining? + // - If yes: then they already shared a room in common, do nothing. + // - If no: then they are a brand new user so inform BOTH parties of this via 'changed=[...]' + err = stateAPI.QuerySharedUsers(ctx, ¤tstateAPI.QuerySharedUsersRequest{ + UserID: userID, + ExcludeRoomIDs: newlyJoinedRooms, + }, &queryRes) + if err != nil { + return nil, left, err + } + err = stateAPI.QueryBulkStateContent(ctx, ¤tstateAPI.QueryBulkStateContentRequest{ + RoomIDs: newlyJoinedRooms, + StateTuples: []gomatrixserverlib.StateKeyTuple{ + { + EventType: gomatrixserverlib.MRoomMember, + StateKey: "*", + }, + }, + AllowWildcards: true, + }, &stateRes) + if err != nil { + return nil, left, err + } + for _, state := range stateRes.Rooms { + for tuple, membership := range state { + if membership != gomatrixserverlib.Join { + continue + } + // new user who we weren't previously sharing rooms with + if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok { + changed = append(changed, tuple.StateKey) // changed is returned + } + } + } + return changed, left, nil +} + +func joinedRooms(res *types.Response, userID string) []string { + var roomIDs []string + for roomID, join := range res.Rooms.Join { + // we would expect to see our join event somewhere if we newly joined the room. + // Normal events get put in the join section so it's not enough to know the room ID is present in 'join'. + newlyJoined := membershipEventPresent(join.State.Events, userID) + if newlyJoined { + roomIDs = append(roomIDs, roomID) + continue + } + newlyJoined = membershipEventPresent(join.Timeline.Events, userID) + if newlyJoined { + roomIDs = append(roomIDs, roomID) + } + } + return roomIDs +} + +func leftRooms(res *types.Response) []string { + roomIDs := make([]string, len(res.Rooms.Leave)) + i := 0 + for roomID := range res.Rooms.Leave { + roomIDs[i] = roomID + i++ + } + return roomIDs +} + +func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID string) bool { + for _, ev := range events { + // it's enough to know that we have our member event here, don't need to check membership content + // as it's implied by being in the respective section of the sync response. + if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID { + return true + } + } + return false +} diff --git a/syncapi/consumers/keychange_test.go b/syncapi/internal/keychange_test.go similarity index 86% rename from syncapi/consumers/keychange_test.go rename to syncapi/internal/keychange_test.go index 3ecb3f583..d0d27e448 100644 --- a/syncapi/consumers/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -1,4 +1,4 @@ -package consumers +package internal import ( "context" @@ -159,18 +159,17 @@ func leaveResponseWithRooms(syncResponse *types.Response, userID string, roomIDs func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) { newShareUser := "@bill:localhost" newlyJoinedRoom := "!TestKeyChangeCatchupOnJoinShareNewUser:bar" - consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockKeyAPI{}, &mockCurrentStateAPI{ + syncResponse := types.NewResponse() + syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom}) + + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{ roomIDToJoinedMembers: map[string][]string{ newlyJoinedRoom: {syncingUser, newShareUser}, "!another:room": {syncingUser}, }, - }, nil) - syncResponse := types.NewResponse() - syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom}) - - _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, emptyToken) + }, syncingUser, syncResponse, emptyToken) if err != nil { - t.Fatalf("Catchup returned an error: %s", err) + t.Fatalf("DeviceListCatchup returned an error: %s", err) } assertCatchup(t, hasNew, syncResponse, wantCatchup{ hasNew: true, @@ -182,18 +181,17 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) { func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) { removeUser := "@bill:localhost" newlyLeftRoom := "!TestKeyChangeCatchupOnLeaveShareLeftUser:bar" - consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockKeyAPI{}, &mockCurrentStateAPI{ + syncResponse := types.NewResponse() + syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom}) + + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{ roomIDToJoinedMembers: map[string][]string{ newlyLeftRoom: {removeUser}, "!another:room": {syncingUser}, }, - }, nil) - syncResponse := types.NewResponse() - syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom}) - - _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, emptyToken) + }, syncingUser, syncResponse, emptyToken) if err != nil { - t.Fatalf("Catchup returned an error: %s", err) + t.Fatalf("DeviceListCatchup returned an error: %s", err) } assertCatchup(t, hasNew, syncResponse, wantCatchup{ hasNew: true, @@ -205,16 +203,15 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) { func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) { existingUser := "@bob:localhost" newlyJoinedRoom := "!TestKeyChangeCatchupOnJoinShareNoNewUsers:bar" - consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockKeyAPI{}, &mockCurrentStateAPI{ + syncResponse := types.NewResponse() + syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom}) + + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{ roomIDToJoinedMembers: map[string][]string{ newlyJoinedRoom: {syncingUser, existingUser}, "!another:room": {syncingUser, existingUser}, }, - }, nil) - syncResponse := types.NewResponse() - syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom}) - - _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, emptyToken) + }, syncingUser, syncResponse, emptyToken) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -227,18 +224,17 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) { func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) { existingUser := "@bob:localhost" newlyLeftRoom := "!TestKeyChangeCatchupOnLeaveShareNoUsers:bar" - consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockKeyAPI{}, &mockCurrentStateAPI{ + syncResponse := types.NewResponse() + syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom}) + + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{ roomIDToJoinedMembers: map[string][]string{ newlyLeftRoom: {existingUser}, "!another:room": {syncingUser, existingUser}, }, - }, nil) - syncResponse := types.NewResponse() - syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom}) - - _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, emptyToken) + }, syncingUser, syncResponse, emptyToken) if err != nil { - t.Fatalf("Catchup returned an error: %s", err) + t.Fatalf("DeviceListCatchup returned an error: %s", err) } assertCatchup(t, hasNew, syncResponse, wantCatchup{ hasNew: false, @@ -249,11 +245,6 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) { func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) { existingUser := "@bob1:localhost" roomID := "!TestKeyChangeCatchupNoNewJoinsButMessages:bar" - consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockKeyAPI{}, &mockCurrentStateAPI{ - roomIDToJoinedMembers: map[string][]string{ - roomID: {syncingUser, existingUser}, - }, - }, nil) syncResponse := types.NewResponse() empty := "" roomStateEvents := []gomatrixserverlib.ClientEvent{ @@ -295,9 +286,13 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) { jr.Timeline.Events = roomTimelineEvents syncResponse.Rooms.Join[roomID] = jr - _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, emptyToken) + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{ + roomIDToJoinedMembers: map[string][]string{ + roomID: {syncingUser, existingUser}, + }, + }, syncingUser, syncResponse, emptyToken) if err != nil { - t.Fatalf("Catchup returned an error: %s", err) + t.Fatalf("DeviceListCatchup returned an error: %s", err) } assertCatchup(t, hasNew, syncResponse, wantCatchup{ hasNew: false, @@ -312,18 +307,17 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) { newlyLeftUser2 := "@debra:localhost" newlyJoinedRoom := "!join:bar" newlyLeftRoom := "!left:bar" - consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockKeyAPI{}, &mockCurrentStateAPI{ + syncResponse := types.NewResponse() + syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom}) + syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom}) + + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{ roomIDToJoinedMembers: map[string][]string{ newlyJoinedRoom: {syncingUser, newShareUser, newShareUser2}, newlyLeftRoom: {newlyLeftUser, newlyLeftUser2}, "!another:room": {syncingUser}, }, - }, nil) - syncResponse := types.NewResponse() - syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom}) - syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom}) - - _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, emptyToken) + }, syncingUser, syncResponse, emptyToken) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -348,12 +342,6 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) { newShareUser := "@berta:localhost" newShareUser2 := "@bobby:localhost" roomID := "!join:bar" - consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockKeyAPI{}, &mockCurrentStateAPI{ - roomIDToJoinedMembers: map[string][]string{ - roomID: {newShareUser, newShareUser2}, - "!another:room": {syncingUser}, - }, - }, nil) syncResponse := types.NewResponse() roomEvents := []gomatrixserverlib.ClientEvent{ { @@ -408,9 +396,14 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) { lr.Timeline.Events = roomEvents syncResponse.Rooms.Leave[roomID] = lr - _, hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, emptyToken) + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{ + roomIDToJoinedMembers: map[string][]string{ + roomID: {newShareUser, newShareUser2}, + "!another:room": {syncingUser}, + }, + }, syncingUser, syncResponse, emptyToken) if err != nil { - t.Fatalf("Catchup returned an error: %s", err) + t.Fatalf("DeviceListCatchup returned an error: %s", err) } assertCatchup(t, hasNew, syncResponse, wantCatchup{ hasNew: true, diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index 325e75351..df23a2f4a 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -132,6 +132,16 @@ func (n *Notifier) OnNewSendToDevice( n.wakeupUserDevice(userID, deviceIDs, latestPos) } +func (n *Notifier) OnNewKeyChange( + posUpdate types.StreamingToken, wakeUserID, keyChangeUserID string, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + latestPos := n.currPos.WithUpdates(posUpdate) + n.currPos = latestPos + n.wakeupUsers([]string{wakeUserID}, latestPos) +} + // GetListener returns a UserStreamListener that can be used to wait for // updates for a user. Must be closed. // notify for anything before sincePos diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index bf6a9e01f..754d69833 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -22,6 +22,9 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/jsonerror" + currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" + keyapi "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -35,11 +38,16 @@ type RequestPool struct { db storage.Database userAPI userapi.UserInternalAPI notifier *Notifier + keyAPI keyapi.KeyInternalAPI + stateAPI currentstateAPI.CurrentStateInternalAPI } // NewRequestPool makes a new RequestPool -func NewRequestPool(db storage.Database, n *Notifier, userAPI userapi.UserInternalAPI) *RequestPool { - return &RequestPool{db, userAPI, n} +func NewRequestPool( + db storage.Database, n *Notifier, userAPI userapi.UserInternalAPI, keyAPI keyapi.KeyInternalAPI, + stateAPI currentstateAPI.CurrentStateInternalAPI, +) *RequestPool { + return &RequestPool{db, userAPI, n, keyAPI, stateAPI} } // OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be @@ -164,6 +172,10 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea if err != nil { return } + res, err = rp.appendDeviceLists(res, req.device.UserID, since) + if err != nil { + return + } // Before we return the sync response, make sure that we take action on // any send-to-device database updates or deletions that we need to do. @@ -192,6 +204,22 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea return } +func (rp *RequestPool) appendDeviceLists( + data *types.Response, userID string, since types.StreamingToken, +) (*types.Response, error) { + // TODO: Currently this code will race which may result in duplicates but not missing data. + // This happens because, whilst we are told the range to fetch here (since / latest) the + // QueryKeyChanges API only exposes a "from" value (on purpose to avoid racing, which then + // returns the latest position with which the response has authority on). We'd need to tweak + // the API to expose a "to" value to fix this. + _, _, err := internal.DeviceListCatchup(context.Background(), rp.keyAPI, rp.stateAPI, userID, data, since) + if err != nil { + return nil, err + } + + return data, nil +} + // nolint:gocyclo func (rp *RequestPool) appendAccountData( data *types.Response, userID string, req syncRequest, currentPos types.StreamPosition, diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index caf91e27e..754cd5026 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -21,7 +21,9 @@ import ( "github.com/gorilla/mux" "github.com/sirupsen/logrus" + currentstateapi "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/internal/config" + keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -39,6 +41,8 @@ func AddPublicRoutes( consumer sarama.Consumer, userAPI userapi.UserInternalAPI, rsAPI api.RoomserverInternalAPI, + keyAPI keyapi.KeyInternalAPI, + currentStateAPI currentstateapi.CurrentStateInternalAPI, federation *gomatrixserverlib.FederationClient, cfg *config.Dendrite, ) { @@ -58,7 +62,7 @@ func AddPublicRoutes( logrus.WithError(err).Panicf("failed to start notifier") } - requestPool := sync.NewRequestPool(syncDB, notifier, userAPI) + requestPool := sync.NewRequestPool(syncDB, notifier, userAPI, keyAPI, currentStateAPI) roomConsumer := consumers.NewOutputRoomEventConsumer( cfg, consumer, notifier, syncDB, rsAPI, @@ -88,5 +92,13 @@ func AddPublicRoutes( logrus.WithError(err).Panicf("failed to start send-to-device consumer") } + keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( + cfg.Matrix.ServerName, string(cfg.Kafka.Topics.OutputKeyChangeEvent), + consumer, notifier, keyAPI, currentStateAPI, syncDB, + ) + if err = keyChangeConsumer.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start key change consumer") + } + routing.Setup(router, requestPool, syncDB, userAPI, federation, rsAPI, cfg) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 7bba8e522..f20c73bff 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -110,6 +110,10 @@ type StreamingToken struct { logs map[string]*LogPosition } +func (t *StreamingToken) SetLog(name string, lp *LogPosition) { + t.logs[name] = lp +} + func (t *StreamingToken) Log(name string) *LogPosition { l, ok := t.logs[name] if !ok { diff --git a/sytest-whitelist b/sytest-whitelist index 388f95e08..26922df4c 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -127,6 +127,7 @@ Can query specific device keys using POST query for user with no keys returns empty key dict Can claim one time key using POST Can claim remote one time key using POST +Local device key changes appear in v2 /sync Can add account data Can add account data to room Can get account data without syncing