From a7e67e65a8662387f1a5ba6860698743f9dbd60f Mon Sep 17 00:00:00 2001 From: Kegsay Date: Thu, 30 Jul 2020 18:00:56 +0100 Subject: [PATCH 1/4] Notify clients when devices are deleted (#1233) * Recheck device lists when join/leave events come in * Add PerformDeviceDeletion * Notify clients when devices are deleted * Unbreak things * Remove debug logging --- build/gobind/monolith.go | 6 +++-- clientapi/routing/device.go | 28 ++++++++++---------- clientapi/routing/routing.go | 4 +-- cmd/dendrite-demo-libp2p/main.go | 6 +++-- cmd/dendrite-demo-yggdrasil/main.go | 6 +++-- cmd/dendrite-key-server/main.go | 3 ++- cmd/dendrite-monolith-server/main.go | 5 ++-- cmd/dendrite-user-api-server/main.go | 2 +- cmd/dendritejs/main.go | 6 +++-- keyserver/api/api.go | 4 +++ keyserver/internal/internal.go | 13 +++++++++- keyserver/inthttp/client.go | 5 ++++ keyserver/keyserver.go | 4 +-- keyserver/producers/keychange.go | 9 ++++--- keyserver/storage/interface.go | 3 ++- syncapi/consumers/roomserver.go | 23 ++++++++++++++++ syncapi/internal/keychange_test.go | 3 +++ syncapi/syncapi.go | 18 ++++++------- sytest-whitelist | 1 + userapi/api/api.go | 10 +++++++ userapi/internal/api.go | 39 ++++++++++++++++++++++++++++ userapi/inthttp/client.go | 13 ++++++++++ userapi/inthttp/server.go | 13 ++++++++++ userapi/userapi.go | 4 ++- userapi/userapi_test.go | 2 +- 25 files changed, 183 insertions(+), 47 deletions(-) diff --git a/build/gobind/monolith.go b/build/gobind/monolith.go index 5d0d11fda..e2ff79c34 100644 --- a/build/gobind/monolith.go +++ b/build/gobind/monolith.go @@ -118,7 +118,9 @@ func (m *DendriteMonolith) Start() { serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, cfg.Derived.ApplicationServices) + keyAPI := keyserver.NewInternalAPI(base.Cfg, federation, base.KafkaProducer) + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, cfg.Derived.ApplicationServices, keyAPI) + keyAPI.SetUserAPI(userAPI) rsAPI := roomserver.NewInternalAPI( base, keyRing, federation, @@ -156,7 +158,7 @@ func (m *DendriteMonolith) Start() { RoomserverAPI: rsAPI, UserAPI: userAPI, StateAPI: stateAPI, - KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation, userAPI, base.KafkaProducer), + KeyAPI: keyAPI, ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ygg, fsAPI, federation, ), diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index 01310400a..11c6c7827 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -165,7 +165,7 @@ func UpdateDeviceByID( // DeleteDeviceById handles DELETE requests to /devices/{deviceId} func DeleteDeviceById( - req *http.Request, userInteractiveAuth *auth.UserInteractive, deviceDB devices.Database, device *api.Device, + req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.UserInternalAPI, device *api.Device, deviceID string, ) util.JSONResponse { ctx := req.Context() @@ -197,8 +197,12 @@ func DeleteDeviceById( } } - if err := deviceDB.RemoveDevice(ctx, deviceID, localpart); err != nil { - util.GetLogger(ctx).WithError(err).Error("deviceDB.RemoveDevice failed") + var res api.PerformDeviceDeletionResponse + if err := userAPI.PerformDeviceDeletion(ctx, &api.PerformDeviceDeletionRequest{ + UserID: device.UserID, + DeviceIDs: []string{deviceID}, + }, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("userAPI.PerformDeviceDeletion failed") return jsonerror.InternalServerError() } @@ -210,26 +214,24 @@ func DeleteDeviceById( // DeleteDevices handles POST requests to /delete_devices func DeleteDevices( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI api.UserInternalAPI, device *api.Device, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - ctx := req.Context() payload := devicesDeleteJSON{} if err := json.NewDecoder(req.Body).Decode(&payload); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.NewDecoder.Decode failed") + util.GetLogger(ctx).WithError(err).Error("json.NewDecoder.Decode failed") return jsonerror.InternalServerError() } defer req.Body.Close() // nolint: errcheck - if err := deviceDB.RemoveDevices(ctx, localpart, payload.Devices); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveDevices failed") + var res api.PerformDeviceDeletionResponse + if err := userAPI.PerformDeviceDeletion(ctx, &api.PerformDeviceDeletionRequest{ + UserID: device.UserID, + DeviceIDs: payload.Devices, + }, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("userAPI.PerformDeviceDeletion failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index ebb141ef8..6c40db865 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -654,13 +654,13 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return DeleteDeviceById(req, userInteractiveAuth, deviceDB, device, vars["deviceID"]) + return DeleteDeviceById(req, userInteractiveAuth, userAPI, device, vars["deviceID"]) }), ).Methods(http.MethodDelete, http.MethodOptions) r0mux.Handle("/delete_devices", httputil.MakeAuthAPI("delete_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return DeleteDevices(req, deviceDB, device) + return DeleteDevices(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index ed7310181..4e774acc9 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -141,7 +141,9 @@ func main() { accountDB := base.Base.CreateAccountsDB() deviceDB := base.Base.CreateDeviceDB() federation := createFederationClient(base) - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, nil) + keyAPI := keyserver.NewInternalAPI(base.Base.Cfg, federation, base.Base.KafkaProducer) + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, nil, keyAPI) + keyAPI.SetUserAPI(userAPI) serverKeyAPI := serverkeyapi.NewInternalAPI( base.Base.Cfg, federation, base.Base.Caches, @@ -186,7 +188,7 @@ func main() { ServerKeyAPI: serverKeyAPI, StateAPI: stateAPI, UserAPI: userAPI, - KeyAPI: keyserver.NewInternalAPI(base.Base.Cfg, federation, userAPI, base.Base.KafkaProducer), + KeyAPI: keyAPI, ExtPublicRoomsProvider: provider, } monolith.AddAllPublicRoutes(base.Base.PublicAPIMux) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 5fd761029..0655a2a3b 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -103,7 +103,9 @@ func main() { serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, nil) + keyAPI := keyserver.NewInternalAPI(base.Cfg, federation, base.KafkaProducer) + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, nil, keyAPI) + keyAPI.SetUserAPI(userAPI) rsComponent := roomserver.NewInternalAPI( base, keyRing, federation, @@ -142,7 +144,7 @@ func main() { RoomserverAPI: rsAPI, UserAPI: userAPI, StateAPI: stateAPI, - KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation, userAPI, base.KafkaProducer), + KeyAPI: keyAPI, //ServerKeyAPI: serverKeyAPI, ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ygg, fsAPI, federation, diff --git a/cmd/dendrite-key-server/main.go b/cmd/dendrite-key-server/main.go index d58d475a2..94ea819f8 100644 --- a/cmd/dendrite-key-server/main.go +++ b/cmd/dendrite-key-server/main.go @@ -24,7 +24,8 @@ func main() { base := setup.NewBaseDendrite(cfg, "KeyServer", true) defer base.Close() // nolint: errcheck - intAPI := keyserver.NewInternalAPI(base.Cfg, base.CreateFederationClient(), base.UserAPIClient(), base.KafkaProducer) + intAPI := keyserver.NewInternalAPI(base.Cfg, base.CreateFederationClient(), base.KafkaProducer) + intAPI.SetUserAPI(base.UserAPIClient()) keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index b312579cb..bce5fce08 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -76,7 +76,9 @@ func main() { serverKeyAPI = base.ServerKeyAPIClient() } keyRing := serverKeyAPI.KeyRing() - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, cfg.Derived.ApplicationServices) + keyAPI := keyserver.NewInternalAPI(base.Cfg, federation, base.KafkaProducer) + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, cfg.Derived.ApplicationServices, keyAPI) + keyAPI.SetUserAPI(userAPI) rsImpl := roomserver.NewInternalAPI( base, keyRing, federation, @@ -119,7 +121,6 @@ func main() { rsImpl.SetFederationSenderAPI(fsAPI) stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer) - keyAPI := keyserver.NewInternalAPI(base.Cfg, federation, userAPI, base.KafkaProducer) monolith := setup.Monolith{ Config: base.Cfg, diff --git a/cmd/dendrite-user-api-server/main.go b/cmd/dendrite-user-api-server/main.go index 4257da3f3..e6d61da10 100644 --- a/cmd/dendrite-user-api-server/main.go +++ b/cmd/dendrite-user-api-server/main.go @@ -27,7 +27,7 @@ func main() { accountDB := base.CreateAccountsDB() deviceDB := base.CreateDeviceDB() - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, cfg.Derived.ApplicationServices) + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient()) userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index 70e5279cf..0df53e06d 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -196,7 +196,9 @@ func main() { accountDB := base.CreateAccountsDB() deviceDB := base.CreateDeviceDB() federation := createFederationClient(cfg, node) - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, nil) + keyAPI := keyserver.NewInternalAPI(base.Cfg, federation, base.KafkaProducer) + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Matrix.ServerName, nil, keyAPI) + keyAPI.SetUserAPI(userAPI) fetcher := &libp2pKeyFetcher{} keyRing := gomatrixserverlib.KeyRing{ @@ -233,7 +235,7 @@ func main() { RoomserverAPI: rsAPI, StateAPI: stateAPI, UserAPI: userAPI, - KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation, userAPI, base.KafkaProducer), + KeyAPI: keyAPI, //ServerKeyAPI: serverKeyAPI, ExtPublicRoomsProvider: p2pPublicRoomProvider, } diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 98bcd9442..6795498fe 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -19,9 +19,13 @@ import ( "encoding/json" "strings" "time" + + userapi "github.com/matrix-org/dendrite/userapi/api" ) type KeyInternalAPI interface { + // SetUserAPI assigns a user API to query when extracting device names. + SetUserAPI(i userapi.UserInternalAPI) PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) // PerformClaimKeys claims one-time keys for use in pre-key messages PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 703713538..480d1084e 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -40,6 +40,10 @@ type KeyInternalAPI struct { Producer *producers.KeyChange } +func (a *KeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) { + a.UserAPI = i +} + func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) { if req.Partition < 0 { req.Partition = a.Producer.DefaultPartition() @@ -272,6 +276,10 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU var keysToStore []api.DeviceKeys // assert that the user ID / device ID are not lying for each key for _, key := range req.DeviceKeys { + if len(key.KeyJSON) == 0 { + keysToStore = append(keysToStore, key) + continue // deleted keys don't need sanity checking + } gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str if gotUserID == key.UserID && gotDeviceID == key.DeviceID { @@ -286,6 +294,7 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU ), }) } + // get existing device keys so we can check for changes existingKeys := make([]api.DeviceKeys, len(keysToStore)) for i := range keysToStore { @@ -358,7 +367,9 @@ func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) er for _, newKey := range new { exists := false for _, existingKey := range existing { - if bytes.Equal(existingKey.KeyJSON, newKey.KeyJSON) { + // Do not treat the absence of keys as equal, or else we will not emit key changes + // when users delete devices which never had a key to begin with as both KeyJSONs are nil. + if bytes.Equal(existingKey.KeyJSON, newKey.KeyJSON) && len(existingKey.KeyJSON) > 0 { exists = true break } diff --git a/keyserver/inthttp/client.go b/keyserver/inthttp/client.go index cd9cf70d4..3f9690b51 100644 --- a/keyserver/inthttp/client.go +++ b/keyserver/inthttp/client.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/opentracing/opentracing-go" ) @@ -52,6 +53,10 @@ type httpKeyInternalAPI struct { httpClient *http.Client } +func (h *httpKeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) { + // no-op: doesn't need it +} + func (h *httpKeyInternalAPI) PerformClaimKeys( ctx context.Context, request *api.PerformClaimKeysRequest, diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index c748d7ce3..36bedf34b 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/keyserver/inthttp" "github.com/matrix-org/dendrite/keyserver/producers" "github.com/matrix-org/dendrite/keyserver/storage" - userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) @@ -37,7 +36,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - cfg *config.Dendrite, fedClient *gomatrixserverlib.FederationClient, userAPI userapi.UserInternalAPI, producer sarama.SyncProducer, + cfg *config.Dendrite, fedClient *gomatrixserverlib.FederationClient, producer sarama.SyncProducer, ) api.KeyInternalAPI { db, err := storage.NewDatabase( string(cfg.Database.E2EKey), @@ -55,7 +54,6 @@ func NewInternalAPI( DB: db, ThisServer: cfg.Matrix.ServerName, FedClient: fedClient, - UserAPI: userAPI, Producer: keyChangeProducer, } } diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go index c51d9f55d..6035b67bd 100644 --- a/keyserver/producers/keychange.go +++ b/keyserver/producers/keychange.go @@ -63,10 +63,11 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error { return err } logrus.WithFields(logrus.Fields{ - "user_id": key.UserID, - "device_id": key.DeviceID, - "partition": partition, - "offset": offset, + "user_id": key.UserID, + "device_id": key.DeviceID, + "partition": partition, + "offset": offset, + "len_key_bytes": len(key.KeyJSON), }).Infof("Produced to key change topic '%s'", p.Topic) } return nil diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 7a4fce6f5..fade75228 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -32,7 +32,8 @@ type Database interface { // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced. DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error - // StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. + // StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). // Returns an error if there was a problem storing the keys. StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index da4a5366c..f8cdcd5c1 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -35,6 +35,7 @@ type OutputRoomEventConsumer struct { rsConsumer *internal.ContinualConsumer db storage.Database notifier *sync.Notifier + keyChanges *OutputKeyChangeEventConsumer } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -44,6 +45,7 @@ func NewOutputRoomEventConsumer( n *sync.Notifier, store storage.Database, rsAPI api.RoomserverInternalAPI, + keyChanges *OutputKeyChangeEventConsumer, ) *OutputRoomEventConsumer { consumer := internal.ContinualConsumer{ @@ -56,6 +58,7 @@ func NewOutputRoomEventConsumer( db: store, notifier: n, rsAPI: rsAPI, + keyChanges: keyChanges, } consumer.ProcessMessage = s.onMessage @@ -160,9 +163,29 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( } s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0, nil)) + s.notifyKeyChanges(&ev) + return nil } +func (s *OutputRoomEventConsumer) notifyKeyChanges(ev *gomatrixserverlib.HeaderedEvent) { + if ev.Type() != gomatrixserverlib.MRoomMember || ev.StateKey() == nil { + return + } + membership, err := ev.Membership() + if err != nil { + return + } + switch membership { + case gomatrixserverlib.Join: + s.keyChanges.OnJoinEvent(ev) + case gomatrixserverlib.Ban: + fallthrough + case gomatrixserverlib.Leave: + s.keyChanges.OnLeaveEvent(ev) + } +} + func (s *OutputRoomEventConsumer) onNewInviteEvent( ctx context.Context, msg api.OutputNewInviteEvent, ) error { diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index 3f18696c4..2c3d154d5 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -10,6 +10,7 @@ import ( "github.com/matrix-org/dendrite/currentstateserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" ) @@ -29,6 +30,8 @@ type mockKeyAPI struct{} func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) { } +func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {} + // PerformClaimKeys claims one-time keys for use in pre-key messages func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) { } diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 754cd5026..5198d59b1 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -64,8 +64,16 @@ func AddPublicRoutes( requestPool := sync.NewRequestPool(syncDB, notifier, userAPI, keyAPI, currentStateAPI) + keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( + cfg.Matrix.ServerName, string(cfg.Kafka.Topics.OutputKeyChangeEvent), + consumer, notifier, keyAPI, currentStateAPI, syncDB, + ) + if err = keyChangeConsumer.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start key change consumer") + } + roomConsumer := consumers.NewOutputRoomEventConsumer( - cfg, consumer, notifier, syncDB, rsAPI, + cfg, consumer, notifier, syncDB, rsAPI, keyChangeConsumer, ) if err = roomConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start room server consumer") @@ -92,13 +100,5 @@ func AddPublicRoutes( logrus.WithError(err).Panicf("failed to start send-to-device consumer") } - keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( - cfg.Matrix.ServerName, string(cfg.Kafka.Topics.OutputKeyChangeEvent), - consumer, notifier, keyAPI, currentStateAPI, syncDB, - ) - if err = keyChangeConsumer.Start(); err != nil { - logrus.WithError(err).Panicf("failed to start key change consumer") - } - routing.Setup(router, requestPool, syncDB, userAPI, federation, rsAPI, cfg) } diff --git a/sytest-whitelist b/sytest-whitelist index 341df8a9a..03baf4d44 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -129,6 +129,7 @@ Can claim one time key using POST Can claim remote one time key using POST Local device key changes appear in v2 /sync Local device key changes appear in /keys/changes +Local delete device changes appear in v2 /sync Get left notifs for other users in sync and /keys/changes when user leaves Can add account data Can add account data to room diff --git a/userapi/api/api.go b/userapi/api/api.go index 5791403ff..5c964c4fd 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -27,6 +27,7 @@ type UserInternalAPI interface { InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error + PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error @@ -47,6 +48,15 @@ type InputAccountDataRequest struct { type InputAccountDataResponse struct { } +type PerformDeviceDeletionRequest struct { + UserID string + // The devices to delete + DeviceIDs []string +} + +type PerformDeviceDeletionResponse struct { +} + // QueryDeviceInfosRequest is the request to QueryDeviceInfos type QueryDeviceInfosRequest struct { DeviceIDs []string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 5b1541967..738023dd6 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -25,10 +25,12 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/sqlutil" + keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" ) type UserInternalAPI struct { @@ -37,6 +39,7 @@ type UserInternalAPI struct { ServerName gomatrixserverlib.ServerName // AppServices is the list of all registered AS AppServices []config.ApplicationService + KeyAPI keyapi.KeyInternalAPI } func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { @@ -104,6 +107,42 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe return nil } +func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.PerformDeviceDeletionRequest, res *api.PerformDeviceDeletionResponse) error { + util.GetLogger(ctx).WithField("user_id", req.UserID).WithField("devices", req.DeviceIDs).Info("PerformDeviceDeletion") + local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot PerformDeviceDeletion of remote users: got %s want %s", domain, a.ServerName) + } + err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) + if err != nil { + return err + } + // create empty device keys and upload them to delete what was once there and trigger device list changes + deviceKeys := make([]keyapi.DeviceKeys, len(req.DeviceIDs)) + for i, did := range req.DeviceIDs { + deviceKeys[i] = keyapi.DeviceKeys{ + UserID: req.UserID, + DeviceID: did, + KeyJSON: nil, + } + } + + var uploadRes keyapi.PerformUploadKeysResponse + a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ + DeviceKeys: deviceKeys, + }, &uploadRes) + if uploadRes.Error != nil { + return fmt.Errorf("Failed to delete device keys: %v", uploadRes.Error) + } + if len(uploadRes.KeyErrors) > 0 { + return fmt.Errorf("Failed to delete device keys, key errors: %+v", uploadRes.KeyErrors) + } + return nil +} + func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfileRequest, res *api.QueryProfileResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 3e1ac0662..47e2110f9 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -30,6 +30,7 @@ const ( PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformAccountCreationPath = "/userapi/performAccountCreation" + PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" QueryProfilePath = "/userapi/queryProfile" QueryAccessTokenPath = "/userapi/queryAccessToken" @@ -91,6 +92,18 @@ func (h *httpUserInternalAPI) PerformDeviceCreation( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +func (h *httpUserInternalAPI) PerformDeviceDeletion( + ctx context.Context, + request *api.PerformDeviceDeletionRequest, + response *api.PerformDeviceDeletionResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceDeletion") + defer span.Finish() + + apiURL := h.apiURL + PerformDeviceDeletionPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + func (h *httpUserInternalAPI) QueryProfile( ctx context.Context, request *api.QueryProfileRequest, diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index d29f4d442..ebb9bf4e8 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -52,6 +52,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(PerformDeviceDeletionPath, + httputil.MakeInternalAPI("performDeviceDeletion", func(req *http.Request) util.JSONResponse { + request := api.PerformDeviceDeletionRequest{} + response := api.PerformDeviceDeletionResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformDeviceDeletion(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(QueryProfilePath, httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse { request := api.QueryProfileRequest{} diff --git a/userapi/userapi.go b/userapi/userapi.go index 7aadec06a..c4ab90bac 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -17,6 +17,7 @@ package userapi import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/config" + keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/inthttp" @@ -34,12 +35,13 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI(accountDB accounts.Database, deviceDB devices.Database, - serverName gomatrixserverlib.ServerName, appServices []config.ApplicationService) api.UserInternalAPI { + serverName gomatrixserverlib.ServerName, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI) api.UserInternalAPI { return &internal.UserInternalAPI{ AccountDB: accountDB, DeviceDB: deviceDB, ServerName: serverName, AppServices: appServices, + KeyAPI: keyAPI, } } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 163b10ec7..dab1ec716 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -32,7 +32,7 @@ func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database, t.Fatalf("failed to create device DB: %s", err) } - return userapi.NewInternalAPI(accountDB, deviceDB, serverName, nil), accountDB, deviceDB + return userapi.NewInternalAPI(accountDB, deviceDB, serverName, nil, nil), accountDB, deviceDB } func TestQueryProfile(t *testing.T) { From b5cb1d153458ad83abdfbebed7405dd9da159cb8 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Fri, 31 Jul 2020 14:40:45 +0100 Subject: [PATCH 2/4] Fix edge cases around device lists (#1234) * Fix New users appear in /keys/changes * Create blank device keys when logging in on a new device * Add PerformDeviceUpdate and fix a few bugs - Correct device deletion query on sqlite - Return no keys on /keys/query rather than an empty key * Unbreak sqlite properly * Use a real DB for currentstateserver integration tests * Race fix --- clientapi/routing/device.go | 48 ++++++++----------- clientapi/routing/login.go | 24 ++++++---- clientapi/routing/routing.go | 4 +- currentstateserver/currentstateserver_test.go | 23 ++++++--- keyserver/internal/internal.go | 3 ++ syncapi/internal/keychange.go | 33 +++++++++++++ syncapi/sync/requestpool.go | 2 +- sytest-whitelist | 4 ++ userapi/api/api.go | 11 +++++ userapi/internal/api.go | 42 ++++++++++++++-- userapi/inthttp/client.go | 9 ++++ userapi/inthttp/server.go | 13 +++++ .../storage/devices/sqlite3/devices_table.go | 3 +- 13 files changed, 167 insertions(+), 52 deletions(-) diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index 11c6c7827..d0b3bdbe5 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -115,33 +115,9 @@ func GetDevicesByLocalpart( // UpdateDeviceByID handles PUT on /devices/{deviceID} func UpdateDeviceByID( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI api.UserInternalAPI, device *api.Device, deviceID string, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - - ctx := req.Context() - dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID) - if err == sql.ErrNoRows { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Unknown device"), - } - } else if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByID failed") - return jsonerror.InternalServerError() - } - - if dev.UserID != device.UserID { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("device not owned by current user"), - } - } defer req.Body.Close() // nolint: errcheck @@ -152,10 +128,28 @@ func UpdateDeviceByID( return jsonerror.InternalServerError() } - if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.UpdateDevice failed") + var performRes api.PerformDeviceUpdateResponse + err := userAPI.PerformDeviceUpdate(req.Context(), &api.PerformDeviceUpdateRequest{ + RequestingUserID: device.UserID, + DeviceID: deviceID, + DisplayName: payload.DisplayName, + }, &performRes) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceUpdate failed") return jsonerror.InternalServerError() } + if !performRes.DeviceExists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.Forbidden("device does not exist"), + } + } + if performRes.Forbidden { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("device not owned by current user"), + } + } return util.JSONResponse{ Code: http.StatusOK, diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 7f47aafff..42f828f64 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -23,8 +23,8 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/config" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -57,7 +57,7 @@ func passwordLogin() flows { // Login implements GET and POST /login func Login( - req *http.Request, accountDB accounts.Database, deviceDB devices.Database, + req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI, cfg *config.Dendrite, ) util.JSONResponse { if req.Method == http.MethodGet { @@ -81,7 +81,7 @@ func Login( return *authErr } // make a device/access token - return completeAuth(req.Context(), cfg.Matrix.ServerName, deviceDB, login) + return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login) } return util.JSONResponse{ Code: http.StatusMethodNotAllowed, @@ -90,7 +90,7 @@ func Login( } func completeAuth( - ctx context.Context, serverName gomatrixserverlib.ServerName, deviceDB devices.Database, login *auth.Login, + ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.UserInternalAPI, login *auth.Login, ) util.JSONResponse { token, err := auth.GenerateAccessToken() if err != nil { @@ -104,9 +104,13 @@ func completeAuth( return jsonerror.InternalServerError() } - dev, err := deviceDB.CreateDevice( - ctx, localpart, login.DeviceID, token, login.InitialDisplayName, - ) + var performRes userapi.PerformDeviceCreationResponse + err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ + DeviceDisplayName: login.InitialDisplayName, + DeviceID: login.DeviceID, + AccessToken: token, + Localpart: localpart, + }, &performRes) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -117,10 +121,10 @@ func completeAuth( return util.JSONResponse{ Code: http.StatusOK, JSON: loginResponse{ - UserID: dev.UserID, - AccessToken: dev.AccessToken, + UserID: performRes.Device.UserID, + AccessToken: performRes.Device.AccessToken, HomeServer: serverName, - DeviceID: dev.ID, + DeviceID: performRes.Device.ID, }, } } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 6c40db865..0e58129ef 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -387,7 +387,7 @@ func Setup( r0mux.Handle("/login", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { - return Login(req, accountDB, deviceDB, cfg) + return Login(req, accountDB, userAPI, cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) @@ -644,7 +644,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return UpdateDeviceByID(req, deviceDB, device, vars["deviceID"]) + return UpdateDeviceByID(req, userAPI, device, vars["deviceID"]) }), ).Methods(http.MethodPut, http.MethodOptions) diff --git a/currentstateserver/currentstateserver_test.go b/currentstateserver/currentstateserver_test.go index 1366a0be8..193173906 100644 --- a/currentstateserver/currentstateserver_test.go +++ b/currentstateserver/currentstateserver_test.go @@ -19,6 +19,7 @@ import ( "crypto/ed25519" "encoding/json" "net/http" + "os" "reflect" "testing" "time" @@ -91,11 +92,13 @@ func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *rooms return nil } -func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.SyncProducer) { +func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.SyncProducer, func()) { cfg := &config.Dendrite{} + stateDBName := "test_state.db" + naffkaDBName := "test_naffka.db" cfg.Kafka.Topics.OutputRoomEvent = config.Topic(kafkaTopic) - cfg.Database.CurrentState = config.DataSource("file::memory:") - db, err := sqlutil.Open(sqlutil.SQLiteDriverName(), "file::memory:", nil) + cfg.Database.CurrentState = config.DataSource("file:" + stateDBName) + db, err := sqlutil.Open(sqlutil.SQLiteDriverName(), "file:"+naffkaDBName, nil) if err != nil { t.Fatalf("Failed to open naffka database: %s", err) } @@ -107,11 +110,15 @@ func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.Sync if err != nil { t.Fatalf("Failed to create naffka consumer: %s", err) } - return NewInternalAPI(cfg, naff), naff + return NewInternalAPI(cfg, naff), naff, func() { + os.Remove(naffkaDBName) + os.Remove(stateDBName) + } } func TestQueryCurrentState(t *testing.T) { - currStateAPI, producer := MustMakeInternalAPI(t) + currStateAPI, producer, cancel := MustMakeInternalAPI(t) + defer cancel() plTuple := gomatrixserverlib.StateKeyTuple{ EventType: "m.room.power_levels", StateKey: "", @@ -209,7 +216,8 @@ func mustMakeMembershipEvent(t *testing.T, roomID, userID, membership string) *r // This test makes sure that QuerySharedUsers is returning the correct users for a range of sets. func TestQuerySharedUsers(t *testing.T) { - currStateAPI, producer := MustMakeInternalAPI(t) + currStateAPI, producer, cancel := MustMakeInternalAPI(t) + defer cancel() MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join")) MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join")) @@ -222,6 +230,9 @@ func TestQuerySharedUsers(t *testing.T) { MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join")) + // we don't know when the server has processed the events + time.Sleep(10 * time.Millisecond) + testCases := []struct { req api.QuerySharedUsersRequest wantRes api.QuerySharedUsersResponse diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 480d1084e..bb8286635 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -206,6 +206,9 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques res.DeviceKeys[userID] = make(map[string]json.RawMessage) } for _, dk := range deviceKeys { + if len(dk.KeyJSON) == 0 { + continue // don't include blank keys + } // inject display name if known dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { DisplayName string `json:"device_display_name,omitempty"` diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index cb4fca7d8..0272ffc06 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -16,6 +16,7 @@ package internal import ( "context" + "strings" "github.com/Shopify/sarama" currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" @@ -88,6 +89,16 @@ func DeviceListCatchup( if !userSet[userID] { res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) hasNew = true + userSet[userID] = true + } + } + // if the response has any join/leave events, add them now. + // TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them. + for _, userID := range membershipEvents(res) { + if !userSet[userID] { + res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) + hasNew = true + userSet[userID] = true } } return hasNew, nil @@ -219,3 +230,25 @@ func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID strin } return false } + +// returns the user IDs of anyone joining or leaving a room in this response. These users will be added to +// the 'changed' property because of https://matrix.org/docs/spec/client_server/r0.6.1#id84 +// "For optimal performance, Alice should be added to changed in Bob's sync only when she adds a new device, +// or when Alice and Bob now share a room but didn't share any room previously. However, for the sake of simpler +// logic, a server may add Alice to changed when Alice and Bob share a new room, even if they previously already shared a room." +func membershipEvents(res *types.Response) (userIDs []string) { + for _, room := range res.Rooms.Join { + for _, ev := range room.Timeline.Events { + if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil { + if strings.Contains(string(ev.Content), `"join"`) { + userIDs = append(userIDs, *ev.StateKey) + } else if strings.Contains(string(ev.Content), `"leave"`) { + userIDs = append(userIDs, *ev.StateKey) + } else if strings.Contains(string(ev.Content), `"ban"`) { + userIDs = append(userIDs, *ev.StateKey) + } + } + } + } + return +} diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index f817f0981..b530b34d1 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -168,7 +168,7 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use } // work out room joins/leaves res, err := rp.db.IncrementalSync( - req.Context(), types.NewResponse(), *device, fromToken, toToken, 0, false, + req.Context(), types.NewResponse(), *device, fromToken, toToken, 10, false, ) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("Failed to IncrementalSync") diff --git a/sytest-whitelist b/sytest-whitelist index 03baf4d44..16a71c648 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -129,7 +129,11 @@ Can claim one time key using POST Can claim remote one time key using POST Local device key changes appear in v2 /sync Local device key changes appear in /keys/changes +New users appear in /keys/changes Local delete device changes appear in v2 /sync +Local new device changes appear in v2 /sync +Local update device changes appear in v2 /sync +Users receive device_list updates for their own devices Get left notifs for other users in sync and /keys/changes when user leaves Can add account data Can add account data to room diff --git a/userapi/api/api.go b/userapi/api/api.go index 5c964c4fd..84338dbf2 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -28,6 +28,7 @@ type UserInternalAPI interface { PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error + PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error @@ -48,6 +49,16 @@ type InputAccountDataRequest struct { type InputAccountDataResponse struct { } +type PerformDeviceUpdateRequest struct { + RequestingUserID string + DeviceID string + DisplayName *string +} +type PerformDeviceUpdateResponse struct { + DeviceExists bool + Forbidden bool +} + type PerformDeviceDeletionRequest struct { UserID string // The devices to delete diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 738023dd6..b9d188229 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -104,7 +104,8 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe } res.DeviceCreated = true res.Device = dev - return nil + // create empty device keys and upload them to trigger device list changes + return a.deviceListUpdate(dev.UserID, []string{dev.ID}) } func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.PerformDeviceDeletionRequest, res *api.PerformDeviceDeletionResponse) error { @@ -121,10 +122,14 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe return err } // create empty device keys and upload them to delete what was once there and trigger device list changes - deviceKeys := make([]keyapi.DeviceKeys, len(req.DeviceIDs)) - for i, did := range req.DeviceIDs { + return a.deviceListUpdate(req.UserID, req.DeviceIDs) +} + +func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error { + deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs)) + for i, did := range deviceIDs { deviceKeys[i] = keyapi.DeviceKeys{ - UserID: req.UserID, + UserID: userID, DeviceID: did, KeyJSON: nil, } @@ -143,6 +148,35 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe return nil } +func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { + localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") + return err + } + dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID) + if err == sql.ErrNoRows { + res.DeviceExists = false + return nil + } else if err != nil { + util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed") + return err + } + res.DeviceExists = true + + if dev.UserID != req.RequestingUserID { + res.Forbidden = true + return nil + } + + err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") + return err + } + return nil +} + func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfileRequest, res *api.QueryProfileResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 47e2110f9..5f4df0eb1 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -31,6 +31,7 @@ const ( PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformAccountCreationPath = "/userapi/performAccountCreation" PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" + PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" QueryProfilePath = "/userapi/queryProfile" QueryAccessTokenPath = "/userapi/queryAccessToken" @@ -104,6 +105,14 @@ func (h *httpUserInternalAPI) PerformDeviceDeletion( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +func (h *httpUserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceUpdate") + defer span.Finish() + + apiURL := h.apiURL + PerformDeviceUpdatePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + func (h *httpUserInternalAPI) QueryProfile( ctx context.Context, request *api.QueryProfileRequest, diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index ebb9bf4e8..47d68ff21 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -52,6 +52,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(PerformDeviceUpdatePath, + httputil.MakeInternalAPI("performDeviceUpdate", func(req *http.Request) util.JSONResponse { + request := api.PerformDeviceUpdateRequest{} + response := api.PerformDeviceUpdateResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformDeviceUpdate(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(PerformDeviceDeletionPath, httputil.MakeInternalAPI("performDeviceDeletion", func(req *http.Request) util.JSONResponse { request := api.PerformDeviceDeletionRequest{} diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index efe6f927c..9b535aab9 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -174,7 +174,7 @@ func (s *devicesStatements) deleteDevice( func (s *devicesStatements) deleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { - orig := strings.Replace(deleteDevicesSQL, "($1)", sqlutil.QueryVariadic(len(devices)), 1) + orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) prep, err := s.db.Prepare(orig) if err != nil { return err @@ -186,7 +186,6 @@ func (s *devicesStatements) deleteDevices( for i, v := range devices { params[i+1] = v } - params = append(params, params...) _, err = stmt.ExecContext(ctx, params...) return err }) From ffcb6d2ea199cfa985e72ffbdcb884fb9bc9f54d Mon Sep 17 00:00:00 2001 From: Kegsay Date: Mon, 3 Aug 2020 12:29:58 +0100 Subject: [PATCH 3/4] Produce OTK counts in /sync response (#1235) * Add QueryOneTimeKeys for /sync extensions * Unbreak tests * Produce OTK counts in /sync response * Linting --- keyserver/api/api.go | 14 ++++++++++++ keyserver/internal/internal.go | 11 ++++++++++ keyserver/inthttp/client.go | 18 +++++++++++++++ keyserver/inthttp/server.go | 11 ++++++++++ keyserver/storage/interface.go | 3 +++ .../storage/postgres/one_time_keys_table.go | 22 +++++++++++++++++++ keyserver/storage/shared/storage.go | 4 ++++ .../storage/sqlite3/one_time_keys_table.go | 22 +++++++++++++++++++ keyserver/storage/tables/interface.go | 1 + syncapi/internal/keychange.go | 15 +++++++++++++ syncapi/internal/keychange_test.go | 3 +++ syncapi/sync/requestpool.go | 19 ++++++++++------ syncapi/types/types.go | 2 ++ 13 files changed, 138 insertions(+), 7 deletions(-) diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 6795498fe..eb2f9e24a 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -31,6 +31,7 @@ type KeyInternalAPI interface { PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) + QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) } // KeyError is returned if there was a problem performing/querying the server @@ -157,3 +158,16 @@ type QueryKeyChangesResponse struct { // Set if there was a problem handling the request. Error *KeyError } + +type QueryOneTimeKeysRequest struct { + // The local user to query OTK counts for + UserID string + // The device to query OTK counts for + DeviceID string +} + +type QueryOneTimeKeysResponse struct { + // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84 + Count OneTimeKeysCount + Error *KeyError +} diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index bb8286635..3c8dff847 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -168,6 +168,17 @@ func (a *KeyInternalAPI) claimRemoteKeys( util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys") } +func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) { + count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("Failed to query OTK counts: %s", err), + } + return + } + res.Count = *count +} + func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.Failures = make(map[string]interface{}) diff --git a/keyserver/inthttp/client.go b/keyserver/inthttp/client.go index 3f9690b51..b65cbdafb 100644 --- a/keyserver/inthttp/client.go +++ b/keyserver/inthttp/client.go @@ -31,6 +31,7 @@ const ( PerformClaimKeysPath = "/keyserver/performClaimKeys" QueryKeysPath = "/keyserver/queryKeys" QueryKeyChangesPath = "/keyserver/queryKeyChanges" + QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys" ) // NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API. @@ -108,6 +109,23 @@ func (h *httpKeyInternalAPI) QueryKeys( } } +func (h *httpKeyInternalAPI) QueryOneTimeKeys( + ctx context.Context, + request *api.QueryOneTimeKeysRequest, + response *api.QueryOneTimeKeysResponse, +) { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOneTimeKeys") + defer span.Finish() + + apiURL := h.apiURL + QueryOneTimeKeysPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + if err != nil { + response.Error = &api.KeyError{ + Err: err.Error(), + } + } +} + func (h *httpKeyInternalAPI) QueryKeyChanges( ctx context.Context, request *api.QueryKeyChangesRequest, diff --git a/keyserver/inthttp/server.go b/keyserver/inthttp/server.go index f3d2882c2..615b6f80e 100644 --- a/keyserver/inthttp/server.go +++ b/keyserver/inthttp/server.go @@ -58,6 +58,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryOneTimeKeysPath, + httputil.MakeInternalAPI("queryOneTimeKeys", func(req *http.Request) util.JSONResponse { + request := api.QueryOneTimeKeysRequest{} + response := api.QueryOneTimeKeysResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + s.QueryOneTimeKeys(req.Context(), &request, &response) + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(QueryKeyChangesPath, httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse { request := api.QueryKeyChangesRequest{} diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index fade75228..0e0158e58 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -29,6 +29,9 @@ type Database interface { // StoreOneTimeKeys persists the given one-time keys. StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) + // OneTimeKeysCount returns a count of all OTKs for this device. + OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced. DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go index a9d05548b..df215d5a8 100644 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ b/keyserver/storage/postgres/one_time_keys_table.go @@ -121,6 +121,28 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d return result, rows.Err() } +func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + counts := &api.OneTimeKeysCount{ + DeviceID: deviceID, + UserID: userID, + KeyCount: make(map[string]int), + } + rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + return counts, nil +} + func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { now := time.Now().Unix() counts := &api.OneTimeKeysCount{ diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 8c2534f5c..44cb0cc25 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -39,6 +39,10 @@ func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) ( return d.OneTimeKeysTable.InsertOneTimeKeys(ctx, keys) } +func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) +} + func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) } diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index fecf533e6..b35407cd4 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -121,6 +121,28 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d return result, rows.Err() } +func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + counts := &api.OneTimeKeysCount{ + DeviceID: deviceID, + UserID: userID, + KeyCount: make(map[string]int), + } + rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + return counts, nil +} + func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { now := time.Now().Unix() counts := &api.OneTimeKeysCount{ diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 8b89283f5..c6e43be45 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -24,6 +24,7 @@ import ( type OneTimeKeys interface { SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. // Returns an empty map if the key does not exist. diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 0272ffc06..66134d791 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -29,6 +29,20 @@ import ( const DeviceListLogName = "dl" +// DeviceOTKCounts adds one-time key counts to the /sync response +func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID, deviceID string, res *types.Response) error { + var queryRes api.QueryOneTimeKeysResponse + keyAPI.QueryOneTimeKeys(ctx, &api.QueryOneTimeKeysRequest{ + UserID: userID, + DeviceID: deviceID, + }, &queryRes) + if queryRes.Error != nil { + return queryRes.Error + } + res.DeviceListsOTKCount = queryRes.Count.KeyCount + return nil +} + // DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response // was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST // be already filled in with join/leave information. @@ -36,6 +50,7 @@ func DeviceListCatchup( ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, userID string, res *types.Response, from, to types.StreamingToken, ) (hasNew bool, err error) { + // Track users who we didn't track before but now do by virtue of sharing a room with them, or not. newlyJoinedRooms := joinedRooms(res, userID) newlyLeftRooms := leftRooms(res) diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index 2c3d154d5..af456af41 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -38,6 +38,9 @@ func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformCl func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) { } func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) { +} +func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) { + } type mockCurrentStateAPI struct { diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index b530b34d1..12c597bbe 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -192,8 +192,9 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use } } -func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) { - res = types.NewResponse() +// nolint:gocyclo +func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (*types.Response, error) { + res := types.NewResponse() since := types.NewStreamToken(0, 0, nil) if req.since != nil { @@ -213,17 +214,21 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState) } if err != nil { - return + return res, err } accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter) if err != nil { - return + return res, err } res, err = rp.appendDeviceLists(res, req.device.UserID, since, latestPos) if err != nil { - return + return res, err + } + err = internal.DeviceOTKCounts(req.ctx, rp.keyAPI, req.device.UserID, req.device.ID, res) + if err != nil { + return res, err } // Before we return the sync response, make sure that we take action on @@ -233,7 +238,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea // Handle the updates and deletions in the database. err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, since) if err != nil { - return + return res, err } } if len(events) > 0 { @@ -250,7 +255,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea } } - return + return res, err } func (rp *RequestPool) appendDeviceLists( diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 4761cce28..f465d9fff 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -393,6 +393,7 @@ type Response struct { Changed []string `json:"changed,omitempty"` Left []string `json:"left,omitempty"` } `json:"device_lists,omitempty"` + DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count"` } // NewResponse creates an empty response with initialised maps. @@ -411,6 +412,7 @@ func NewResponse() *Response { res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0) res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0) res.ToDevice.Events = make([]gomatrixserverlib.SendToDeviceEvent, 0) + res.DeviceListsOTKCount = make(map[string]int) return &res } From fb56bbf0b7d4b21da3f55b066e71d24bf4599887 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Mon, 3 Aug 2020 17:07:06 +0100 Subject: [PATCH 4/4] Generate stream IDs for locally uploaded device keys (#1236) * Breaking: add stream_id to keyserver_device_keys table * Add tests for stream ID generation * Fix whitelist --- keyserver/api/api.go | 17 ++++ keyserver/internal/internal.go | 36 +++++--- keyserver/producers/keychange.go | 2 +- keyserver/storage/interface.go | 9 +- .../storage/postgres/device_keys_table.go | 84 ++++++++++++------- keyserver/storage/shared/storage.go | 29 ++++++- .../storage/sqlite3/device_keys_table.go | 80 +++++++++++------- keyserver/storage/storage_test.go | 82 ++++++++++++++++++ keyserver/storage/tables/interface.go | 7 +- syncapi/consumers/keychange.go | 2 +- sytest-whitelist | 1 + 11 files changed, 265 insertions(+), 84 deletions(-) diff --git a/keyserver/api/api.go b/keyserver/api/api.go index eb2f9e24a..080d0e5fd 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -43,6 +43,13 @@ func (k *KeyError) Error() string { return k.Err } +// DeviceMessage represents the message produced into Kafka by the key server. +type DeviceMessage struct { + DeviceKeys + // A monotonically increasing number which represents device changes for this user. + StreamID int +} + // DeviceKeys represents a set of device keys for a single device // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload type DeviceKeys struct { @@ -50,10 +57,20 @@ type DeviceKeys struct { UserID string // The device ID of this device DeviceID string + // The device display name + DisplayName string // The raw device key JSON KeyJSON []byte } +// WithStreamID returns a copy of this device message with the given stream ID +func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage { + return DeviceMessage{ + DeviceKeys: *k, + StreamID: streamID, + } +} + // OneTimeKeys represents a set of one-time keys for a single device // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload type OneTimeKeys struct { diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 3c8dff847..9027cbf4f 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -61,7 +61,7 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { res.KeyErrors = make(map[string]map[string]*api.KeyError) - a.uploadDeviceKeys(ctx, req, res) + a.uploadLocalDeviceKeys(ctx, req, res) a.uploadOneTimeKeys(ctx, req, res) } @@ -286,18 +286,25 @@ func (a *KeyInternalAPI) queryRemoteKeys( } } -func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { - var keysToStore []api.DeviceKeys +func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { + var keysToStore []api.DeviceMessage // assert that the user ID / device ID are not lying for each key for _, key := range req.DeviceKeys { + _, serverName, err := gomatrixserverlib.SplitID('@', key.UserID) + if err != nil { + continue // ignore invalid users + } + if serverName != a.ThisServer { + continue // ignore remote users + } if len(key.KeyJSON) == 0 { - keysToStore = append(keysToStore, key) + keysToStore = append(keysToStore, key.WithStreamID(0)) continue // deleted keys don't need sanity checking } gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str if gotUserID == key.UserID && gotDeviceID == key.DeviceID { - keysToStore = append(keysToStore, key) + keysToStore = append(keysToStore, key.WithStreamID(0)) continue } @@ -310,11 +317,13 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU } // get existing device keys so we can check for changes - existingKeys := make([]api.DeviceKeys, len(keysToStore)) + existingKeys := make([]api.DeviceMessage, len(keysToStore)) for i := range keysToStore { - existingKeys[i] = api.DeviceKeys{ - UserID: keysToStore[i].UserID, - DeviceID: keysToStore[i].DeviceID, + existingKeys[i] = api.DeviceMessage{ + DeviceKeys: api.DeviceKeys{ + UserID: keysToStore[i].UserID, + DeviceID: keysToStore[i].DeviceID, + }, } } if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil { @@ -324,13 +333,14 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU return } // store the device keys and emit changes - if err := a.DB.StoreDeviceKeys(ctx, keysToStore); err != nil { + err := a.DB.StoreDeviceKeys(ctx, keysToStore) + if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), } return } - err := a.emitDeviceKeyChanges(existingKeys, keysToStore) + err = a.emitDeviceKeyChanges(existingKeys, keysToStore) if err != nil { util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err) } @@ -375,9 +385,9 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform } -func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) error { +func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage) error { // find keys in new that are not in existing - var keysAdded []api.DeviceKeys + var keysAdded []api.DeviceMessage for _, newKey := range new { exists := false for _, existingKey := range existing { diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go index 6035b67bd..99629b42e 100644 --- a/keyserver/producers/keychange.go +++ b/keyserver/producers/keychange.go @@ -41,7 +41,7 @@ func (p *KeyChange) DefaultPartition() int32 { } // ProduceKeyChanges creates new change events for each key -func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error { +func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error { for _, key := range keys { var m sarama.ProducerMessage diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 0e0158e58..11284d86b 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -32,17 +32,18 @@ type Database interface { // OneTimeKeysCount returns a count of all OTKs for this device. OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) - // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced. - DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. + DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error // StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // for this (user, device). + // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set. // Returns an error if there was a problem storing the keys. - StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error + StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. - DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) + DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index d915246c7..e1b4e9475 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -20,7 +20,6 @@ import ( "time" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -32,28 +31,37 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys ( device_id TEXT NOT NULL, ts_added_secs BIGINT NOT NULL, key_json TEXT NOT NULL, + -- the stream ID of this key, scoped per-user. This gets updated when the device key changes. + -- This means we do not store an unbounded append-only log of device keys, which is not actually + -- required in the spec because in the event of a missed update the server fetches the entire + -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs. + stream_id BIGINT NOT NULL, -- Clobber based on tuple of user/device. CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) ); ` const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" + - " VALUES ($1, $2, $3, $4)" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + + " VALUES ($1, $2, $3, $4, $5)" + " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + - " DO UPDATE SET key_json = $4" + " DO UPDATE SET key_json = $4, stream_id = $5" const selectDeviceKeysSQL = "" + - "SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" + +const selectMaxStreamForUserSQL = "" + + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt } func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -73,38 +81,54 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { return nil, err } + if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { + return nil, err + } return s, nil } -func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { +func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { for i, key := range keys { var keyJSONStr string - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr) + var streamID int + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device keys[i].KeyJSON = []byte(keyJSONStr) + keys[i].StreamID = streamID } return nil } -func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { - now := time.Now().Unix() - return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { - for _, key := range keys { - _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), - ) - if err != nil { - return err - } - } - return nil - }) +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { + // nullable if there are no results + var nullStream sql.NullInt32 + err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + if err == sql.ErrNoRows { + err = nil + } + if nullStream.Valid { + streamID = nullStream.Int32 + } + return } -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { +func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ) + if err != nil { + return err + } + } + return nil +} + +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) if err != nil { return nil, err @@ -114,15 +138,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID for _, d := range deviceIDs { deviceIDMap[d] = true } - var result []api.DeviceKeys + var result []api.DeviceMessage for rows.Next() { - var dk api.DeviceKeys + var dk api.DeviceMessage dk.UserID = userID var keyJSON string - if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil { + var streamID int + if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { return nil, err } dk.KeyJSON = []byte(keyJSON) + dk.StreamID = streamID // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { result = append(result, dk) diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 44cb0cc25..e78ee9433 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -43,15 +43,36 @@ func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) } -func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { +func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) } -func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { - return d.DeviceKeysTable.InsertDeviceKeys(ctx, keys) +func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { + // work out the latest stream IDs for each user + userIDToStreamID := make(map[string]int) + for _, k := range keys { + userIDToStreamID[k.UserID] = 0 + } + return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + for userID := range userIDToStreamID { + streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID) + if err != nil { + return err + } + userIDToStreamID[userID] = int(streamID) + } + // set the stream IDs for each key + for i := range keys { + k := keys[i] + userIDToStreamID[k.UserID]++ // start stream from 1 + k.StreamID = userIDToStreamID[k.UserID] + keys[i] = k + } + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) } -func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { +func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs) } diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 69fe7a6e4..9f70885ad 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -20,7 +20,6 @@ import ( "time" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -32,28 +31,33 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys ( device_id TEXT NOT NULL, ts_added_secs BIGINT NOT NULL, key_json TEXT NOT NULL, + stream_id BIGINT NOT NULL, -- Clobber based on tuple of user/device. UNIQUE (user_id, device_id) ); ` const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" + - " VALUES ($1, $2, $3, $4)" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + + " VALUES ($1, $2, $3, $4, $5)" + " ON CONFLICT (user_id, device_id)" + - " DO UPDATE SET key_json = $4" + " DO UPDATE SET key_json = $4, stream_id = $5" const selectDeviceKeysSQL = "" + - "SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" + +const selectMaxStreamForUserSQL = "" + + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt } func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -73,10 +77,13 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { return nil, err } + if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { + return nil, err + } return s, nil } -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { deviceIDMap := make(map[string]bool) for _, d := range deviceIDs { deviceIDMap[d] = true @@ -86,15 +93,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") - var result []api.DeviceKeys + var result []api.DeviceMessage for rows.Next() { - var dk api.DeviceKeys + var dk api.DeviceMessage dk.UserID = userID var keyJSON string - if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil { + var streamID int + if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { return nil, err } dk.KeyJSON = []byte(keyJSON) + dk.StreamID = streamID // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { result = append(result, dk) @@ -103,30 +112,43 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID return result, rows.Err() } -func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { +func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { for i, key := range keys { var keyJSONStr string - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr) + var streamID int + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device keys[i].KeyJSON = []byte(keyJSONStr) + keys[i].StreamID = streamID } return nil } -func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { - now := time.Now().Unix() - return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { - for _, key := range keys { - _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), - ) - if err != nil { - return err - } - } - return nil - }) +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { + // nullable if there are no results + var nullStream sql.NullInt32 + err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + if err == sql.ErrNoRows { + err = nil + } + if nullStream.Valid { + streamID = nullStream.Int32 + } + return +} + +func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ) + if err != nil { + return err + } + } + return nil } diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index 66f6930f9..b3e45e6cc 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/keyserver/api" ) var ctx = context.Background() @@ -77,3 +78,84 @@ func TestKeyChangesUpperLimit(t *testing.T) { t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) } } + +// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, +// and that they are returned correctly when querying for device keys. +func TestDeviceKeysStreamIDGeneration(t *testing.T) { + db, err := NewDatabase("file::memory:", nil) + if err != nil { + t.Fatalf("Failed to NewDatabase: %s", err) + } + alice := "@alice:TestDeviceKeysStreamIDGeneration" + bob := "@bob:TestDeviceKeysStreamIDGeneration" + msgs := []api.DeviceMessage{ + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 + }, + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "AAA", + UserID: bob, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 as this is a different user + }, + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "another_device", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 2 as this is a 2nd device key + }, + } + MustNotError(t, db.StoreDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 1 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) + } + if msgs[1].StreamID != 1 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) + } + if msgs[2].StreamID != 2 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) + } + + // updating a device sets the next stream ID for that user + msgs = []api.DeviceMessage{ + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v2"}`), + }, + // StreamID: 3 + }, + } + MustNotError(t, db.StoreDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 3 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) + } + + // Querying for device keys returns the latest stream IDs + msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}) + if err != nil { + t.Fatalf("DeviceKeysForUser returned error: %s", err) + } + wantStreamIDs := map[string]int{ + "AAA": 3, + "another_device": 2, + } + if len(msgs) != len(wantStreamIDs) { + t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) + } + for _, m := range msgs { + if m.StreamID != wantStreamIDs[m.DeviceID] { + t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) + } + } +} diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index c6e43be45..65da3310c 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -32,9 +32,10 @@ type OneTimeKeys interface { } type DeviceKeys interface { - SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error - InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error - SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) + SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error + SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) + SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) } type KeyChanges interface { diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 35978be71..e14d2223e 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -98,7 +98,7 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er defer func() { s.updateOffset(msg) }() - var output api.DeviceKeys + var output api.DeviceMessage if err := json.Unmarshal(msg.Value, &output); err != nil { // If the message was invalid, log it and move on to the next message in the stream log.WithError(err).Error("syncapi: failed to unmarshal key change event from key server") diff --git a/sytest-whitelist b/sytest-whitelist index 16a71c648..a1d2e437c 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -110,6 +110,7 @@ Rooms a user is invited to appear in an incremental sync Sync can be polled for updates Sync is woken up for leaves Newly left rooms appear in the leave section of incremental sync +Rooms can be created with an initial invite list (SYN-205) We should see our own leave event, even if history_visibility is restricted (SYN-662) We should see our own leave event when rejecting an invite, even if history_visibility is restricted (riot-web/3462) Newly left rooms appear in the leave section of gapped sync