diff --git a/build/gobind/monolith.go b/build/gobind/monolith.go index d9faeb2b7..598f5b085 100644 --- a/build/gobind/monolith.go +++ b/build/gobind/monolith.go @@ -118,7 +118,9 @@ func (m *DendriteMonolith) Start() { serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices) + keyAPI := keyserver.NewInternalAPI(base.Cfg, federation, base.KafkaProducer) + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices, keyAPI) + keyAPI.SetUserAPI(userAPI) rsAPI := roomserver.NewInternalAPI( base, keyRing, federation, @@ -156,7 +158,11 @@ func (m *DendriteMonolith) Start() { RoomserverAPI: rsAPI, UserAPI: userAPI, StateAPI: stateAPI, +<<<<<<< HEAD KeyAPI: keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, userAPI, base.KafkaProducer), +======= + KeyAPI: keyAPI, +>>>>>>> master ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ygg, fsAPI, federation, ), diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index 01310400a..d0b3bdbe5 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -115,33 +115,9 @@ func GetDevicesByLocalpart( // UpdateDeviceByID handles PUT on /devices/{deviceID} func UpdateDeviceByID( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI api.UserInternalAPI, device *api.Device, deviceID string, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - - ctx := req.Context() - dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID) - if err == sql.ErrNoRows { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Unknown device"), - } - } else if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByID failed") - return jsonerror.InternalServerError() - } - - if dev.UserID != device.UserID { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("device not owned by current user"), - } - } defer req.Body.Close() // nolint: errcheck @@ -152,10 +128,28 @@ func UpdateDeviceByID( return jsonerror.InternalServerError() } - if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.UpdateDevice failed") + var performRes api.PerformDeviceUpdateResponse + err := userAPI.PerformDeviceUpdate(req.Context(), &api.PerformDeviceUpdateRequest{ + RequestingUserID: device.UserID, + DeviceID: deviceID, + DisplayName: payload.DisplayName, + }, &performRes) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceUpdate failed") return jsonerror.InternalServerError() } + if !performRes.DeviceExists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.Forbidden("device does not exist"), + } + } + if performRes.Forbidden { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("device not owned by current user"), + } + } return util.JSONResponse{ Code: http.StatusOK, @@ -165,7 +159,7 @@ func UpdateDeviceByID( // DeleteDeviceById handles DELETE requests to /devices/{deviceId} func DeleteDeviceById( - req *http.Request, userInteractiveAuth *auth.UserInteractive, deviceDB devices.Database, device *api.Device, + req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.UserInternalAPI, device *api.Device, deviceID string, ) util.JSONResponse { ctx := req.Context() @@ -197,8 +191,12 @@ func DeleteDeviceById( } } - if err := deviceDB.RemoveDevice(ctx, deviceID, localpart); err != nil { - util.GetLogger(ctx).WithError(err).Error("deviceDB.RemoveDevice failed") + var res api.PerformDeviceDeletionResponse + if err := userAPI.PerformDeviceDeletion(ctx, &api.PerformDeviceDeletionRequest{ + UserID: device.UserID, + DeviceIDs: []string{deviceID}, + }, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("userAPI.PerformDeviceDeletion failed") return jsonerror.InternalServerError() } @@ -210,26 +208,24 @@ func DeleteDeviceById( // DeleteDevices handles POST requests to /delete_devices func DeleteDevices( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI api.UserInternalAPI, device *api.Device, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - ctx := req.Context() payload := devicesDeleteJSON{} if err := json.NewDecoder(req.Body).Decode(&payload); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.NewDecoder.Decode failed") + util.GetLogger(ctx).WithError(err).Error("json.NewDecoder.Decode failed") return jsonerror.InternalServerError() } defer req.Body.Close() // nolint: errcheck - if err := deviceDB.RemoveDevices(ctx, localpart, payload.Devices); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveDevices failed") + var res api.PerformDeviceDeletionResponse + if err := userAPI.PerformDeviceDeletion(ctx, &api.PerformDeviceDeletionRequest{ + UserID: device.UserID, + DeviceIDs: payload.Devices, + }, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("userAPI.PerformDeviceDeletion failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index dae593bf0..d2bc9337d 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -23,8 +23,8 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/config" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -57,7 +57,7 @@ func passwordLogin() flows { // Login implements GET and POST /login func Login( - req *http.Request, accountDB accounts.Database, deviceDB devices.Database, + req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI, cfg *config.ClientAPI, ) util.JSONResponse { if req.Method == http.MethodGet { @@ -81,7 +81,7 @@ func Login( return *authErr } // make a device/access token - return completeAuth(req.Context(), cfg.Matrix.ServerName, deviceDB, login) + return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login) } return util.JSONResponse{ Code: http.StatusMethodNotAllowed, @@ -90,7 +90,7 @@ func Login( } func completeAuth( - ctx context.Context, serverName gomatrixserverlib.ServerName, deviceDB devices.Database, login *auth.Login, + ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.UserInternalAPI, login *auth.Login, ) util.JSONResponse { token, err := auth.GenerateAccessToken() if err != nil { @@ -104,9 +104,13 @@ func completeAuth( return jsonerror.InternalServerError() } - dev, err := deviceDB.CreateDevice( - ctx, localpart, login.DeviceID, token, login.InitialDisplayName, - ) + var performRes userapi.PerformDeviceCreationResponse + err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ + DeviceDisplayName: login.InitialDisplayName, + DeviceID: login.DeviceID, + AccessToken: token, + Localpart: localpart, + }, &performRes) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -117,10 +121,10 @@ func completeAuth( return util.JSONResponse{ Code: http.StatusOK, JSON: loginResponse{ - UserID: dev.UserID, - AccessToken: dev.AccessToken, + UserID: performRes.Device.UserID, + AccessToken: performRes.Device.AccessToken, HomeServer: serverName, - DeviceID: dev.ID, + DeviceID: performRes.Device.ID, }, } } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 57aefd0ad..883b473bf 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -387,7 +387,7 @@ func Setup( r0mux.Handle("/login", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { - return Login(req, accountDB, deviceDB, cfg) + return Login(req, accountDB, userAPI, cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) @@ -644,7 +644,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return UpdateDeviceByID(req, deviceDB, device, vars["deviceID"]) + return UpdateDeviceByID(req, userAPI, device, vars["deviceID"]) }), ).Methods(http.MethodPut, http.MethodOptions) @@ -654,13 +654,13 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return DeleteDeviceById(req, userInteractiveAuth, deviceDB, device, vars["deviceID"]) + return DeleteDeviceById(req, userInteractiveAuth, userAPI, device, vars["deviceID"]) }), ).Methods(http.MethodDelete, http.MethodOptions) r0mux.Handle("/delete_devices", httputil.MakeAuthAPI("delete_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return DeleteDevices(req, deviceDB, device) + return DeleteDevices(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index 31e2f87fa..0f5196299 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -144,7 +144,9 @@ func main() { accountDB := base.Base.CreateAccountsDB() deviceDB := base.Base.CreateDeviceDB() federation := createFederationClient(base) - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil) + keyAPI := keyserver.NewInternalAPI(&base.Base.Cfg.KeyServer, federation, base.Base.KafkaProducer) + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil, keyAPI) + keyAPI.SetUserAPI(userAPI) serverKeyAPI := serverkeyapi.NewInternalAPI( &base.Base.Cfg.ServerKeyAPI, federation, base.Base.Caches, @@ -189,7 +191,7 @@ func main() { ServerKeyAPI: serverKeyAPI, StateAPI: stateAPI, UserAPI: userAPI, - KeyAPI: keyserver.NewInternalAPI(&base.Base.Cfg.KeyServer, federation, userAPI, base.Base.KafkaProducer), + KeyAPI: keyAPI, ExtPublicRoomsProvider: provider, } monolith.AddAllPublicRoutes(base.Base.PublicAPIMux) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 5a01f5f77..0f8a6029f 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -105,7 +105,9 @@ func main() { serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil) + keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer) + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil, keyAPI) + keyAPI.SetUserAPI(userAPI) rsComponent := roomserver.NewInternalAPI( base, keyRing, federation, @@ -144,8 +146,7 @@ func main() { RoomserverAPI: rsAPI, UserAPI: userAPI, StateAPI: stateAPI, - KeyAPI: keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, userAPI, base.KafkaProducer), - //ServerKeyAPI: serverKeyAPI, + KeyAPI: keyAPI, ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ygg, fsAPI, federation, ), diff --git a/cmd/dendrite-key-server/main.go b/cmd/dendrite-key-server/main.go index 669808b2a..f3110a1e1 100644 --- a/cmd/dendrite-key-server/main.go +++ b/cmd/dendrite-key-server/main.go @@ -24,7 +24,8 @@ func main() { base := setup.NewBaseDendrite(cfg, "KeyServer", true) defer base.Close() // nolint: errcheck - intAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, base.CreateFederationClient(), base.UserAPIClient(), base.KafkaProducer) + intAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, base.CreateFederationClient(), base.KafkaProducer) + intAPI.SetUserAPI(base.UserAPIClient()) keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 289aa7754..a6d25b253 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -78,7 +78,9 @@ func main() { serverKeyAPI = base.ServerKeyAPIClient() } keyRing := serverKeyAPI.KeyRing() - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices) + keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer) + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices, keyAPI) + keyAPI.SetUserAPI(userAPI) rsImpl := roomserver.NewInternalAPI( base, keyRing, federation, @@ -121,7 +123,6 @@ func main() { rsImpl.SetFederationSenderAPI(fsAPI) stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer) - keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, userAPI, base.KafkaProducer) monolith := setup.Monolith{ Config: base.Cfg, diff --git a/cmd/dendrite-user-api-server/main.go b/cmd/dendrite-user-api-server/main.go index 4655cd09a..22b6255eb 100644 --- a/cmd/dendrite-user-api-server/main.go +++ b/cmd/dendrite-user-api-server/main.go @@ -27,7 +27,7 @@ func main() { accountDB := base.CreateAccountsDB() deviceDB := base.CreateDeviceDB() - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices) + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient()) userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index c9f2f181e..a815c4057 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -196,7 +196,9 @@ func main() { accountDB := base.CreateAccountsDB() deviceDB := base.CreateDeviceDB() federation := createFederationClient(cfg, node) - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil) + keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer) + userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil, keyAPI) + keyAPI.SetUserAPI(userAPI) fetcher := &libp2pKeyFetcher{} keyRing := gomatrixserverlib.KeyRing{ @@ -233,7 +235,7 @@ func main() { RoomserverAPI: rsAPI, StateAPI: stateAPI, UserAPI: userAPI, - KeyAPI: keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, userAPI, base.KafkaProducer), + KeyAPI: keyAPI, //ServerKeyAPI: serverKeyAPI, ExtPublicRoomsProvider: p2pPublicRoomProvider, } diff --git a/currentstateserver/currentstateserver_test.go b/currentstateserver/currentstateserver_test.go index 42751b2e3..eb189275e 100644 --- a/currentstateserver/currentstateserver_test.go +++ b/currentstateserver/currentstateserver_test.go @@ -20,6 +20,7 @@ import ( "crypto/ed25519" "encoding/json" "net/http" + "os" "reflect" "testing" "time" @@ -92,13 +93,14 @@ func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *rooms return nil } -func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.SyncProducer) { +func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.SyncProducer, func()) { cfg := &config.Dendrite{} - cfg.Defaults() + stateDBName := "test_state.db" + naffkaDBName := "test_naffka.db" cfg.Global.ServerName = "kaer.morhen" cfg.Global.Kafka.Topics.OutputRoomEvent = config.Topic(kafkaTopic) - cfg.CurrentStateServer.Database.ConnectionString = config.DataSource("file::memory:") - db, err := sqlutil.Open(&cfg.CurrentStateServer.Database) + cfg.CurrentStateServer.Database.ConnectionString = config.DataSource("file:" + stateDBName) + db, err := sqlutil.Open(sqlutil.SQLiteDriverName(), "file:"+naffkaDBName, nil) if err != nil { t.Fatalf("Failed to open naffka database: %s", err) } @@ -110,11 +112,15 @@ func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.Sync if err != nil { t.Fatalf("Failed to create naffka consumer: %s", err) } - return NewInternalAPI(&cfg.CurrentStateServer, naff), naff + return NewInternalAPI(cfg, naff), naff, func() { + os.Remove(naffkaDBName) + os.Remove(stateDBName) + } } func TestQueryCurrentState(t *testing.T) { - currStateAPI, producer := MustMakeInternalAPI(t) + currStateAPI, producer, cancel := MustMakeInternalAPI(t) + defer cancel() plTuple := gomatrixserverlib.StateKeyTuple{ EventType: "m.room.power_levels", StateKey: "", @@ -217,7 +223,8 @@ func mustMakeMembershipEvent(t *testing.T, roomID, userID, membership string) *r // This test makes sure that QuerySharedUsers is returning the correct users for a range of sets. func TestQuerySharedUsers(t *testing.T) { - currStateAPI, producer := MustMakeInternalAPI(t) + currStateAPI, producer, cancel := MustMakeInternalAPI(t) + defer cancel() MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join")) MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join")) @@ -230,6 +237,9 @@ func TestQuerySharedUsers(t *testing.T) { MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join")) + // we don't know when the server has processed the events + time.Sleep(10 * time.Millisecond) + testCases := []struct { req api.QuerySharedUsersRequest wantRes api.QuerySharedUsersResponse diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 98bcd9442..080d0e5fd 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -19,14 +19,19 @@ import ( "encoding/json" "strings" "time" + + userapi "github.com/matrix-org/dendrite/userapi/api" ) type KeyInternalAPI interface { + // SetUserAPI assigns a user API to query when extracting device names. + SetUserAPI(i userapi.UserInternalAPI) PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) // PerformClaimKeys claims one-time keys for use in pre-key messages PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) + QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) } // KeyError is returned if there was a problem performing/querying the server @@ -38,6 +43,13 @@ func (k *KeyError) Error() string { return k.Err } +// DeviceMessage represents the message produced into Kafka by the key server. +type DeviceMessage struct { + DeviceKeys + // A monotonically increasing number which represents device changes for this user. + StreamID int +} + // DeviceKeys represents a set of device keys for a single device // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload type DeviceKeys struct { @@ -45,10 +57,20 @@ type DeviceKeys struct { UserID string // The device ID of this device DeviceID string + // The device display name + DisplayName string // The raw device key JSON KeyJSON []byte } +// WithStreamID returns a copy of this device message with the given stream ID +func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage { + return DeviceMessage{ + DeviceKeys: *k, + StreamID: streamID, + } +} + // OneTimeKeys represents a set of one-time keys for a single device // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload type OneTimeKeys struct { @@ -153,3 +175,16 @@ type QueryKeyChangesResponse struct { // Set if there was a problem handling the request. Error *KeyError } + +type QueryOneTimeKeysRequest struct { + // The local user to query OTK counts for + UserID string + // The device to query OTK counts for + DeviceID string +} + +type QueryOneTimeKeysResponse struct { + // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84 + Count OneTimeKeysCount + Error *KeyError +} diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 703713538..9027cbf4f 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -40,6 +40,10 @@ type KeyInternalAPI struct { Producer *producers.KeyChange } +func (a *KeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) { + a.UserAPI = i +} + func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) { if req.Partition < 0 { req.Partition = a.Producer.DefaultPartition() @@ -57,7 +61,7 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { res.KeyErrors = make(map[string]map[string]*api.KeyError) - a.uploadDeviceKeys(ctx, req, res) + a.uploadLocalDeviceKeys(ctx, req, res) a.uploadOneTimeKeys(ctx, req, res) } @@ -164,6 +168,17 @@ func (a *KeyInternalAPI) claimRemoteKeys( util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys") } +func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) { + count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("Failed to query OTK counts: %s", err), + } + return + } + res.Count = *count +} + func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.Failures = make(map[string]interface{}) @@ -202,6 +217,9 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques res.DeviceKeys[userID] = make(map[string]json.RawMessage) } for _, dk := range deviceKeys { + if len(dk.KeyJSON) == 0 { + continue // don't include blank keys + } // inject display name if known dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { DisplayName string `json:"device_display_name,omitempty"` @@ -268,14 +286,25 @@ func (a *KeyInternalAPI) queryRemoteKeys( } } -func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { - var keysToStore []api.DeviceKeys +func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { + var keysToStore []api.DeviceMessage // assert that the user ID / device ID are not lying for each key for _, key := range req.DeviceKeys { + _, serverName, err := gomatrixserverlib.SplitID('@', key.UserID) + if err != nil { + continue // ignore invalid users + } + if serverName != a.ThisServer { + continue // ignore remote users + } + if len(key.KeyJSON) == 0 { + keysToStore = append(keysToStore, key.WithStreamID(0)) + continue // deleted keys don't need sanity checking + } gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str if gotUserID == key.UserID && gotDeviceID == key.DeviceID { - keysToStore = append(keysToStore, key) + keysToStore = append(keysToStore, key.WithStreamID(0)) continue } @@ -286,12 +315,15 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU ), }) } + // get existing device keys so we can check for changes - existingKeys := make([]api.DeviceKeys, len(keysToStore)) + existingKeys := make([]api.DeviceMessage, len(keysToStore)) for i := range keysToStore { - existingKeys[i] = api.DeviceKeys{ - UserID: keysToStore[i].UserID, - DeviceID: keysToStore[i].DeviceID, + existingKeys[i] = api.DeviceMessage{ + DeviceKeys: api.DeviceKeys{ + UserID: keysToStore[i].UserID, + DeviceID: keysToStore[i].DeviceID, + }, } } if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil { @@ -301,13 +333,14 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU return } // store the device keys and emit changes - if err := a.DB.StoreDeviceKeys(ctx, keysToStore); err != nil { + err := a.DB.StoreDeviceKeys(ctx, keysToStore) + if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), } return } - err := a.emitDeviceKeyChanges(existingKeys, keysToStore) + err = a.emitDeviceKeyChanges(existingKeys, keysToStore) if err != nil { util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err) } @@ -352,13 +385,15 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform } -func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) error { +func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage) error { // find keys in new that are not in existing - var keysAdded []api.DeviceKeys + var keysAdded []api.DeviceMessage for _, newKey := range new { exists := false for _, existingKey := range existing { - if bytes.Equal(existingKey.KeyJSON, newKey.KeyJSON) { + // Do not treat the absence of keys as equal, or else we will not emit key changes + // when users delete devices which never had a key to begin with as both KeyJSONs are nil. + if bytes.Equal(existingKey.KeyJSON, newKey.KeyJSON) && len(existingKey.KeyJSON) > 0 { exists = true break } diff --git a/keyserver/inthttp/client.go b/keyserver/inthttp/client.go index cd9cf70d4..b65cbdafb 100644 --- a/keyserver/inthttp/client.go +++ b/keyserver/inthttp/client.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/opentracing/opentracing-go" ) @@ -30,6 +31,7 @@ const ( PerformClaimKeysPath = "/keyserver/performClaimKeys" QueryKeysPath = "/keyserver/queryKeys" QueryKeyChangesPath = "/keyserver/queryKeyChanges" + QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys" ) // NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API. @@ -52,6 +54,10 @@ type httpKeyInternalAPI struct { httpClient *http.Client } +func (h *httpKeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) { + // no-op: doesn't need it +} + func (h *httpKeyInternalAPI) PerformClaimKeys( ctx context.Context, request *api.PerformClaimKeysRequest, @@ -103,6 +109,23 @@ func (h *httpKeyInternalAPI) QueryKeys( } } +func (h *httpKeyInternalAPI) QueryOneTimeKeys( + ctx context.Context, + request *api.QueryOneTimeKeysRequest, + response *api.QueryOneTimeKeysResponse, +) { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOneTimeKeys") + defer span.Finish() + + apiURL := h.apiURL + QueryOneTimeKeysPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + if err != nil { + response.Error = &api.KeyError{ + Err: err.Error(), + } + } +} + func (h *httpKeyInternalAPI) QueryKeyChanges( ctx context.Context, request *api.QueryKeyChangesRequest, diff --git a/keyserver/inthttp/server.go b/keyserver/inthttp/server.go index f3d2882c2..615b6f80e 100644 --- a/keyserver/inthttp/server.go +++ b/keyserver/inthttp/server.go @@ -58,6 +58,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryOneTimeKeysPath, + httputil.MakeInternalAPI("queryOneTimeKeys", func(req *http.Request) util.JSONResponse { + request := api.QueryOneTimeKeysRequest{} + response := api.QueryOneTimeKeysResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + s.QueryOneTimeKeys(req.Context(), &request, &response) + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(QueryKeyChangesPath, httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse { request := api.QueryKeyChangesRequest{} diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index e6a8f1667..ada49a215 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/keyserver/inthttp" "github.com/matrix-org/dendrite/keyserver/producers" "github.com/matrix-org/dendrite/keyserver/storage" - userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) @@ -37,7 +36,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - cfg *config.KeyServer, fedClient *gomatrixserverlib.FederationClient, userAPI userapi.UserInternalAPI, producer sarama.SyncProducer, + cfg *config.KeyServer, fedClient *gomatrixserverlib.FederationClient, producer sarama.SyncProducer, ) api.KeyInternalAPI { db, err := storage.NewDatabase(&cfg.Database) if err != nil { @@ -52,7 +51,6 @@ func NewInternalAPI( DB: db, ThisServer: cfg.Matrix.ServerName, FedClient: fedClient, - UserAPI: userAPI, Producer: keyChangeProducer, } } diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go index c51d9f55d..99629b42e 100644 --- a/keyserver/producers/keychange.go +++ b/keyserver/producers/keychange.go @@ -41,7 +41,7 @@ func (p *KeyChange) DefaultPartition() int32 { } // ProduceKeyChanges creates new change events for each key -func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error { +func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error { for _, key := range keys { var m sarama.ProducerMessage @@ -63,10 +63,11 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error { return err } logrus.WithFields(logrus.Fields{ - "user_id": key.UserID, - "device_id": key.DeviceID, - "partition": partition, - "offset": offset, + "user_id": key.UserID, + "device_id": key.DeviceID, + "partition": partition, + "offset": offset, + "len_key_bytes": len(key.KeyJSON), }).Infof("Produced to key change topic '%s'", p.Topic) } return nil diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 7a4fce6f5..11284d86b 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -29,16 +29,21 @@ type Database interface { // StoreOneTimeKeys persists the given one-time keys. StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) - // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced. - DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error + // OneTimeKeysCount returns a count of all OTKs for this device. + OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) - // StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. + DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + + // StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). + // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set. // Returns an error if there was a problem storing the keys. - StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error + StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. - DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) + DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index d915246c7..e1b4e9475 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -20,7 +20,6 @@ import ( "time" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -32,28 +31,37 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys ( device_id TEXT NOT NULL, ts_added_secs BIGINT NOT NULL, key_json TEXT NOT NULL, + -- the stream ID of this key, scoped per-user. This gets updated when the device key changes. + -- This means we do not store an unbounded append-only log of device keys, which is not actually + -- required in the spec because in the event of a missed update the server fetches the entire + -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs. + stream_id BIGINT NOT NULL, -- Clobber based on tuple of user/device. CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) ); ` const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" + - " VALUES ($1, $2, $3, $4)" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + + " VALUES ($1, $2, $3, $4, $5)" + " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + - " DO UPDATE SET key_json = $4" + " DO UPDATE SET key_json = $4, stream_id = $5" const selectDeviceKeysSQL = "" + - "SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" + +const selectMaxStreamForUserSQL = "" + + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt } func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -73,38 +81,54 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { return nil, err } + if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { + return nil, err + } return s, nil } -func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { +func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { for i, key := range keys { var keyJSONStr string - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr) + var streamID int + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device keys[i].KeyJSON = []byte(keyJSONStr) + keys[i].StreamID = streamID } return nil } -func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { - now := time.Now().Unix() - return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { - for _, key := range keys { - _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), - ) - if err != nil { - return err - } - } - return nil - }) +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { + // nullable if there are no results + var nullStream sql.NullInt32 + err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + if err == sql.ErrNoRows { + err = nil + } + if nullStream.Valid { + streamID = nullStream.Int32 + } + return } -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { +func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ) + if err != nil { + return err + } + } + return nil +} + +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) if err != nil { return nil, err @@ -114,15 +138,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID for _, d := range deviceIDs { deviceIDMap[d] = true } - var result []api.DeviceKeys + var result []api.DeviceMessage for rows.Next() { - var dk api.DeviceKeys + var dk api.DeviceMessage dk.UserID = userID var keyJSON string - if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil { + var streamID int + if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { return nil, err } dk.KeyJSON = []byte(keyJSON) + dk.StreamID = streamID // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { result = append(result, dk) diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go index a9d05548b..df215d5a8 100644 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ b/keyserver/storage/postgres/one_time_keys_table.go @@ -121,6 +121,28 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d return result, rows.Err() } +func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + counts := &api.OneTimeKeysCount{ + DeviceID: deviceID, + UserID: userID, + KeyCount: make(map[string]int), + } + rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + return counts, nil +} + func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { now := time.Now().Unix() counts := &api.OneTimeKeysCount{ diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 8c2534f5c..e78ee9433 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -39,15 +39,40 @@ func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) ( return d.OneTimeKeysTable.InsertOneTimeKeys(ctx, keys) } -func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { +func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) +} + +func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) } -func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { - return d.DeviceKeysTable.InsertDeviceKeys(ctx, keys) +func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { + // work out the latest stream IDs for each user + userIDToStreamID := make(map[string]int) + for _, k := range keys { + userIDToStreamID[k.UserID] = 0 + } + return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + for userID := range userIDToStreamID { + streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID) + if err != nil { + return err + } + userIDToStreamID[userID] = int(streamID) + } + // set the stream IDs for each key + for i := range keys { + k := keys[i] + userIDToStreamID[k.UserID]++ // start stream from 1 + k.StreamID = userIDToStreamID[k.UserID] + keys[i] = k + } + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) } -func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { +func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs) } diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 69fe7a6e4..9f70885ad 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -20,7 +20,6 @@ import ( "time" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -32,28 +31,33 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys ( device_id TEXT NOT NULL, ts_added_secs BIGINT NOT NULL, key_json TEXT NOT NULL, + stream_id BIGINT NOT NULL, -- Clobber based on tuple of user/device. UNIQUE (user_id, device_id) ); ` const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" + - " VALUES ($1, $2, $3, $4)" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + + " VALUES ($1, $2, $3, $4, $5)" + " ON CONFLICT (user_id, device_id)" + - " DO UPDATE SET key_json = $4" + " DO UPDATE SET key_json = $4, stream_id = $5" const selectDeviceKeysSQL = "" + - "SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" + +const selectMaxStreamForUserSQL = "" + + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt } func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -73,10 +77,13 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { return nil, err } + if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { + return nil, err + } return s, nil } -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { deviceIDMap := make(map[string]bool) for _, d := range deviceIDs { deviceIDMap[d] = true @@ -86,15 +93,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") - var result []api.DeviceKeys + var result []api.DeviceMessage for rows.Next() { - var dk api.DeviceKeys + var dk api.DeviceMessage dk.UserID = userID var keyJSON string - if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil { + var streamID int + if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { return nil, err } dk.KeyJSON = []byte(keyJSON) + dk.StreamID = streamID // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { result = append(result, dk) @@ -103,30 +112,43 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID return result, rows.Err() } -func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { +func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { for i, key := range keys { var keyJSONStr string - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr) + var streamID int + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device keys[i].KeyJSON = []byte(keyJSONStr) + keys[i].StreamID = streamID } return nil } -func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { - now := time.Now().Unix() - return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { - for _, key := range keys { - _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), - ) - if err != nil { - return err - } - } - return nil - }) +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { + // nullable if there are no results + var nullStream sql.NullInt32 + err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + if err == sql.ErrNoRows { + err = nil + } + if nullStream.Valid { + streamID = nullStream.Int32 + } + return +} + +func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ) + if err != nil { + return err + } + } + return nil } diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index fecf533e6..b35407cd4 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -121,6 +121,28 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d return result, rows.Err() } +func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + counts := &api.OneTimeKeysCount{ + DeviceID: deviceID, + UserID: userID, + KeyCount: make(map[string]int), + } + rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + return counts, nil +} + func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { now := time.Now().Unix() counts := &api.OneTimeKeysCount{ diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index ad67a3b82..1c76f1bf7 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -7,6 +7,7 @@ import ( "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/keyserver/api" ) var ctx = context.Background() @@ -82,3 +83,84 @@ func TestKeyChangesUpperLimit(t *testing.T) { t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) } } + +// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, +// and that they are returned correctly when querying for device keys. +func TestDeviceKeysStreamIDGeneration(t *testing.T) { + db, err := NewDatabase("file::memory:", nil) + if err != nil { + t.Fatalf("Failed to NewDatabase: %s", err) + } + alice := "@alice:TestDeviceKeysStreamIDGeneration" + bob := "@bob:TestDeviceKeysStreamIDGeneration" + msgs := []api.DeviceMessage{ + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 + }, + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "AAA", + UserID: bob, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 as this is a different user + }, + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "another_device", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 2 as this is a 2nd device key + }, + } + MustNotError(t, db.StoreDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 1 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) + } + if msgs[1].StreamID != 1 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) + } + if msgs[2].StreamID != 2 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) + } + + // updating a device sets the next stream ID for that user + msgs = []api.DeviceMessage{ + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v2"}`), + }, + // StreamID: 3 + }, + } + MustNotError(t, db.StoreDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 3 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) + } + + // Querying for device keys returns the latest stream IDs + msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}) + if err != nil { + t.Fatalf("DeviceKeysForUser returned error: %s", err) + } + wantStreamIDs := map[string]int{ + "AAA": 3, + "another_device": 2, + } + if len(msgs) != len(wantStreamIDs) { + t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) + } + for _, m := range msgs { + if m.StreamID != wantStreamIDs[m.DeviceID] { + t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) + } + } +} diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 8b89283f5..65da3310c 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -24,6 +24,7 @@ import ( type OneTimeKeys interface { SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. // Returns an empty map if the key does not exist. @@ -31,9 +32,10 @@ type OneTimeKeys interface { } type DeviceKeys interface { - SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error - InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error - SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) + SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error + SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) + SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) } type KeyChanges interface { diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 35978be71..e14d2223e 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -98,7 +98,7 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er defer func() { s.updateOffset(msg) }() - var output api.DeviceKeys + var output api.DeviceMessage if err := json.Unmarshal(msg.Value, &output); err != nil { // If the message was invalid, log it and move on to the next message in the stream log.WithError(err).Error("syncapi: failed to unmarshal key change event from key server") diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index ebceb7370..06c904c39 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -35,6 +35,7 @@ type OutputRoomEventConsumer struct { rsConsumer *internal.ContinualConsumer db storage.Database notifier *sync.Notifier + keyChanges *OutputKeyChangeEventConsumer } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -44,6 +45,7 @@ func NewOutputRoomEventConsumer( n *sync.Notifier, store storage.Database, rsAPI api.RoomserverInternalAPI, + keyChanges *OutputKeyChangeEventConsumer, ) *OutputRoomEventConsumer { consumer := internal.ContinualConsumer{ @@ -56,6 +58,7 @@ func NewOutputRoomEventConsumer( db: store, notifier: n, rsAPI: rsAPI, + keyChanges: keyChanges, } consumer.ProcessMessage = s.onMessage @@ -160,9 +163,29 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( } s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0, nil)) + s.notifyKeyChanges(&ev) + return nil } +func (s *OutputRoomEventConsumer) notifyKeyChanges(ev *gomatrixserverlib.HeaderedEvent) { + if ev.Type() != gomatrixserverlib.MRoomMember || ev.StateKey() == nil { + return + } + membership, err := ev.Membership() + if err != nil { + return + } + switch membership { + case gomatrixserverlib.Join: + s.keyChanges.OnJoinEvent(ev) + case gomatrixserverlib.Ban: + fallthrough + case gomatrixserverlib.Leave: + s.keyChanges.OnLeaveEvent(ev) + } +} + func (s *OutputRoomEventConsumer) onNewInviteEvent( ctx context.Context, msg api.OutputNewInviteEvent, ) error { diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index cb4fca7d8..66134d791 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -16,6 +16,7 @@ package internal import ( "context" + "strings" "github.com/Shopify/sarama" currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" @@ -28,6 +29,20 @@ import ( const DeviceListLogName = "dl" +// DeviceOTKCounts adds one-time key counts to the /sync response +func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID, deviceID string, res *types.Response) error { + var queryRes api.QueryOneTimeKeysResponse + keyAPI.QueryOneTimeKeys(ctx, &api.QueryOneTimeKeysRequest{ + UserID: userID, + DeviceID: deviceID, + }, &queryRes) + if queryRes.Error != nil { + return queryRes.Error + } + res.DeviceListsOTKCount = queryRes.Count.KeyCount + return nil +} + // DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response // was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST // be already filled in with join/leave information. @@ -35,6 +50,7 @@ func DeviceListCatchup( ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, userID string, res *types.Response, from, to types.StreamingToken, ) (hasNew bool, err error) { + // Track users who we didn't track before but now do by virtue of sharing a room with them, or not. newlyJoinedRooms := joinedRooms(res, userID) newlyLeftRooms := leftRooms(res) @@ -88,6 +104,16 @@ func DeviceListCatchup( if !userSet[userID] { res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) hasNew = true + userSet[userID] = true + } + } + // if the response has any join/leave events, add them now. + // TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them. + for _, userID := range membershipEvents(res) { + if !userSet[userID] { + res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) + hasNew = true + userSet[userID] = true } } return hasNew, nil @@ -219,3 +245,25 @@ func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID strin } return false } + +// returns the user IDs of anyone joining or leaving a room in this response. These users will be added to +// the 'changed' property because of https://matrix.org/docs/spec/client_server/r0.6.1#id84 +// "For optimal performance, Alice should be added to changed in Bob's sync only when she adds a new device, +// or when Alice and Bob now share a room but didn't share any room previously. However, for the sake of simpler +// logic, a server may add Alice to changed when Alice and Bob share a new room, even if they previously already shared a room." +func membershipEvents(res *types.Response) (userIDs []string) { + for _, room := range res.Rooms.Join { + for _, ev := range room.Timeline.Events { + if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil { + if strings.Contains(string(ev.Content), `"join"`) { + userIDs = append(userIDs, *ev.StateKey) + } else if strings.Contains(string(ev.Content), `"leave"`) { + userIDs = append(userIDs, *ev.StateKey) + } else if strings.Contains(string(ev.Content), `"ban"`) { + userIDs = append(userIDs, *ev.StateKey) + } + } + } + } + return +} diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index 3f18696c4..af456af41 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -10,6 +10,7 @@ import ( "github.com/matrix-org/dendrite/currentstateserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" ) @@ -29,12 +30,17 @@ type mockKeyAPI struct{} func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) { } +func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {} + // PerformClaimKeys claims one-time keys for use in pre-key messages func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) { } func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) { } func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) { +} +func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) { + } type mockCurrentStateAPI struct { diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index f817f0981..12c597bbe 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -168,7 +168,7 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use } // work out room joins/leaves res, err := rp.db.IncrementalSync( - req.Context(), types.NewResponse(), *device, fromToken, toToken, 0, false, + req.Context(), types.NewResponse(), *device, fromToken, toToken, 10, false, ) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("Failed to IncrementalSync") @@ -192,8 +192,9 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use } } -func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) { - res = types.NewResponse() +// nolint:gocyclo +func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (*types.Response, error) { + res := types.NewResponse() since := types.NewStreamToken(0, 0, nil) if req.since != nil { @@ -213,17 +214,21 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState) } if err != nil { - return + return res, err } accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter) if err != nil { - return + return res, err } res, err = rp.appendDeviceLists(res, req.device.UserID, since, latestPos) if err != nil { - return + return res, err + } + err = internal.DeviceOTKCounts(req.ctx, rp.keyAPI, req.device.UserID, req.device.ID, res) + if err != nil { + return res, err } // Before we return the sync response, make sure that we take action on @@ -233,7 +238,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea // Handle the updates and deletions in the database. err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, since) if err != nil { - return + return res, err } } if len(events) > 0 { @@ -250,7 +255,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea } } - return + return res, err } func (rp *RequestPool) appendDeviceLists( diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 606c8c281..9caed7be0 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -64,8 +64,16 @@ func AddPublicRoutes( requestPool := sync.NewRequestPool(syncDB, notifier, userAPI, keyAPI, currentStateAPI) + keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( + cfg.Matrix.ServerName, string(cfg.Matrix.Kafka.Topics.OutputKeyChangeEvent), + consumer, notifier, keyAPI, currentStateAPI, syncDB, + ) + if err = keyChangeConsumer.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start key change consumer") + } + roomConsumer := consumers.NewOutputRoomEventConsumer( - cfg, consumer, notifier, syncDB, rsAPI, + cfg, consumer, notifier, syncDB, rsAPI, keyChangeConsumer, ) if err = roomConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start room server consumer") @@ -92,13 +100,5 @@ func AddPublicRoutes( logrus.WithError(err).Panicf("failed to start send-to-device consumer") } - keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( - cfg.Matrix.ServerName, string(cfg.Matrix.Kafka.Topics.OutputKeyChangeEvent), - consumer, notifier, keyAPI, currentStateAPI, syncDB, - ) - if err = keyChangeConsumer.Start(); err != nil { - logrus.WithError(err).Panicf("failed to start key change consumer") - } - routing.Setup(router, requestPool, syncDB, userAPI, federation, rsAPI, cfg) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 4761cce28..f465d9fff 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -393,6 +393,7 @@ type Response struct { Changed []string `json:"changed,omitempty"` Left []string `json:"left,omitempty"` } `json:"device_lists,omitempty"` + DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count"` } // NewResponse creates an empty response with initialised maps. @@ -411,6 +412,7 @@ func NewResponse() *Response { res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0) res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0) res.ToDevice.Events = make([]gomatrixserverlib.SendToDeviceEvent, 0) + res.DeviceListsOTKCount = make(map[string]int) return &res } diff --git a/sytest-whitelist b/sytest-whitelist index 341df8a9a..a1d2e437c 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -110,6 +110,7 @@ Rooms a user is invited to appear in an incremental sync Sync can be polled for updates Sync is woken up for leaves Newly left rooms appear in the leave section of incremental sync +Rooms can be created with an initial invite list (SYN-205) We should see our own leave event, even if history_visibility is restricted (SYN-662) We should see our own leave event when rejecting an invite, even if history_visibility is restricted (riot-web/3462) Newly left rooms appear in the leave section of gapped sync @@ -129,6 +130,11 @@ Can claim one time key using POST Can claim remote one time key using POST Local device key changes appear in v2 /sync Local device key changes appear in /keys/changes +New users appear in /keys/changes +Local delete device changes appear in v2 /sync +Local new device changes appear in v2 /sync +Local update device changes appear in v2 /sync +Users receive device_list updates for their own devices Get left notifs for other users in sync and /keys/changes when user leaves Can add account data Can add account data to room diff --git a/userapi/api/api.go b/userapi/api/api.go index 5791403ff..84338dbf2 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -27,6 +27,8 @@ type UserInternalAPI interface { InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error + PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error + PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error @@ -47,6 +49,25 @@ type InputAccountDataRequest struct { type InputAccountDataResponse struct { } +type PerformDeviceUpdateRequest struct { + RequestingUserID string + DeviceID string + DisplayName *string +} +type PerformDeviceUpdateResponse struct { + DeviceExists bool + Forbidden bool +} + +type PerformDeviceDeletionRequest struct { + UserID string + // The devices to delete + DeviceIDs []string +} + +type PerformDeviceDeletionResponse struct { +} + // QueryDeviceInfosRequest is the request to QueryDeviceInfos type QueryDeviceInfosRequest struct { DeviceIDs []string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 5b1541967..b9d188229 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -25,10 +25,12 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/sqlutil" + keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" ) type UserInternalAPI struct { @@ -37,6 +39,7 @@ type UserInternalAPI struct { ServerName gomatrixserverlib.ServerName // AppServices is the list of all registered AS AppServices []config.ApplicationService + KeyAPI keyapi.KeyInternalAPI } func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { @@ -101,6 +104,76 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe } res.DeviceCreated = true res.Device = dev + // create empty device keys and upload them to trigger device list changes + return a.deviceListUpdate(dev.UserID, []string{dev.ID}) +} + +func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.PerformDeviceDeletionRequest, res *api.PerformDeviceDeletionResponse) error { + util.GetLogger(ctx).WithField("user_id", req.UserID).WithField("devices", req.DeviceIDs).Info("PerformDeviceDeletion") + local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot PerformDeviceDeletion of remote users: got %s want %s", domain, a.ServerName) + } + err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) + if err != nil { + return err + } + // create empty device keys and upload them to delete what was once there and trigger device list changes + return a.deviceListUpdate(req.UserID, req.DeviceIDs) +} + +func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error { + deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs)) + for i, did := range deviceIDs { + deviceKeys[i] = keyapi.DeviceKeys{ + UserID: userID, + DeviceID: did, + KeyJSON: nil, + } + } + + var uploadRes keyapi.PerformUploadKeysResponse + a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ + DeviceKeys: deviceKeys, + }, &uploadRes) + if uploadRes.Error != nil { + return fmt.Errorf("Failed to delete device keys: %v", uploadRes.Error) + } + if len(uploadRes.KeyErrors) > 0 { + return fmt.Errorf("Failed to delete device keys, key errors: %+v", uploadRes.KeyErrors) + } + return nil +} + +func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { + localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") + return err + } + dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID) + if err == sql.ErrNoRows { + res.DeviceExists = false + return nil + } else if err != nil { + util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed") + return err + } + res.DeviceExists = true + + if dev.UserID != req.RequestingUserID { + res.Forbidden = true + return nil + } + + err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") + return err + } return nil } diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 3e1ac0662..5f4df0eb1 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -30,6 +30,8 @@ const ( PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformAccountCreationPath = "/userapi/performAccountCreation" + PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" + PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" QueryProfilePath = "/userapi/queryProfile" QueryAccessTokenPath = "/userapi/queryAccessToken" @@ -91,6 +93,26 @@ func (h *httpUserInternalAPI) PerformDeviceCreation( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +func (h *httpUserInternalAPI) PerformDeviceDeletion( + ctx context.Context, + request *api.PerformDeviceDeletionRequest, + response *api.PerformDeviceDeletionResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceDeletion") + defer span.Finish() + + apiURL := h.apiURL + PerformDeviceDeletionPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceUpdate") + defer span.Finish() + + apiURL := h.apiURL + PerformDeviceUpdatePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + func (h *httpUserInternalAPI) QueryProfile( ctx context.Context, request *api.QueryProfileRequest, diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index d29f4d442..47d68ff21 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -52,6 +52,32 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(PerformDeviceUpdatePath, + httputil.MakeInternalAPI("performDeviceUpdate", func(req *http.Request) util.JSONResponse { + request := api.PerformDeviceUpdateRequest{} + response := api.PerformDeviceUpdateResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformDeviceUpdate(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(PerformDeviceDeletionPath, + httputil.MakeInternalAPI("performDeviceDeletion", func(req *http.Request) util.JSONResponse { + request := api.PerformDeviceDeletionRequest{} + response := api.PerformDeviceDeletionResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformDeviceDeletion(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(QueryProfilePath, httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse { request := api.QueryProfileRequest{} diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index efe6f927c..9b535aab9 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -174,7 +174,7 @@ func (s *devicesStatements) deleteDevice( func (s *devicesStatements) deleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { - orig := strings.Replace(deleteDevicesSQL, "($1)", sqlutil.QueryVariadic(len(devices)), 1) + orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) prep, err := s.db.Prepare(orig) if err != nil { return err @@ -186,7 +186,6 @@ func (s *devicesStatements) deleteDevices( for i, v := range devices { params[i+1] = v } - params = append(params, params...) _, err = stmt.ExecContext(ctx, params...) return err }) diff --git a/userapi/userapi.go b/userapi/userapi.go index 7aadec06a..c4ab90bac 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -17,6 +17,7 @@ package userapi import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/config" + keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/inthttp" @@ -34,12 +35,13 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI(accountDB accounts.Database, deviceDB devices.Database, - serverName gomatrixserverlib.ServerName, appServices []config.ApplicationService) api.UserInternalAPI { + serverName gomatrixserverlib.ServerName, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI) api.UserInternalAPI { return &internal.UserInternalAPI{ AccountDB: accountDB, DeviceDB: deviceDB, ServerName: serverName, AppServices: appServices, + KeyAPI: keyAPI, } } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index b971964af..548148f27 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -37,7 +37,7 @@ func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database, t.Fatalf("failed to create device DB: %s", err) } - return userapi.NewInternalAPI(accountDB, deviceDB, serverName, nil), accountDB, deviceDB + return userapi.NewInternalAPI(accountDB, deviceDB, serverName, nil, nil), accountDB, deviceDB } func TestQueryProfile(t *testing.T) {