diff --git a/appservice/appservice.go b/appservice/appservice.go index b950a8219..5b1b93de2 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -38,7 +38,7 @@ import ( // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( base *base.BaseDendrite, - userAPI userapi.UserInternalAPI, + userAPI userapi.AppserviceUserAPI, rsAPI roomserverAPI.RoomserverInternalAPI, ) appserviceAPI.AppServiceInternalAPI { client := &http.Client{ diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go index 9e9940cd3..de9f5aaf1 100644 --- a/appservice/appservice_test.go +++ b/appservice/appservice_test.go @@ -125,7 +125,7 @@ func TestAppserviceInternalAPI(t *testing.T) { // Create required internal APIs rsAPI := roomserver.NewInternalAPI(base) - usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil) + usrAPI := userapi.NewInternalAPI(base, rsAPI, nil) asAPI := appservice.NewInternalAPI(base, usrAPI, rsAPI) runCases(t, asAPI) diff --git a/build/dendritejs-pinecone/main.go b/build/dendritejs-pinecone/main.go index f3fcb03e4..44e52286f 100644 --- a/build/dendritejs-pinecone/main.go +++ b/build/dendritejs-pinecone/main.go @@ -30,7 +30,6 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" @@ -183,13 +182,11 @@ func startup() { rsAPI := roomserver.NewInternalAPI(base) federation := conn.CreateFederationClient(base, pSessions) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, federation) asQuery := appservice.NewInternalAPI( base, userAPI, rsAPI, @@ -208,7 +205,6 @@ func startup() { FederationAPI: fedSenderAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, - KeyAPI: keyAPI, //ServerKeyAPI: serverKeyAPI, ExtPublicRoomsProvider: rooms.NewPineconeRoomProvider(pRouter, pSessions, fedSenderAPI, federation), } diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index fad850a9e..32af611ae 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" @@ -165,9 +164,7 @@ func (m *DendriteMonolith) Start() { base, federation, rsAPI, base.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI) - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, federation) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) rsAPI.SetAppserviceAPI(asAPI) @@ -186,7 +183,6 @@ func (m *DendriteMonolith) Start() { FederationAPI: fsAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, - KeyAPI: keyAPI, ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ygg, fsAPI, federation, ), diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index c7ca019ff..300d3a88a 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -8,7 +8,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/federationapi" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" @@ -40,11 +39,9 @@ func TestAdminResetPassword(t *testing.T) { rsAPI := roomserver.NewInternalAPI(base) // Needed for changing the password/login - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) // We mostly need the userAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil) + AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil) // Create the users in the userapi and login accessTokens := map[*test.User]string{ @@ -155,14 +152,12 @@ func TestPurgeRoom(t *testing.T) { fedClient := base.CreateFederationClient() rsAPI := roomserver.NewInternalAPI(base) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fedClient, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) // this starts the JetStream consumers - syncapi.AddPublicRoutes(base, userAPI, rsAPI, keyAPI) + syncapi.AddPublicRoutes(base, userAPI, rsAPI) federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true) rsAPI.SetFederationAPI(nil, nil) - keyAPI.SetUserAPI(userAPI) // Create the room if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { @@ -170,7 +165,7 @@ func TestPurgeRoom(t *testing.T) { } // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil) + AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil) // Create the users in the userapi and login accessTokens := map[*test.User]string{ diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 2d17e0928..e9985d43f 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -15,6 +15,7 @@ package clientapi import ( + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" @@ -23,11 +24,9 @@ import ( "github.com/matrix-org/dendrite/clientapi/routing" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/transactions" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" - userapi "github.com/matrix-org/dendrite/userapi/api" ) // AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component. @@ -40,7 +39,6 @@ func AddPublicRoutes( fsAPI federationAPI.ClientFederationAPI, userAPI userapi.ClientUserAPI, userDirectoryProvider userapi.QuerySearchProfilesAPI, - keyAPI keyserverAPI.ClientKeyAPI, extRoomsProvider api.ExtraPublicRoomsProvider, ) { cfg := &base.Cfg.ClientAPI @@ -61,7 +59,7 @@ func AddPublicRoutes( base, cfg, rsAPI, asAPI, userAPI, userDirectoryProvider, federation, - syncProducer, transactionsCache, fsAPI, keyAPI, + syncProducer, transactionsCache, fsAPI, extRoomsProvider, mscCfg, natsClient, ) } diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 4b4dedfd1..a01f6b944 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -16,14 +16,13 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" - userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/api" ) -func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -56,7 +55,7 @@ func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi } } -func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -99,7 +98,7 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi } } -func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -130,7 +129,7 @@ func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.De } } -func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { +func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.Device, userAPI api.ClientUserAPI) util.JSONResponse { if req.Body == nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -150,8 +149,8 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap JSON: jsonerror.InvalidArgumentValue(err.Error()), } } - accAvailableResp := &userapi.QueryAccountAvailabilityResponse{} - if err = userAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{ + accAvailableResp := &api.QueryAccountAvailabilityResponse{} + if err = userAPI.QueryAccountAvailability(req.Context(), &api.QueryAccountAvailabilityRequest{ Localpart: localpart, ServerName: serverName, }, accAvailableResp); err != nil { @@ -186,13 +185,13 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap return *internal.PasswordResponse(err) } - updateReq := &userapi.PerformPasswordUpdateRequest{ + updateReq := &api.PerformPasswordUpdateRequest{ Localpart: localpart, ServerName: serverName, Password: request.Password, LogoutDevices: true, } - updateRes := &userapi.PerformPasswordUpdateResponse{} + updateRes := &api.PerformPasswordUpdateResponse{} if err := userAPI.PerformPasswordUpdate(req.Context(), updateReq, updateRes); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -209,7 +208,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap } } -func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, natsClient *nats.Conn) util.JSONResponse { +func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *api.Device, natsClient *nats.Conn) util.JSONResponse { _, err := natsClient.RequestMsg(nats.NewMsg(cfg.Matrix.JetStream.Prefixed(jetstream.InputFulltextReindex)), time.Second*10) if err != nil { logrus.WithError(err).Error("failed to publish nats message") @@ -255,7 +254,7 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien } } -func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) diff --git a/clientapi/routing/joinroom_test.go b/clientapi/routing/joinroom_test.go index 9e8208e6d..1450ef4bd 100644 --- a/clientapi/routing/joinroom_test.go +++ b/clientapi/routing/joinroom_test.go @@ -10,7 +10,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" @@ -29,8 +28,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) { defer baseClose() rsAPI := roomserver.NewInternalAPI(base) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 2570db09c..267ba1dc5 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -21,9 +21,8 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" - userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" ) @@ -34,8 +33,8 @@ type crossSigningRequest struct { func UploadCrossSigningDeviceKeys( req *http.Request, userInteractiveAuth *auth.UserInteractive, - keyserverAPI api.ClientKeyAPI, device *userapi.Device, - accountAPI userapi.ClientUserAPI, cfg *config.ClientAPI, + keyserverAPI api.ClientKeyAPI, device *api.Device, + accountAPI api.ClientUserAPI, cfg *config.ClientAPI, ) util.JSONResponse { uploadReq := &crossSigningRequest{} uploadRes := &api.PerformUploadDeviceKeysResponse{} @@ -107,7 +106,7 @@ func UploadCrossSigningDeviceKeys( } } -func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { +func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse { uploadReq := &api.PerformUploadDeviceSignaturesRequest{} uploadRes := &api.PerformUploadDeviceSignaturesResponse{} diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index 0c12b1117..3d60fcc3a 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -23,8 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/keyserver/api" - userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/api" ) type uploadKeysRequest struct { @@ -32,7 +31,7 @@ type uploadKeysRequest struct { OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"` } -func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { +func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse { var r uploadKeysRequest resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { @@ -106,7 +105,7 @@ func (r *queryKeysRequest) GetTimeout() time.Duration { return timeout } -func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { +func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse { var r queryKeysRequest resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { diff --git a/clientapi/routing/login_test.go b/clientapi/routing/login_test.go index d429d7f8c..b72db9d8b 100644 --- a/clientapi/routing/login_test.go +++ b/clientapi/routing/login_test.go @@ -9,7 +9,6 @@ import ( "testing" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" @@ -39,12 +38,10 @@ func TestLogin(t *testing.T) { rsAPI := roomserver.NewInternalAPI(base) // Needed for /login - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) // We mostly need the userAPI for this test, so nil for other APIs/caches etc. - Setup(base, &base.Cfg.ClientAPI, nil, nil, userAPI, nil, nil, nil, nil, nil, keyAPI, nil, &base.Cfg.MSCs, nil) + Setup(base, &base.Cfg.ClientAPI, nil, nil, userAPI, nil, nil, nil, nil, nil, nil, &base.Cfg.MSCs, nil) // Create password password := util.RandomString(8) diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 670c392bf..651e3d3d6 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -30,7 +30,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" @@ -409,9 +408,7 @@ func Test_register(t *testing.T) { defer baseClose() rsAPI := roomserver.NewInternalAPI(base) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -582,9 +579,7 @@ func TestRegisterUserWithDisplayName(t *testing.T) { base.Cfg.Global.ServerName = "server" rsAPI := roomserver.NewInternalAPI(base) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) deviceName, deviceID := "deviceName", "deviceID" expectedDisplayName := "DisplayName" response := completeRegistration( @@ -623,9 +618,7 @@ func TestRegisterAdminUsingSharedSecret(t *testing.T) { base.Cfg.ClientAPI.RegistrationSharedSecret = sharedSecret rsAPI := roomserver.NewInternalAPI(base) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) expectedDisplayName := "rabbit" jsonStr := []byte(`{"admin":true,"mac":"24dca3bba410e43fe64b9b5c28306693bf3baa9f","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice","displayname":"rabbit"}`) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 66610c0a0..028d02e97 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -22,6 +22,7 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/setup/base" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/nats-io/nats.go" @@ -37,11 +38,9 @@ import ( federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/transactions" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" - userapi "github.com/matrix-org/dendrite/userapi/api" ) // Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client @@ -61,7 +60,6 @@ func Setup( syncProducer *producers.SyncAPIProducer, transactionsCache *transactions.Cache, federationSender federationAPI.ClientFederationAPI, - keyAPI keyserverAPI.ClientKeyAPI, extRoomsProvider api.ExtraPublicRoomsProvider, mscCfg *config.MSCs, natsClient *nats.Conn, ) { @@ -192,7 +190,7 @@ func Setup( dendriteAdminRouter.Handle("/admin/refreshDevices/{userID}", httputil.MakeAdminAPI("admin_refresh_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return AdminMarkAsStale(req, cfg, keyAPI) + return AdminMarkAsStale(req, cfg, userAPI) }), ).Methods(http.MethodPost, http.MethodOptions) @@ -1372,11 +1370,11 @@ func Setup( // Cross-signing device keys postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, keyAPI, device, userAPI, cfg) + return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, userAPI, device, userAPI, cfg) }) postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return UploadCrossSigningDeviceSignatures(req, keyAPI, device) + return UploadCrossSigningDeviceSignatures(req, userAPI, device) }, httputil.WithAllowGuests()) v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) @@ -1388,22 +1386,22 @@ func Setup( // Supplying a device ID is deprecated. v3mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return UploadKeys(req, keyAPI, device) + return UploadKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/upload", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return UploadKeys(req, keyAPI, device) + return UploadKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/query", httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return QueryKeys(req, keyAPI, device) + return QueryKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/claim", httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return ClaimKeys(req, keyAPI) + return ClaimKeys(req, userAPI) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", diff --git a/cmd/dendrite-demo-pinecone/monolith/monolith.go b/cmd/dendrite-demo-pinecone/monolith/monolith.go index fe19593ce..27720369e 100644 --- a/cmd/dendrite-demo-pinecone/monolith/monolith.go +++ b/cmd/dendrite-demo-pinecone/monolith/monolith.go @@ -38,7 +38,6 @@ import ( federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/relayapi" relayAPI "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/dendrite/roomserver" @@ -139,9 +138,7 @@ func (p *P2PMonolith) SetupDendrite(cfg *config.Dendrite, port int, enableRelayi p.BaseDendrite, federation, rsAPI, p.BaseDendrite.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(p.BaseDendrite, &p.BaseDendrite.Cfg.KeyServer, fsAPI, rsComponent) - userAPI := userapi.NewInternalAPI(p.BaseDendrite, &cfg.UserAPI, nil, keyAPI, rsAPI, p.BaseDendrite.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(p.BaseDendrite, rsAPI, federation) asAPI := appservice.NewInternalAPI(p.BaseDendrite, userAPI, rsAPI) @@ -175,7 +172,6 @@ func (p *P2PMonolith) SetupDendrite(cfg *config.Dendrite, port int, enableRelayi FederationAPI: fsAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, - KeyAPI: keyAPI, RelayAPI: relayAPI, ExtPublicRoomsProvider: roomProvider, ExtUserDirectoryProvider: userProvider, diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 842682b4a..d759c6a73 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -39,7 +39,6 @@ import ( "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" @@ -161,10 +160,7 @@ func main() { base, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI) - - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, federation) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) rsAPI.SetAppserviceAPI(asAPI) @@ -184,7 +180,6 @@ func main() { FederationAPI: fsAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, - KeyAPI: keyAPI, ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ygg, fsAPI, federation, ), diff --git a/cmd/dendrite/main.go b/cmd/dendrite/main.go index 35bfc1d65..e8ff0a478 100644 --- a/cmd/dendrite/main.go +++ b/cmd/dendrite/main.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/federationapi" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" basepkg "github.com/matrix-org/dendrite/setup/base" @@ -56,10 +55,7 @@ func main() { keyRing := fsAPI.KeyRing() - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI) - - pgClient := base.PushGatewayHTTPClient() - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient) + userAPI := userapi.NewInternalAPI(base, rsAPI, federation) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) @@ -69,7 +65,6 @@ func main() { rsAPI.SetFederationAPI(fsAPI, keyRing) rsAPI.SetAppserviceAPI(asAPI) rsAPI.SetUserAPI(userAPI) - keyAPI.SetUserAPI(userAPI) monolith := setup.Monolith{ Config: base.Cfg, @@ -83,7 +78,6 @@ func main() { FederationAPI: fsAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, - KeyAPI: keyAPI, } monolith.AddAllPublicRoutes(base) diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 601257d4b..7d9df3d78 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -26,11 +26,11 @@ import ( "github.com/matrix-org/dendrite/federationapi/queue" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/types" - "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/userapi/api" ) // KeyChangeConsumer consumes events that originate in key server. diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 108039167..ec482659a 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -28,7 +28,6 @@ import ( "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/internal/caching" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" @@ -42,12 +41,11 @@ import ( // AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component. func AddPublicRoutes( base *base.BaseDendrite, - userAPI userapi.UserInternalAPI, + userAPI userapi.FederationUserAPI, federation *gomatrixserverlib.FederationClient, keyRing gomatrixserverlib.JSONVerifier, rsAPI roomserverAPI.FederationRoomserverAPI, fedAPI federationAPI.FederationInternalAPI, - keyAPI keyserverAPI.FederationKeyAPI, servers federationAPI.ServersInRoomProvider, ) { cfg := &base.Cfg.FederationAPI @@ -79,7 +77,7 @@ func AddPublicRoutes( routing.Setup( base, rsAPI, f, keyRing, - federation, userAPI, keyAPI, mscCfg, + federation, userAPI, mscCfg, servers, producer, ) } diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 8d1d8514f..57d4b9644 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -17,13 +17,13 @@ import ( "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/internal" - keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" + userapi "github.com/matrix-org/dendrite/userapi/api" ) type fedRoomserverAPI struct { @@ -230,9 +230,9 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { // Inject a keyserver key change event and ensure we try to send it out. If we don't, then the // federationapi is incorrectly waiting for an output room event to arrive to update the joined // hosts table. - key := keyapi.DeviceMessage{ - Type: keyapi.TypeDeviceKeyUpdate, - DeviceKeys: &keyapi.DeviceKeys{ + key := userapi.DeviceMessage{ + Type: userapi.TypeDeviceKeyUpdate, + DeviceKeys: &userapi.DeviceKeys{ UserID: joiningUser.ID, DeviceID: "MY_DEVICE", DisplayName: "BLARGLE", @@ -277,7 +277,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { keyRing := &test.NopJSONVerifier{} // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. - federationapi.AddPublicRoutes(b, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil, nil) + federationapi.AddPublicRoutes(b, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil) baseURL, cancel := test.ListenAndServe(t, b.PublicFederationAPIMux, true) defer cancel() serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) diff --git a/federationapi/producers/syncapi.go b/federationapi/producers/syncapi.go index 7cce13a7d..6bcfafa39 100644 --- a/federationapi/producers/syncapi.go +++ b/federationapi/producers/syncapi.go @@ -41,7 +41,7 @@ type SyncAPIProducer struct { TopicSigningKeyUpdate string JetStream nats.JetStreamContext Config *config.FederationAPI - UserAPI userapi.UserInternalAPI + UserAPI userapi.FederationUserAPI } func (p *SyncAPIProducer) SendReceipt( diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go index ce8b06b70..871d26cd4 100644 --- a/federationapi/routing/devices.go +++ b/federationapi/routing/devices.go @@ -17,7 +17,7 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/jsonerror" - keyapi "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/tidwall/gjson" @@ -26,11 +26,11 @@ import ( // GetUserDevices for the given user id func GetUserDevices( req *http.Request, - keyAPI keyapi.FederationKeyAPI, + keyAPI api.FederationKeyAPI, userID string, ) util.JSONResponse { - var res keyapi.QueryDeviceMessagesResponse - if err := keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{ + var res api.QueryDeviceMessagesResponse + if err := keyAPI.QueryDeviceMessages(req.Context(), &api.QueryDeviceMessagesRequest{ UserID: userID, }, &res); err != nil { return util.ErrorResponse(err) @@ -40,12 +40,12 @@ func GetUserDevices( return jsonerror.InternalServerError() } - sigReq := &keyapi.QuerySignaturesRequest{ + sigReq := &api.QuerySignaturesRequest{ TargetIDs: map[string][]gomatrixserverlib.KeyID{ userID: {}, }, } - sigRes := &keyapi.QuerySignaturesResponse{} + sigRes := &api.QuerySignaturesResponse{} for _, dev := range res.Devices { sigReq.TargetIDs[userID] = append(sigReq.TargetIDs[userID], gomatrixserverlib.KeyID(dev.DeviceID)) } diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index dc262cfde..2885cc916 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -22,8 +22,8 @@ import ( clienthttputil "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" diff --git a/federationapi/routing/profile_test.go b/federationapi/routing/profile_test.go index 763656081..3b9d576bf 100644 --- a/federationapi/routing/profile_test.go +++ b/federationapi/routing/profile_test.go @@ -62,7 +62,7 @@ func TestHandleQueryProfile(t *testing.T) { if !ok { panic("This is a programming error.") } - routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil) + routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, &base.Cfg.MSCs, nil, nil) handler := fedMux.Get(routing.QueryProfileRouteName).GetHandler().ServeHTTP _, sk, _ := ed25519.GenerateKey(nil) diff --git a/federationapi/routing/query_test.go b/federationapi/routing/query_test.go index 21f35bf0c..d839a16b8 100644 --- a/federationapi/routing/query_test.go +++ b/federationapi/routing/query_test.go @@ -62,7 +62,7 @@ func TestHandleQueryDirectory(t *testing.T) { if !ok { panic("This is a programming error.") } - routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil) + routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, &base.Cfg.MSCs, nil, nil) handler := fedMux.Get(routing.QueryDirectoryRouteName).GetHandler().ServeHTTP _, sk, _ := ed25519.GenerateKey(nil) diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 5eb30c6ec..324740ddc 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -29,7 +29,6 @@ import ( "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" @@ -62,7 +61,6 @@ func Setup( keys gomatrixserverlib.JSONVerifier, federation federationAPI.FederationClient, userAPI userapi.FederationUserAPI, - keyAPI keyserverAPI.FederationKeyAPI, mscCfg *config.MSCs, servers federationAPI.ServersInRoomProvider, producer *producers.SyncAPIProducer, @@ -141,7 +139,7 @@ func Setup( func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return Send( httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]), - cfg, rsAPI, keyAPI, keys, federation, mu, servers, producer, + cfg, rsAPI, userAPI, keys, federation, mu, servers, producer, ) }, )).Methods(http.MethodPut, http.MethodOptions).Name(SendRouteName) @@ -269,7 +267,7 @@ func Setup( "federation_user_devices", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetUserDevices( - httpReq, keyAPI, vars["userID"], + httpReq, userAPI, vars["userID"], ) }, )).Methods(http.MethodGet) @@ -494,14 +492,14 @@ func Setup( v1fedmux.Handle("/user/keys/claim", MakeFedAPI( "federation_keys_claim", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - return ClaimOneTimeKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) + return ClaimOneTimeKeys(httpReq, request, userAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) v1fedmux.Handle("/user/keys/query", MakeFedAPI( "federation_keys_query", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) + return QueryDeviceKeys(httpReq, request, userAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 67b513c90..82651719f 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -28,9 +28,9 @@ import ( federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" + userAPI "github.com/matrix-org/dendrite/userapi/api" ) const ( @@ -59,7 +59,7 @@ func Send( txnID gomatrixserverlib.TransactionID, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, - keyAPI keyapi.FederationKeyAPI, + keyAPI userAPI.FederationUserAPI, keys gomatrixserverlib.JSONVerifier, federation federationAPI.FederationClient, mu *internal.MutexByRoom, diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index d7feee0e5..eed4e7e69 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -58,7 +58,7 @@ func TestHandleSend(t *testing.T) { if !ok { panic("This is a programming error.") } - routing.Setup(base, nil, r, keyRing, nil, nil, nil, &base.Cfg.MSCs, nil, nil) + routing.Setup(base, nil, r, keyRing, nil, nil, &base.Cfg.MSCs, nil, nil) handler := fedMux.Get(routing.SendRouteName).GetHandler().ServeHTTP _, sk, _ := ed25519.GenerateKey(nil) diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go index 95673fc14..13b00af50 100644 --- a/internal/transactionrequest.go +++ b/internal/transactionrequest.go @@ -24,9 +24,9 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/federationapi/types" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" syncTypes "github.com/matrix-org/dendrite/syncapi/types" + userAPI "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" @@ -56,7 +56,7 @@ var ( type TxnReq struct { gomatrixserverlib.Transaction rsAPI api.FederationRoomserverAPI - keyAPI keyapi.FederationKeyAPI + userAPI userAPI.FederationUserAPI ourServerName gomatrixserverlib.ServerName keys gomatrixserverlib.JSONVerifier roomsMu *MutexByRoom @@ -66,7 +66,7 @@ type TxnReq struct { func NewTxnReq( rsAPI api.FederationRoomserverAPI, - keyAPI keyapi.FederationKeyAPI, + userAPI userAPI.FederationUserAPI, ourServerName gomatrixserverlib.ServerName, keys gomatrixserverlib.JSONVerifier, roomsMu *MutexByRoom, @@ -80,7 +80,7 @@ func NewTxnReq( ) TxnReq { t := TxnReq{ rsAPI: rsAPI, - keyAPI: keyAPI, + userAPI: userAPI, ourServerName: ourServerName, keys: keys, roomsMu: roomsMu, diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go index 93c6fb6f1..8597ae24b 100644 --- a/internal/transactionrequest_test.go +++ b/internal/transactionrequest_test.go @@ -23,13 +23,13 @@ import ( "time" "github.com/matrix-org/dendrite/federationapi/producers" - keyAPI "github.com/matrix-org/dendrite/keyserver/api" rsAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" + keyAPI "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/stretchr/testify/assert" diff --git a/keyserver/README.md b/keyserver/README.md deleted file mode 100644 index fd9f37d27..000000000 --- a/keyserver/README.md +++ /dev/null @@ -1,19 +0,0 @@ -## Key Server - -This is an internal component which manages E2E keys from clients. It handles all the [Key Management APIs](https://matrix.org/docs/spec/client_server/r0.6.1#key-management-api) with the exception of `/keys/changes` which is handled by Sync API. This component is designed to shard by user ID. - -Keys are uploaded and stored in this component, and key changes are emitted to a Kafka topic for downstream components such as Sync API. - -### Internal APIs -- `PerformUploadKeys` stores identity keys and one-time public keys for given user(s). -- `PerformClaimKeys` acquires one-time public keys for given user(s). This may involve outbound federation calls. -- `QueryKeys` returns identity keys for given user(s). This may involve outbound federation calls. This component may then cache federated identity keys to avoid repeatedly hitting remote servers. -- A topic which emits identity keys every time there is a change (addition or deletion). - -### Endpoint mappings -- Client API maps `/keys/upload` to `PerformUploadKeys`. -- Client API maps `/keys/query` to `QueryKeys`. -- Client API maps `/keys/claim` to `PerformClaimKeys`. -- Federation API maps `/user/keys/query` to `QueryKeys`. -- Federation API maps `/user/keys/claim` to `PerformClaimKeys`. -- Sync API maps `/keys/changes` to consuming from the Kafka topic. diff --git a/keyserver/api/api.go b/keyserver/api/api.go deleted file mode 100644 index 14fced3e8..000000000 --- a/keyserver/api/api.go +++ /dev/null @@ -1,346 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package api - -import ( - "bytes" - "context" - "encoding/json" - "strings" - "time" - - "github.com/matrix-org/gomatrixserverlib" - - "github.com/matrix-org/dendrite/keyserver/types" - userapi "github.com/matrix-org/dendrite/userapi/api" -) - -type KeyInternalAPI interface { - SyncKeyAPI - ClientKeyAPI - FederationKeyAPI - UserKeyAPI - - // SetUserAPI assigns a user API to query when extracting device names. - SetUserAPI(i userapi.KeyserverUserAPI) -} - -// API functions required by the clientapi -type ClientKeyAPI interface { - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error - PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error - PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error - PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error - // PerformClaimKeys claims one-time keys for use in pre-key messages - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error - PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error -} - -// API functions required by the userapi -type UserKeyAPI interface { - PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error - PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) error -} - -// API functions required by the syncapi -type SyncKeyAPI interface { - QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error - QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error - PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error -} - -type FederationKeyAPI interface { - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error - QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error - QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error - PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error -} - -// KeyError is returned if there was a problem performing/querying the server -type KeyError struct { - Err string `json:"error"` - IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE - IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM - IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM -} - -func (k *KeyError) Error() string { - return k.Err -} - -type DeviceMessageType int - -const ( - TypeDeviceKeyUpdate DeviceMessageType = iota - TypeCrossSigningUpdate -) - -// DeviceMessage represents the message produced into Kafka by the key server. -type DeviceMessage struct { - Type DeviceMessageType `json:"Type,omitempty"` - *DeviceKeys `json:"DeviceKeys,omitempty"` - *OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"` - // A monotonically increasing number which represents device changes for this user. - StreamID int64 - DeviceChangeID int64 -} - -// OutputCrossSigningKeyUpdate is an entry in the signing key update output kafka log -type OutputCrossSigningKeyUpdate struct { - CrossSigningKeyUpdate `json:"signing_keys"` -} - -type CrossSigningKeyUpdate struct { - MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"` - SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"` - UserID string `json:"user_id"` -} - -// DeviceKeysEqual returns true if the device keys updates contain the -// same display name and key JSON. This will return false if either of -// the updates is not a device keys update, or if the user ID/device ID -// differ between the two. -func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool { - if m1.DeviceKeys == nil || m2.DeviceKeys == nil { - return false - } - if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID { - return false - } - if m1.DisplayName != m2.DisplayName { - return false // different display names - } - if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 { - return false // either is empty - } - return bytes.Equal(m1.KeyJSON, m2.KeyJSON) -} - -// DeviceKeys represents a set of device keys for a single device -// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload -type DeviceKeys struct { - // The user who owns this device - UserID string - // The device ID of this device - DeviceID string - // The device display name - DisplayName string - // The raw device key JSON - KeyJSON []byte -} - -// WithStreamID returns a copy of this device message with the given stream ID -func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage { - return DeviceMessage{ - DeviceKeys: k, - StreamID: streamID, - } -} - -// OneTimeKeys represents a set of one-time keys for a single device -// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload -type OneTimeKeys struct { - // The user who owns this device - UserID string - // The device ID of this device - DeviceID string - // A map of algorithm:key_id => key JSON - KeyJSON map[string]json.RawMessage -} - -// Split a key in KeyJSON into algorithm and key ID -func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) { - segments := strings.Split(keyIDWithAlgo, ":") - return segments[0], segments[1] -} - -// OneTimeKeysCount represents the counts of one-time keys for a single device -type OneTimeKeysCount struct { - // The user who owns this device - UserID string - // The device ID of this device - DeviceID string - // algorithm to count e.g: - // { - // "curve25519": 10, - // "signed_curve25519": 20 - // } - KeyCount map[string]int -} - -// PerformUploadKeysRequest is the request to PerformUploadKeys -type PerformUploadKeysRequest struct { - UserID string // Required - User performing the request - DeviceID string // Optional - Device performing the request, for fetching OTK count - DeviceKeys []DeviceKeys - OneTimeKeys []OneTimeKeys - // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update - // the display name for their respective device, and NOT to modify the keys. The key - // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths. - // Without this flag, requests to modify device display names would delete device keys. - OnlyDisplayNameUpdates bool -} - -// PerformUploadKeysResponse is the response to PerformUploadKeys -type PerformUploadKeysResponse struct { - // A fatal error when processing e.g database failures - Error *KeyError - // A map of user_id -> device_id -> Error for tracking failures. - KeyErrors map[string]map[string]*KeyError - OneTimeKeyCounts []OneTimeKeysCount -} - -// PerformDeleteKeysRequest asks the keyserver to forget about certain -// keys, and signatures related to those keys. -type PerformDeleteKeysRequest struct { - UserID string - KeyIDs []gomatrixserverlib.KeyID -} - -// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest. -type PerformDeleteKeysResponse struct { - Error *KeyError -} - -// KeyError sets a key error field on KeyErrors -func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) { - if r.KeyErrors[userID] == nil { - r.KeyErrors[userID] = make(map[string]*KeyError) - } - r.KeyErrors[userID][deviceID] = err -} - -type PerformClaimKeysRequest struct { - // Map of user_id to device_id to algorithm name - OneTimeKeys map[string]map[string]string - Timeout time.Duration -} - -type PerformClaimKeysResponse struct { - // Map of user_id to device_id to algorithm:key_id to key JSON - OneTimeKeys map[string]map[string]map[string]json.RawMessage - // Map of remote server domain to error JSON - Failures map[string]interface{} - // Set if there was a fatal error processing this action - Error *KeyError -} - -type PerformUploadDeviceKeysRequest struct { - gomatrixserverlib.CrossSigningKeys - // The user that uploaded the key, should be populated by the clientapi. - UserID string -} - -type PerformUploadDeviceKeysResponse struct { - Error *KeyError -} - -type PerformUploadDeviceSignaturesRequest struct { - Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice - // The user that uploaded the sig, should be populated by the clientapi. - UserID string -} - -type PerformUploadDeviceSignaturesResponse struct { - Error *KeyError -} - -type QueryKeysRequest struct { - // The user ID asking for the keys, e.g. if from a client API request. - // Will not be populated if the key request came from federation. - UserID string - // Maps user IDs to a list of devices - UserToDevices map[string][]string - Timeout time.Duration -} - -type QueryKeysResponse struct { - // Map of remote server domain to error JSON - Failures map[string]interface{} - // Map of user_id to device_id to device_key - DeviceKeys map[string]map[string]json.RawMessage - // Maps of user_id to cross signing key - MasterKeys map[string]gomatrixserverlib.CrossSigningKey - SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey - UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey - // Set if there was a fatal error processing this query - Error *KeyError -} - -type QueryKeyChangesRequest struct { - // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning - Offset int64 - // The inclusive offset where to track key changes up to. Messages with this offset are included in the response. - // Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing). - ToOffset int64 -} - -type QueryKeyChangesResponse struct { - // The set of users who have had their keys change. - UserIDs []string - // The latest offset represented in this response. - Offset int64 - // Set if there was a problem handling the request. - Error *KeyError -} - -type QueryOneTimeKeysRequest struct { - // The local user to query OTK counts for - UserID string - // The device to query OTK counts for - DeviceID string -} - -type QueryOneTimeKeysResponse struct { - // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84 - Count OneTimeKeysCount - Error *KeyError -} - -type QueryDeviceMessagesRequest struct { - UserID string -} - -type QueryDeviceMessagesResponse struct { - // The latest stream ID - StreamID int64 - Devices []DeviceMessage - Error *KeyError -} - -type QuerySignaturesRequest struct { - // A map of target user ID -> target key/device IDs to retrieve signatures for - TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"` -} - -type QuerySignaturesResponse struct { - // A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures - Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap - // A map of target user ID -> cross-signing master key - MasterKeys map[string]gomatrixserverlib.CrossSigningKey - // A map of target user ID -> cross-signing self-signing key - SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey - // A map of target user ID -> cross-signing user-signing key - UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey - // The request error, if any - Error *KeyError -} - -type PerformMarkAsStaleRequest struct { - UserID string - Domain gomatrixserverlib.ServerName - DeviceID string -} diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go deleted file mode 100644 index 2d143682d..000000000 --- a/keyserver/keyserver.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package keyserver - -import ( - "github.com/sirupsen/logrus" - - rsapi "github.com/matrix-org/dendrite/roomserver/api" - - fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/consumers" - "github.com/matrix-org/dendrite/keyserver/internal" - "github.com/matrix-org/dendrite/keyserver/producers" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" -) - -// NewInternalAPI returns a concerete implementation of the internal API. Callers -// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. -func NewInternalAPI( - base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI, - rsAPI rsapi.KeyserverRoomserverAPI, -) api.KeyInternalAPI { - js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) - - db, err := storage.NewDatabase(base, &cfg.Database) - if err != nil { - logrus.WithError(err).Panicf("failed to connect to key server database") - } - - keyChangeProducer := &producers.KeyChange{ - Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)), - JetStream: js, - DB: db, - } - ap := &internal.KeyInternalAPI{ - DB: db, - Cfg: cfg, - FedClient: fedClient, - Producer: keyChangeProducer, - } - updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable - ap.Updater = updater - - // Remove users which we don't share a room with anymore - if err := updater.CleanUp(); err != nil { - logrus.WithError(err).Error("failed to cleanup stale device lists") - } - - go func() { - if err := updater.Start(); err != nil { - logrus.WithError(err).Panicf("failed to start device list updater") - } - }() - - dlConsumer := consumers.NewDeviceListUpdateConsumer( - base.ProcessContext, cfg, js, updater, - ) - if err := dlConsumer.Start(); err != nil { - logrus.WithError(err).Panic("failed to start device list consumer") - } - - sigConsumer := consumers.NewSigningKeyUpdateConsumer( - base.ProcessContext, cfg, js, ap, - ) - if err := sigConsumer.Start(); err != nil { - logrus.WithError(err).Panic("failed to start signing key consumer") - } - - return ap -} diff --git a/keyserver/keyserver_test.go b/keyserver/keyserver_test.go deleted file mode 100644 index 159b280f5..000000000 --- a/keyserver/keyserver_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package keyserver - -import ( - "context" - "testing" - - roomserver "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" -) - -type mockKeyserverRoomserverAPI struct { - leftUsers []string -} - -func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error { - res.LeftUsers = m.leftUsers - return nil -} - -// Merely tests that we can create an internal keyserver API -func Test_NewInternalAPI(t *testing.T) { - rsAPI := &mockKeyserverRoomserverAPI{} - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, closeBase := testrig.CreateBaseDendrite(t, dbType) - defer closeBase() - _ = NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - }) -} diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go deleted file mode 100644 index c6a8f44cd..000000000 --- a/keyserver/storage/interface.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "context" - "encoding/json" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type Database interface { - // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination - // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database. - ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) - - // StoreOneTimeKeys persists the given one-time keys. - StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) - - // OneTimeKeysCount returns a count of all OTKs for this device. - OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) - - // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. - DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error - - // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key - // for this (user, device). - // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set. - // Returns an error if there was a problem storing the keys. - StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error - - // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key - // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior - // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. - StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error - - // PrevIDsExists returns true if all prev IDs exist for this user. - PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) - - // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. - // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. - DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) - - // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying - // cross-signing signatures relating to that device. - DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error - - // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key - // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. - ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) - - // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change. - // `userID` is the the user who has changed their keys in some way. - StoreKeyChange(ctx context.Context, userID string) (int64, error) - - // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive). - // A to offset of types.OffsetNewest means no upper limit. - // Returns the offset of the latest key change. - KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) - - // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. - // If no domains are given, all user IDs with stale device lists are returned. - StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) - - // MarkDeviceListStale sets the stale bit for this user to isStale. - MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error - - CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) - CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) - CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) - - StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error - StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error - - DeleteStaleDeviceLists( - ctx context.Context, - userIDs []string, - ) error -} diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go deleted file mode 100644 index 35e630559..000000000 --- a/keyserver/storage/postgres/storage.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/shared" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -// NewDatabase creates a new sync server database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { - var err error - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) - if err != nil { - return nil, err - } - otk, err := NewPostgresOneTimeKeysTable(db) - if err != nil { - return nil, err - } - dk, err := NewPostgresDeviceKeysTable(db) - if err != nil { - return nil, err - } - kc, err := NewPostgresKeyChangesTable(db) - if err != nil { - return nil, err - } - sdl, err := NewPostgresStaleDeviceListsTable(db) - if err != nil { - return nil, err - } - csk, err := NewPostgresCrossSigningKeysTable(db) - if err != nil { - return nil, err - } - css, err := NewPostgresCrossSigningSigsTable(db) - if err != nil { - return nil, err - } - if err = kc.Prepare(); err != nil { - return nil, err - } - d := &shared.Database{ - DB: db, - Writer: writer, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, - StaleDeviceListsTable: sdl, - CrossSigningKeysTable: csk, - CrossSigningSigsTable: css, - } - return d, nil -} diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go deleted file mode 100644 index 54dd6ddc9..000000000 --- a/keyserver/storage/shared/storage.go +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package shared - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type Database struct { - DB *sql.DB - Writer sqlutil.Writer - OneTimeKeysTable tables.OneTimeKeys - DeviceKeysTable tables.DeviceKeys - KeyChangesTable tables.KeyChanges - StaleDeviceListsTable tables.StaleDeviceLists - CrossSigningKeysTable tables.CrossSigningKeys - CrossSigningSigsTable tables.CrossSigningSigs -} - -func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { - return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms) -} - -func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) { - _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys) - return err - }) - return -} - -func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { - return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) -} - -func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { - return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) -} - -func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { - count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs) - if err != nil { - return false, err - } - return count == len(prevIDs), nil -} - -func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for _, userID := range clearUserIDs { - err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID) - if err != nil { - return err - } - } - return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) - }) -} - -func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { - // work out the latest stream IDs for each user - userIDToStreamID := make(map[string]int64) - for _, k := range keys { - userIDToStreamID[k.UserID] = 0 - } - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for userID := range userIDToStreamID { - streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID) - if err != nil { - return err - } - userIDToStreamID[userID] = streamID - } - // set the stream IDs for each key - for i := range keys { - k := keys[i] - userIDToStreamID[k.UserID]++ // start stream from 1 - k.StreamID = userIDToStreamID[k.UserID] - keys[i] = k - } - return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) - }) -} - -func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { - return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty) -} - -func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { - var result []api.OneTimeKeys - err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for userID, deviceToAlgo := range userToDeviceToAlgorithm { - for deviceID, algo := range deviceToAlgo { - keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo) - if err != nil { - return err - } - if keyJSON != nil { - result = append(result, api.OneTimeKeys{ - UserID: userID, - DeviceID: deviceID, - KeyJSON: keyJSON, - }) - } - } - } - return nil - }) - return result, err -} - -func (d *Database) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) { - err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { - id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID) - return err - }) - return -} - -func (d *Database) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { - return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset) -} - -// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. -// If no domains are given, all user IDs with stale device lists are returned. -func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { - return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) -} - -// MarkDeviceListStale sets the stale bit for this user to isStale. -func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { - return d.Writer.Do(nil, nil, func(_ *sql.Tx) error { - return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) - }) -} - -// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying -// cross-signing signatures relating to that device. -func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for _, deviceID := range deviceIDs { - if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err) - } - if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err) - } - if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err) - } - } - return nil - }) -} - -// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. -func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) { - keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) - if err != nil { - return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err) - } - results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} - for purpose, key := range keyMap { - keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode()) - result := gomatrixserverlib.CrossSigningKey{ - UserID: userID, - Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose}, - Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ - keyID: key, - }, - } - sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID) - if err != nil { - continue - } - for sigUserID, forSigUserID := range sigMap { - if userID != sigUserID { - continue - } - if result.Signatures == nil { - result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - if _, ok := result.Signatures[sigUserID]; !ok { - result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - for sigKeyID, sigBytes := range forSigUserID { - result.Signatures[sigUserID][sigKeyID] = sigBytes - } - } - results[purpose] = result - } - return results, nil -} - -// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. -func (d *Database) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) { - return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) -} - -// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. -func (d *Database) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { - return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID) -} - -// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. -func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for keyType, keyData := range keyMap { - if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil { - return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err) - } - } - return nil - }) -} - -// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice. -func (d *Database) StoreCrossSigningSigsForTarget( - ctx context.Context, - originUserID string, originKeyID gomatrixserverlib.KeyID, - targetUserID string, targetKeyID gomatrixserverlib.KeyID, - signature gomatrixserverlib.Base64Bytes, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { - return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err) - } - return nil - }) -} - -// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore. -func (d *Database) DeleteStaleDeviceLists( - ctx context.Context, - userIDs []string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs) - }) -} diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go deleted file mode 100644 index 873fe3e24..000000000 --- a/keyserver/storage/sqlite3/storage.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/shared" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) - if err != nil { - return nil, err - } - otk, err := NewSqliteOneTimeKeysTable(db) - if err != nil { - return nil, err - } - dk, err := NewSqliteDeviceKeysTable(db) - if err != nil { - return nil, err - } - kc, err := NewSqliteKeyChangesTable(db) - if err != nil { - return nil, err - } - sdl, err := NewSqliteStaleDeviceListsTable(db) - if err != nil { - return nil, err - } - csk, err := NewSqliteCrossSigningKeysTable(db) - if err != nil { - return nil, err - } - css, err := NewSqliteCrossSigningSigsTable(db) - if err != nil { - return nil, err - } - - if err = kc.Prepare(); err != nil { - return nil, err - } - d := &shared.Database{ - DB: db, - Writer: writer, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, - StaleDeviceListsTable: sdl, - CrossSigningKeysTable: csk, - CrossSigningSigsTable: css, - } - return d, nil -} diff --git a/keyserver/storage/storage.go b/keyserver/storage/storage.go deleted file mode 100644 index ab6a35401..000000000 --- a/keyserver/storage/storage.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !wasm -// +build !wasm - -package storage - -import ( - "fmt" - - "github.com/matrix-org/dendrite/keyserver/storage/postgres" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) -// and sets postgres connection parameters -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) - case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties) - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go deleted file mode 100644 index e7a2af7c2..000000000 --- a/keyserver/storage/storage_test.go +++ /dev/null @@ -1,197 +0,0 @@ -package storage_test - -import ( - "context" - "reflect" - "sync" - "testing" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" -) - -var ctx = context.Background() - -func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { - base, close := testrig.CreateBaseDendrite(t, dbType) - db, err := storage.NewDatabase(base, &base.Cfg.KeyServer.Database) - if err != nil { - t.Fatalf("failed to create new database: %v", err) - } - return db, close -} - -func MustNotError(t *testing.T, err error) { - t.Helper() - if err == nil { - return - } - t.Fatalf("operation failed: %s", err) -} - -func TestKeyChanges(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - _, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") - MustNotError(t, err) - deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeIDC { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) - } - if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } - }) -} - -func TestKeyChangesNoDupes(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - if deviceChangeIDA == deviceChangeIDB { - t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) - } - deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeID { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) - } - if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } - }) -} - -func TestKeyChangesUpperLimit(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") - MustNotError(t, err) - _, err = db.StoreKeyChange(ctx, "@charlie:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeIDB { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) - } - if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } - }) -} - -var dbLock sync.Mutex -var deviceArray = []string{"AAA", "another_device"} - -// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, -// and that they are returned correctly when querying for device keys. -func TestDeviceKeysStreamIDGeneration(t *testing.T) { - var err error - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - alice := "@alice:TestDeviceKeysStreamIDGeneration" - bob := "@bob:TestDeviceKeysStreamIDGeneration" - msgs := []api.DeviceMessage{ - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: alice, - KeyJSON: []byte(`{"key":"v1"}`), - }, - // StreamID: 1 - }, - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: bob, - KeyJSON: []byte(`{"key":"v1"}`), - }, - // StreamID: 1 as this is a different user - }, - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "another_device", - UserID: alice, - KeyJSON: []byte(`{"key":"v1"}`), - }, - // StreamID: 2 as this is a 2nd device key - }, - } - MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) - if msgs[0].StreamID != 1 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) - } - if msgs[1].StreamID != 1 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) - } - if msgs[2].StreamID != 2 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) - } - - // updating a device sets the next stream ID for that user - msgs = []api.DeviceMessage{ - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: alice, - KeyJSON: []byte(`{"key":"v2"}`), - }, - // StreamID: 3 - }, - } - MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) - if msgs[0].StreamID != 3 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) - } - - dbLock.Lock() - defer dbLock.Unlock() - // Querying for device keys returns the latest stream IDs - msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false) - - if err != nil { - t.Fatalf("DeviceKeysForUser returned error: %s", err) - } - wantStreamIDs := map[string]int64{ - "AAA": 3, - "another_device": 2, - } - if len(msgs) != len(wantStreamIDs) { - t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) - } - for _, m := range msgs { - if m.StreamID != wantStreamIDs[m.DeviceID] { - t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) - } - } - }) -} diff --git a/keyserver/storage/storage_wasm.go b/keyserver/storage/storage_wasm.go deleted file mode 100644 index 75c9053e8..000000000 --- a/keyserver/storage/storage_wasm.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "fmt" - - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) - case dbProperties.ConnectionString.IsPostgres(): - return nil, fmt.Errorf("can't use Postgres implementation") - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go deleted file mode 100644 index 24da1125e..000000000 --- a/keyserver/storage/tables/interface.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tables - -import ( - "context" - "database/sql" - "encoding/json" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type OneTimeKeys interface { - SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) - CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) - InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) - // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. - // Returns an empty map if the key does not exist. - SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error) - DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error -} - -type DeviceKeys interface { - SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error - InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error - SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) - CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) - SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) - DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error - DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error -} - -type KeyChanges interface { - InsertKeyChange(ctx context.Context, userID string) (int64, error) - // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets. - // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset. - SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) - - Prepare() error -} - -type StaleDeviceLists interface { - InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error - SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) - DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error -} - -type CrossSigningKeys interface { - SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) - UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error -} - -type CrossSigningSigs interface { - SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error) - UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error - DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error -} diff --git a/roomserver/api/api.go b/roomserver/api/api.go index a8228ae81..73732ae32 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -17,7 +17,6 @@ type RoomserverInternalAPI interface { ClientRoomserverAPI UserRoomserverAPI FederationRoomserverAPI - KeyserverRoomserverAPI // needed to avoid chicken and egg scenario when setting up the // interdependencies between the roomserver and other input APIs @@ -167,6 +166,7 @@ type ClientRoomserverAPI interface { type UserRoomserverAPI interface { QueryLatestEventsAndStateAPI + KeyserverRoomserverAPI QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 7ba01e50f..304311c4c 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -12,7 +12,6 @@ import ( userAPI "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/federationapi" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/syncapi" "github.com/matrix-org/gomatrixserverlib" @@ -47,7 +46,7 @@ func TestUsers(t *testing.T) { }) t.Run("kick users", func(t *testing.T) { - usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil) + usrAPI := userapi.NewInternalAPI(base, rsAPI, nil) rsAPI.SetUserAPI(usrAPI) testKickUsers(t, rsAPI, usrAPI) }) @@ -228,11 +227,10 @@ func TestPurgeRoom(t *testing.T) { fedClient := base.CreateFederationClient() rsAPI := roomserver.NewInternalAPI(base) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fedClient, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) // this starts the JetStream consumers - syncapi.AddPublicRoutes(base, userAPI, rsAPI, keyAPI) + syncapi.AddPublicRoutes(base, userAPI, rsAPI) federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true) rsAPI.SetFederationAPI(nil, nil) diff --git a/setup/monolith.go b/setup/monolith.go index 5bbe4019e..d8c652234 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/federationapi" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/transactions" - keyAPI "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/mediaapi" "github.com/matrix-org/dendrite/relayapi" relayAPI "github.com/matrix-org/dendrite/relayapi/api" @@ -45,7 +44,6 @@ type Monolith struct { FederationAPI federationAPI.FederationInternalAPI RoomserverAPI roomserverAPI.RoomserverInternalAPI UserAPI userapi.UserInternalAPI - KeyAPI keyAPI.KeyInternalAPI RelayAPI relayAPI.RelayInternalAPI // Optional @@ -61,19 +59,14 @@ func (m *Monolith) AddAllPublicRoutes(base *base.BaseDendrite) { } clientapi.AddPublicRoutes( base, m.FedClient, m.RoomserverAPI, m.AppserviceAPI, transactions.New(), - m.FederationAPI, m.UserAPI, userDirectoryProvider, m.KeyAPI, + m.FederationAPI, m.UserAPI, userDirectoryProvider, m.ExtPublicRoomsProvider, ) federationapi.AddPublicRoutes( - base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, - m.KeyAPI, nil, - ) - mediaapi.AddPublicRoutes( - base, m.UserAPI, m.Client, - ) - syncapi.AddPublicRoutes( - base, m.UserAPI, m.RoomserverAPI, m.KeyAPI, + base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, nil, ) + mediaapi.AddPublicRoutes(base, m.UserAPI, m.Client) + syncapi.AddPublicRoutes(base, m.UserAPI, m.RoomserverAPI) if m.RelayAPI != nil { relayapi.AddPublicRoutes(base, m.KeyRing, m.RelayAPI) diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 92f081500..5faaefb8e 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -19,7 +19,6 @@ import ( "encoding/json" "github.com/getsentry/sentry-go" - "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" @@ -28,6 +27,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go index 356e83263..32208c585 100644 --- a/syncapi/consumers/sendtodevice.go +++ b/syncapi/consumers/sendtodevice.go @@ -25,7 +25,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -33,6 +32,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" ) // OutputSendToDeviceEventConsumer consumes events that originated in the EDU server. @@ -42,7 +42,7 @@ type OutputSendToDeviceEventConsumer struct { durable string topic string db storage.Database - keyAPI keyapi.SyncKeyAPI + userAPI api.SyncKeyAPI isLocalServerName func(gomatrixserverlib.ServerName) bool stream streams.StreamProvider notifier *notifier.Notifier @@ -55,7 +55,7 @@ func NewOutputSendToDeviceEventConsumer( cfg *config.SyncAPI, js nats.JetStreamContext, store storage.Database, - keyAPI keyapi.SyncKeyAPI, + userAPI api.SyncKeyAPI, notifier *notifier.Notifier, stream streams.StreamProvider, ) *OutputSendToDeviceEventConsumer { @@ -65,7 +65,7 @@ func NewOutputSendToDeviceEventConsumer( topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), durable: cfg.Matrix.JetStream.Durable("SyncAPISendToDeviceConsumer"), db: store, - keyAPI: keyAPI, + userAPI: userAPI, isLocalServerName: cfg.Matrix.IsLocalServerName, notifier: notifier, stream: stream, @@ -116,7 +116,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs [] _, senderDomain, _ := gomatrixserverlib.SplitID('@', output.Sender) if requestingDeviceID != "" && !s.isLocalServerName(senderDomain) { // Mark the requesting device as stale, if we don't know about it. - if err = s.keyAPI.PerformMarkAsStaleIfNeeded(ctx, &keyapi.PerformMarkAsStaleRequest{ + if err = s.userAPI.PerformMarkAsStaleIfNeeded(ctx, &api.PerformMarkAsStaleRequest{ UserID: output.Sender, Domain: senderDomain, DeviceID: requestingDeviceID, }, &struct{}{}); err != nil { logger.WithError(err).Errorf("failed to mark as stale if needed") diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 4867f7d9e..e7f677c85 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -18,22 +18,22 @@ import ( "context" "strings" + keytypes "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - keyapi "github.com/matrix-org/dendrite/keyserver/api" - keytypes "github.com/matrix-org/dendrite/keyserver/types" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" ) // DeviceOTKCounts adds one-time key counts to the /sync response -func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error { - var queryRes keyapi.QueryOneTimeKeysResponse - _ = keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{ +func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceID string, res *types.Response) error { + var queryRes api.QueryOneTimeKeysResponse + _ = keyAPI.QueryOneTimeKeys(ctx, &api.QueryOneTimeKeysRequest{ UserID: userID, DeviceID: deviceID, }, &queryRes) @@ -48,7 +48,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, devi // was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST // be already filled in with join/leave information. func DeviceListCatchup( - ctx context.Context, db storage.SharedUsers, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI, + ctx context.Context, db storage.SharedUsers, userAPI api.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI, userID string, res *types.Response, from, to types.StreamPosition, ) (newPos types.StreamPosition, hasNew bool, err error) { @@ -74,8 +74,8 @@ func DeviceListCatchup( if from > 0 { offset = int64(from) } - var queryRes keyapi.QueryKeyChangesResponse - _ = keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{ + var queryRes api.QueryKeyChangesResponse + _ = userAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{ Offset: offset, ToOffset: toOffset, }, &queryRes) diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index d64bea11e..4bb851668 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -9,7 +9,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -22,44 +21,44 @@ var ( type mockKeyAPI struct{} -func (k *mockKeyAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *keyapi.PerformMarkAsStaleRequest, res *struct{}) error { +func (k *mockKeyAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *userapi.PerformMarkAsStaleRequest, res *struct{}) error { return nil } -func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) error { +func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *userapi.PerformUploadKeysRequest, res *userapi.PerformUploadKeysResponse) error { return nil } func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {} // PerformClaimKeys claims one-time keys for use in pre-key messages -func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) error { +func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *userapi.PerformClaimKeysRequest, res *userapi.PerformClaimKeysResponse) error { return nil } -func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) error { +func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *userapi.PerformDeleteKeysRequest, res *userapi.PerformDeleteKeysResponse) error { return nil } -func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) error { +func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *userapi.PerformUploadDeviceKeysRequest, res *userapi.PerformUploadDeviceKeysResponse) error { return nil } -func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) error { +func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *userapi.PerformUploadDeviceSignaturesRequest, res *userapi.PerformUploadDeviceSignaturesResponse) error { return nil } -func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) error { +func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *userapi.QueryKeysRequest, res *userapi.QueryKeysResponse) error { return nil } -func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error { +func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyChangesRequest, res *userapi.QueryKeyChangesResponse) error { return nil } -func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error { +func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error { return nil } -func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) error { +func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error { return nil } -func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) error { +func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *userapi.QuerySignaturesRequest, res *userapi.QuerySignaturesResponse) error { return nil } diff --git a/syncapi/streams/stream_devicelist.go b/syncapi/streams/stream_devicelist.go index 7996c2038..e8189c352 100644 --- a/syncapi/streams/stream_devicelist.go +++ b/syncapi/streams/stream_devicelist.go @@ -3,17 +3,17 @@ package streams import ( "context" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" ) type DeviceListStreamProvider struct { DefaultStreamProvider - rsAPI api.SyncRoomserverAPI - keyAPI keyapi.SyncKeyAPI + rsAPI api.SyncRoomserverAPI + userAPI userapi.SyncKeyAPI } func (p *DeviceListStreamProvider) CompleteSync( @@ -31,12 +31,12 @@ func (p *DeviceListStreamProvider) IncrementalSync( from, to types.StreamPosition, ) types.StreamPosition { var err error - to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) + to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.userAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) if err != nil { req.Log.WithError(err).Error("internal.DeviceListCatchup failed") return from } - err = internal.DeviceOTKCounts(req.Context, p.keyAPI, req.Device.UserID, req.Device.ID, req.Response) + err = internal.DeviceOTKCounts(req.Context, p.userAPI, req.Device.UserID, req.Device.ID, req.Response) if err != nil { req.Log.WithError(err).Error("internal.DeviceListCatchup failed") return from diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index dc8547621..a35491acf 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -5,7 +5,6 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" - keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" @@ -27,7 +26,7 @@ type Streams struct { func NewSyncStreamProviders( d storage.Database, userAPI userapi.SyncUserAPI, - rsAPI rsapi.SyncRoomserverAPI, keyAPI keyapi.SyncKeyAPI, + rsAPI rsapi.SyncRoomserverAPI, eduCache *caching.EDUCache, lazyLoadCache caching.LazyLoadCache, notifier *notifier.Notifier, ) *Streams { streams := &Streams{ @@ -60,7 +59,7 @@ func NewSyncStreamProviders( DeviceListStreamProvider: &DeviceListStreamProvider{ DefaultStreamProvider: DefaultStreamProvider{DB: d}, rsAPI: rsAPI, - keyAPI: keyAPI, + userAPI: userAPI, }, PresenceStreamProvider: &PresenceStreamProvider{ DefaultStreamProvider: DefaultStreamProvider{DB: d}, diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index b086567b8..68f913871 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -32,7 +32,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/sqlutil" - keyapi "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/internal" @@ -48,7 +47,6 @@ type RequestPool struct { db storage.Database cfg *config.SyncAPI userAPI userapi.SyncUserAPI - keyAPI keyapi.SyncKeyAPI rsAPI roomserverAPI.SyncRoomserverAPI lastseen *sync.Map presence *sync.Map @@ -69,7 +67,7 @@ type PresenceConsumer interface { // NewRequestPool makes a new RequestPool func NewRequestPool( db storage.Database, cfg *config.SyncAPI, - userAPI userapi.SyncUserAPI, keyAPI keyapi.SyncKeyAPI, + userAPI userapi.SyncUserAPI, rsAPI roomserverAPI.SyncRoomserverAPI, streams *streams.Streams, notifier *notifier.Notifier, producer PresencePublisher, consumer PresenceConsumer, enableMetrics bool, @@ -83,7 +81,6 @@ func NewRequestPool( db: db, cfg: cfg, userAPI: userAPI, - keyAPI: keyAPI, rsAPI: rsAPI, lastseen: &sync.Map{}, presence: &sync.Map{}, @@ -280,7 +277,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. // https://github.com/matrix-org/synapse/blob/29f06704b8871a44926f7c99e73cf4a978fb8e81/synapse/rest/client/sync.py#L276-L281 // Only try to get OTKs if the context isn't already done. if syncReq.Context.Err() == nil { - err = internal.DeviceOTKCounts(syncReq.Context, rp.keyAPI, syncReq.Device.UserID, syncReq.Device.ID, syncReq.Response) + err = internal.DeviceOTKCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Device.ID, syncReq.Response) if err != nil && err != context.Canceled { syncReq.Log.WithError(err).Warn("failed to get OTK counts") } @@ -551,7 +548,7 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), snapshot, syncReq, fromToken.PDUPosition, toToken.PDUPosition) _, _, err = internal.DeviceListCatchup( - req.Context(), snapshot, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID, + req.Context(), snapshot, rp.userAPI, rp.rsAPI, syncReq.Device.UserID, syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition, ) if err != nil { diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index be19310f2..153f7af5d 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/internal/caching" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" @@ -42,7 +41,6 @@ func AddPublicRoutes( base *base.BaseDendrite, userAPI userapi.SyncUserAPI, rsAPI api.SyncRoomserverAPI, - keyAPI keyapi.SyncKeyAPI, ) { cfg := &base.Cfg.SyncAPI @@ -55,7 +53,7 @@ func AddPublicRoutes( eduCache := caching.NewTypingCache() notifier := notifier.NewNotifier() - streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, base.Caches, notifier) + streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, base.Caches, notifier) notifier.SetCurrentPosition(streams.Latest(context.Background())) if err = notifier.Load(context.Background(), syncDB); err != nil { logrus.WithError(err).Panicf("failed to load notifier ") @@ -71,7 +69,7 @@ func AddPublicRoutes( userAPI, ) - requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics) + requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics) if err = presenceConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start presence consumer") @@ -117,7 +115,7 @@ func AddPublicRoutes( } sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer( - base.ProcessContext, cfg, js, syncDB, keyAPI, notifier, streams.SendToDeviceStreamProvider, + base.ProcessContext, cfg, js, syncDB, userAPI, notifier, streams.SendToDeviceStreamProvider, ) if err = sendToDeviceConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start send-to-device consumer") diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 643c30265..e748660f7 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -16,7 +16,6 @@ import ( "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/clientapi/producers" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" @@ -83,21 +82,18 @@ func (s *syncUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAc return nil } +func (s *syncUserAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyChangesRequest, res *userapi.QueryKeyChangesResponse) error { + return nil +} + +func (s *syncUserAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error { + return nil +} + func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error { return nil } -type syncKeyAPI struct { - keyapi.SyncKeyAPI -} - -func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error { - return nil -} -func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error { - return nil -} - func TestSyncAPIAccessTokens(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { testSyncAccessTokens(t, dbType) @@ -121,7 +117,7 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) msgs := toNATSMsgs(t, base, room.Events()...) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}) testrig.MustPublishMsgs(t, jsctx, msgs...) testCases := []struct { @@ -220,7 +216,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { // m.room.history_visibility msgs := toNATSMsgs(t, base, room.Events()...) sinceTokens := make([]string, len(msgs)) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}) for i, msg := range msgs { testrig.MustPublishMsgs(t, jsctx, msg) time.Sleep(100 * time.Millisecond) @@ -304,7 +300,7 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) { jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}) w := httptest.NewRecorder() base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ "access_token": alice.AccessToken, @@ -422,7 +418,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { rsAPI := roomserver.NewInternalAPI(base) rsAPI.SetFederationAPI(nil, nil) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI) for _, tc := range testCases { testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType) @@ -722,7 +718,7 @@ func TestGetMembership(t *testing.T) { rsAPI := roomserver.NewInternalAPI(base) rsAPI.SetFederationAPI(nil, nil) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -789,7 +785,7 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}) producer := producers.SyncAPIProducer{ TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), @@ -1008,7 +1004,7 @@ func testContext(t *testing.T, dbType test.DBType) { rsAPI := roomserver.NewInternalAPI(base) rsAPI.SetFederationAPI(nil, nil) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI) room := test.NewRoom(t, user) diff --git a/userapi/api/api.go b/userapi/api/api.go index 4ea2e91c3..fa297f773 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -15,9 +15,13 @@ package api import ( + "bytes" "context" "encoding/json" + "strings" + "time" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -26,15 +30,12 @@ import ( // UserInternalAPI is the internal API for information about users and devices. type UserInternalAPI interface { - AppserviceUserAPI SyncUserAPI ClientUserAPI - MediaUserAPI FederationUserAPI - RoomserverUserAPI - KeyserverUserAPI QuerySearchProfilesAPI // used by p2p demos + QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) } // api functions required by the appservice api @@ -43,11 +44,6 @@ type AppserviceUserAPI interface { PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error } -type KeyserverUserAPI interface { - QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error - QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error -} - type RoomserverUserAPI interface { QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) @@ -60,13 +56,20 @@ type MediaUserAPI interface { // api functions required by the federation api type FederationUserAPI interface { + UploadDeviceKeysAPI QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error + QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error + QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error } // api functions required by the sync api type SyncUserAPI interface { QueryAcccessTokenAPI + SyncKeyAPI QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error @@ -79,6 +82,7 @@ type ClientUserAPI interface { QueryAcccessTokenAPI LoginTokenInternalAPI UserLoginAPI + ClientKeyAPI QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error @@ -681,3 +685,310 @@ type QueryAccountByLocalpartRequest struct { type QueryAccountByLocalpartResponse struct { Account *Account } + +// API functions required by the clientapi +type ClientKeyAPI interface { + UploadDeviceKeysAPI + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error + + PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error + // PerformClaimKeys claims one-time keys for use in pre-key messages + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error + PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error +} + +type UploadDeviceKeysAPI interface { + PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error +} + +// API functions required by the syncapi +type SyncKeyAPI interface { + QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error + QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error + PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error +} + +type FederationKeyAPI interface { + UploadDeviceKeysAPI + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error + QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error +} + +// KeyError is returned if there was a problem performing/querying the server +type KeyError struct { + Err string `json:"error"` + IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE + IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM + IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM +} + +func (k *KeyError) Error() string { + return k.Err +} + +type DeviceMessageType int + +const ( + TypeDeviceKeyUpdate DeviceMessageType = iota + TypeCrossSigningUpdate +) + +// DeviceMessage represents the message produced into Kafka by the key server. +type DeviceMessage struct { + Type DeviceMessageType `json:"Type,omitempty"` + *DeviceKeys `json:"DeviceKeys,omitempty"` + *OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"` + // A monotonically increasing number which represents device changes for this user. + StreamID int64 + DeviceChangeID int64 +} + +// OutputCrossSigningKeyUpdate is an entry in the signing key update output kafka log +type OutputCrossSigningKeyUpdate struct { + CrossSigningKeyUpdate `json:"signing_keys"` +} + +type CrossSigningKeyUpdate struct { + MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"` + SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"` + UserID string `json:"user_id"` +} + +// DeviceKeysEqual returns true if the device keys updates contain the +// same display name and key JSON. This will return false if either of +// the updates is not a device keys update, or if the user ID/device ID +// differ between the two. +func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool { + if m1.DeviceKeys == nil || m2.DeviceKeys == nil { + return false + } + if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID { + return false + } + if m1.DisplayName != m2.DisplayName { + return false // different display names + } + if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 { + return false // either is empty + } + return bytes.Equal(m1.KeyJSON, m2.KeyJSON) +} + +// DeviceKeys represents a set of device keys for a single device +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload +type DeviceKeys struct { + // The user who owns this device + UserID string + // The device ID of this device + DeviceID string + // The device display name + DisplayName string + // The raw device key JSON + KeyJSON []byte +} + +// WithStreamID returns a copy of this device message with the given stream ID +func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage { + return DeviceMessage{ + DeviceKeys: k, + StreamID: streamID, + } +} + +// OneTimeKeys represents a set of one-time keys for a single device +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload +type OneTimeKeys struct { + // The user who owns this device + UserID string + // The device ID of this device + DeviceID string + // A map of algorithm:key_id => key JSON + KeyJSON map[string]json.RawMessage +} + +// Split a key in KeyJSON into algorithm and key ID +func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) { + segments := strings.Split(keyIDWithAlgo, ":") + return segments[0], segments[1] +} + +// OneTimeKeysCount represents the counts of one-time keys for a single device +type OneTimeKeysCount struct { + // The user who owns this device + UserID string + // The device ID of this device + DeviceID string + // algorithm to count e.g: + // { + // "curve25519": 10, + // "signed_curve25519": 20 + // } + KeyCount map[string]int +} + +// PerformUploadKeysRequest is the request to PerformUploadKeys +type PerformUploadKeysRequest struct { + UserID string // Required - User performing the request + DeviceID string // Optional - Device performing the request, for fetching OTK count + DeviceKeys []DeviceKeys + OneTimeKeys []OneTimeKeys + // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update + // the display name for their respective device, and NOT to modify the keys. The key + // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths. + // Without this flag, requests to modify device display names would delete device keys. + OnlyDisplayNameUpdates bool +} + +// PerformUploadKeysResponse is the response to PerformUploadKeys +type PerformUploadKeysResponse struct { + // A fatal error when processing e.g database failures + Error *KeyError + // A map of user_id -> device_id -> Error for tracking failures. + KeyErrors map[string]map[string]*KeyError + OneTimeKeyCounts []OneTimeKeysCount +} + +// PerformDeleteKeysRequest asks the keyserver to forget about certain +// keys, and signatures related to those keys. +type PerformDeleteKeysRequest struct { + UserID string + KeyIDs []gomatrixserverlib.KeyID +} + +// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest. +type PerformDeleteKeysResponse struct { + Error *KeyError +} + +// KeyError sets a key error field on KeyErrors +func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) { + if r.KeyErrors[userID] == nil { + r.KeyErrors[userID] = make(map[string]*KeyError) + } + r.KeyErrors[userID][deviceID] = err +} + +type PerformClaimKeysRequest struct { + // Map of user_id to device_id to algorithm name + OneTimeKeys map[string]map[string]string + Timeout time.Duration +} + +type PerformClaimKeysResponse struct { + // Map of user_id to device_id to algorithm:key_id to key JSON + OneTimeKeys map[string]map[string]map[string]json.RawMessage + // Map of remote server domain to error JSON + Failures map[string]interface{} + // Set if there was a fatal error processing this action + Error *KeyError +} + +type PerformUploadDeviceKeysRequest struct { + gomatrixserverlib.CrossSigningKeys + // The user that uploaded the key, should be populated by the clientapi. + UserID string +} + +type PerformUploadDeviceKeysResponse struct { + Error *KeyError +} + +type PerformUploadDeviceSignaturesRequest struct { + Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice + // The user that uploaded the sig, should be populated by the clientapi. + UserID string +} + +type PerformUploadDeviceSignaturesResponse struct { + Error *KeyError +} + +type QueryKeysRequest struct { + // The user ID asking for the keys, e.g. if from a client API request. + // Will not be populated if the key request came from federation. + UserID string + // Maps user IDs to a list of devices + UserToDevices map[string][]string + Timeout time.Duration +} + +type QueryKeysResponse struct { + // Map of remote server domain to error JSON + Failures map[string]interface{} + // Map of user_id to device_id to device_key + DeviceKeys map[string]map[string]json.RawMessage + // Maps of user_id to cross signing key + MasterKeys map[string]gomatrixserverlib.CrossSigningKey + SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey + UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey + // Set if there was a fatal error processing this query + Error *KeyError +} + +type QueryKeyChangesRequest struct { + // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning + Offset int64 + // The inclusive offset where to track key changes up to. Messages with this offset are included in the response. + // Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing). + ToOffset int64 +} + +type QueryKeyChangesResponse struct { + // The set of users who have had their keys change. + UserIDs []string + // The latest offset represented in this response. + Offset int64 + // Set if there was a problem handling the request. + Error *KeyError +} + +type QueryOneTimeKeysRequest struct { + // The local user to query OTK counts for + UserID string + // The device to query OTK counts for + DeviceID string +} + +type QueryOneTimeKeysResponse struct { + // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84 + Count OneTimeKeysCount + Error *KeyError +} + +type QueryDeviceMessagesRequest struct { + UserID string +} + +type QueryDeviceMessagesResponse struct { + // The latest stream ID + StreamID int64 + Devices []DeviceMessage + Error *KeyError +} + +type QuerySignaturesRequest struct { + // A map of target user ID -> target key/device IDs to retrieve signatures for + TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"` +} + +type QuerySignaturesResponse struct { + // A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures + Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap + // A map of target user ID -> cross-signing master key + MasterKeys map[string]gomatrixserverlib.CrossSigningKey + // A map of target user ID -> cross-signing self-signing key + SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey + // A map of target user ID -> cross-signing user-signing key + UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey + // The request error, if any + Error *KeyError +} + +type PerformMarkAsStaleRequest struct { + UserID string + Domain gomatrixserverlib.ServerName + DeviceID string +} diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go index 42ae72e77..51bd2753a 100644 --- a/userapi/consumers/clientapi.go +++ b/userapi/consumers/clientapi.go @@ -37,7 +37,7 @@ type OutputReceiptEventConsumer struct { jetstream nats.JetStreamContext durable string topic string - db storage.Database + db storage.UserDatabase serverName gomatrixserverlib.ServerName syncProducer *producers.SyncAPI pgClient pushgateway.Client @@ -49,7 +49,7 @@ func NewOutputReceiptEventConsumer( process *process.ProcessContext, cfg *config.UserAPI, js nats.JetStreamContext, - store storage.Database, + store storage.UserDatabase, syncProducer *producers.SyncAPI, pgClient pushgateway.Client, ) *OutputReceiptEventConsumer { diff --git a/keyserver/consumers/devicelistupdate.go b/userapi/consumers/devicelistupdate.go similarity index 97% rename from keyserver/consumers/devicelistupdate.go rename to userapi/consumers/devicelistupdate.go index cd911f8c6..a65889fcc 100644 --- a/keyserver/consumers/devicelistupdate.go +++ b/userapi/consumers/devicelistupdate.go @@ -18,11 +18,11 @@ import ( "context" "encoding/json" + "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/keyserver/internal" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -41,7 +41,7 @@ type DeviceListUpdateConsumer struct { // NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. func NewDeviceListUpdateConsumer( process *process.ProcessContext, - cfg *config.KeyServer, + cfg *config.UserAPI, js nats.JetStreamContext, updater *internal.DeviceListUpdater, ) *DeviceListUpdateConsumer { diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 3ce5af621..47d330959 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -38,7 +38,7 @@ type OutputRoomEventConsumer struct { rsAPI rsapi.UserRoomserverAPI jetstream nats.JetStreamContext durable string - db storage.Database + db storage.UserDatabase topic string pgClient pushgateway.Client syncProducer *producers.SyncAPI @@ -53,7 +53,7 @@ func NewOutputRoomEventConsumer( process *process.ProcessContext, cfg *config.UserAPI, js nats.JetStreamContext, - store storage.Database, + store storage.UserDatabase, pgClient pushgateway.Client, rsAPI rsapi.UserRoomserverAPI, syncProducer *producers.SyncAPI, diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 39f4aab4a..bc5ae652d 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -18,11 +18,11 @@ import ( userAPITypes "github.com/matrix-org/dendrite/userapi/types" ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { base, baseclose := testrig.CreateBaseDendrite(t, dbType) t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "", 4, 0, 0, "") if err != nil { diff --git a/keyserver/consumers/signingkeyupdate.go b/userapi/consumers/signingkeyupdate.go similarity index 87% rename from keyserver/consumers/signingkeyupdate.go rename to userapi/consumers/signingkeyupdate.go index bcceaad15..f4ff017db 100644 --- a/keyserver/consumers/signingkeyupdate.go +++ b/userapi/consumers/signingkeyupdate.go @@ -22,11 +22,10 @@ import ( "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" - keyapi "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/internal" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/userapi/api" ) // SigningKeyUpdateConsumer consumes signing key updates that came in over federation. @@ -35,24 +34,24 @@ type SigningKeyUpdateConsumer struct { jetstream nats.JetStreamContext durable string topic string - keyAPI *internal.KeyInternalAPI - cfg *config.KeyServer + userAPI api.UploadDeviceKeysAPI + cfg *config.UserAPI isLocalServerName func(gomatrixserverlib.ServerName) bool } // NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers. func NewSigningKeyUpdateConsumer( process *process.ProcessContext, - cfg *config.KeyServer, + cfg *config.UserAPI, js nats.JetStreamContext, - keyAPI *internal.KeyInternalAPI, + userAPI api.UploadDeviceKeysAPI, ) *SigningKeyUpdateConsumer { return &SigningKeyUpdateConsumer{ ctx: process.Context(), jetstream: js, durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"), topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), - keyAPI: keyAPI, + userAPI: userAPI, cfg: cfg, isLocalServerName: cfg.Matrix.IsLocalServerName, } @@ -70,7 +69,7 @@ func (t *SigningKeyUpdateConsumer) Start() error { // signing key update events topic from the key server. func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { msg := msgs[0] // Guaranteed to exist if onMessage is called - var updatePayload keyapi.CrossSigningKeyUpdate + var updatePayload api.CrossSigningKeyUpdate if err := json.Unmarshal(msg.Data, &updatePayload); err != nil { logrus.WithError(err).Errorf("Failed to read from signing key update input topic") return true @@ -94,12 +93,12 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M if updatePayload.SelfSigningKey != nil { keys.SelfSigningKey = *updatePayload.SelfSigningKey } - uploadReq := &keyapi.PerformUploadDeviceKeysRequest{ + uploadReq := &api.PerformUploadDeviceKeysRequest{ CrossSigningKeys: keys, UserID: updatePayload.UserID, } - uploadRes := &keyapi.PerformUploadDeviceKeysResponse{} - if err := t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil { + uploadRes := &api.PerformUploadDeviceKeysResponse{} + if err := t.userAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil { logrus.WithError(err).Error("failed to upload device keys") return false } diff --git a/keyserver/internal/cross_signing.go b/userapi/internal/cross_signing.go similarity index 91% rename from keyserver/internal/cross_signing.go rename to userapi/internal/cross_signing.go index 99859dff6..8b9704d1b 100644 --- a/keyserver/internal/cross_signing.go +++ b/userapi/internal/cross_signing.go @@ -22,8 +22,8 @@ import ( "fmt" "strings" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/types" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" "golang.org/x/crypto/curve25519" @@ -103,7 +103,7 @@ func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpos } // nolint:gocyclo -func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { +func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { // Find the keys to store. byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} toStore := types.CrossSigningKeyMap{} @@ -169,7 +169,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P // something if any of the specified keys in the request are different // to what we've got in the database, to avoid generating key change // notifications unnecessarily. - existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) + existingKeys, err := a.KeyDatabase.CrossSigningKeysDataForUser(ctx, req.UserID) if err != nil { res.Error = &api.KeyError{ Err: "Retrieving cross-signing keys from database failed: " + err.Error(), @@ -216,7 +216,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P } // Store the keys. - if err := a.DB.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil { + if err := a.KeyDatabase.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err), } @@ -234,7 +234,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P continue } for sigKeyID, sigBytes := range forSigUserID { - if err := a.DB.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil { + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err), } @@ -257,7 +257,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P if update.MasterKey == nil && update.SelfSigningKey == nil { return nil } - if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { + if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } @@ -266,7 +266,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P return nil } -func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error { +func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error { // Before we do anything, we need the master and self-signing keys for this user. // Then we can verify the signatures make sense. queryReq := &api.QueryKeysRequest{ @@ -342,7 +342,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req MasterKey: &masterKey, SelfSigningKey: &selfSigningKey, } - if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { + if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } @@ -352,7 +352,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req return nil } -func (a *KeyInternalAPI) processSelfSignatures( +func (a *UserInternalAPI) processSelfSignatures( ctx context.Context, signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice, ) error { @@ -373,7 +373,7 @@ func (a *KeyInternalAPI) processSelfSignatures( } for originUserID, forOriginUserID := range sig.Signatures { for originKeyID, originSig := range forOriginUserID { - if err := a.DB.StoreCrossSigningSigsForTarget( + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig, ); err != nil { return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) @@ -384,7 +384,7 @@ func (a *KeyInternalAPI) processSelfSignatures( case *gomatrixserverlib.DeviceKeys: for originUserID, forOriginUserID := range sig.Signatures { for originKeyID, originSig := range forOriginUserID { - if err := a.DB.StoreCrossSigningSigsForTarget( + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig, ); err != nil { return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) @@ -401,7 +401,7 @@ func (a *KeyInternalAPI) processSelfSignatures( return nil } -func (a *KeyInternalAPI) processOtherSignatures( +func (a *UserInternalAPI) processOtherSignatures( ctx context.Context, userID string, queryRes *api.QueryKeysResponse, signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice, ) error { @@ -442,7 +442,7 @@ func (a *KeyInternalAPI) processOtherSignatures( } for originKeyID, originSig := range userSigs { - if err := a.DB.StoreCrossSigningSigsForTarget( + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( ctx, userID, originKeyID, targetUserID, targetKeyID, originSig, ); err != nil { return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) @@ -461,11 +461,11 @@ func (a *KeyInternalAPI) processOtherSignatures( return nil } -func (a *KeyInternalAPI) crossSigningKeysFromDatabase( +func (a *UserInternalAPI) crossSigningKeysFromDatabase( ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse, ) { for targetUserID := range req.UserToDevices { - keys, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID) + keys, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID) if err != nil { logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID) continue @@ -478,7 +478,7 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase( break } - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID) + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID) if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID) continue @@ -522,9 +522,9 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase( } } -func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error { +func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error { for targetUserID, forTargetUser := range req.TargetIDs { - keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID) + keyMap, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID) if err != nil && err != sql.ErrNoRows { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.CrossSigningKeysForUser: %s", err), @@ -556,7 +556,7 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign for _, targetKeyID := range forTargetUser { // Get own signatures only. - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID) + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID) if err != nil && err != sql.ErrNoRows { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err), diff --git a/keyserver/internal/device_list_update.go b/userapi/internal/device_list_update.go similarity index 99% rename from keyserver/internal/device_list_update.go rename to userapi/internal/device_list_update.go index 1b00d1eec..3b4dcf98e 100644 --- a/keyserver/internal/device_list_update.go +++ b/userapi/internal/device_list_update.go @@ -33,8 +33,8 @@ import ( "github.com/sirupsen/logrus" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/userapi/api" ) var ( diff --git a/keyserver/internal/device_list_update_default.go b/userapi/internal/device_list_update_default.go similarity index 100% rename from keyserver/internal/device_list_update_default.go rename to userapi/internal/device_list_update_default.go diff --git a/keyserver/internal/device_list_update_sytest.go b/userapi/internal/device_list_update_sytest.go similarity index 100% rename from keyserver/internal/device_list_update_sytest.go rename to userapi/internal/device_list_update_sytest.go diff --git a/keyserver/internal/device_list_update_test.go b/userapi/internal/device_list_update_test.go similarity index 98% rename from keyserver/internal/device_list_update_test.go rename to userapi/internal/device_list_update_test.go index 60a2c2f30..868fc9be8 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/userapi/internal/device_list_update_test.go @@ -29,13 +29,13 @@ import ( "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage" roomserver "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" ) var ( @@ -360,12 +360,12 @@ func TestDebounce(t *testing.T) { } } -func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { t.Helper() base, _, _ := testrig.Base(nil) connStr, clearDB := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) + db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) if err != nil { t.Fatal(err) } diff --git a/keyserver/internal/internal.go b/userapi/internal/key_api.go similarity index 83% rename from keyserver/internal/internal.go rename to userapi/internal/key_api.go index 9a08a0bb7..be816fe5d 100644 --- a/keyserver/internal/internal.go +++ b/userapi/internal/key_api.go @@ -29,29 +29,11 @@ import ( "github.com/tidwall/gjson" "github.com/tidwall/sjson" - fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/producers" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/setup/config" - userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/api" ) -type KeyInternalAPI struct { - DB storage.Database - Cfg *config.KeyServer - FedClient fedsenderapi.KeyserverFederationAPI - UserAPI userapi.KeyserverUserAPI - Producer *producers.KeyChange - Updater *DeviceListUpdater -} - -func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) { - a.UserAPI = i -} - -func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error { - userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset) +func (a *UserInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error { + userIDs, latest, err := a.KeyDatabase.KeyChanges(ctx, req.Offset, req.ToOffset) if err != nil { res.Error = &api.KeyError{ Err: err.Error(), @@ -63,7 +45,7 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC return nil } -func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error { +func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error { res.KeyErrors = make(map[string]map[string]*api.KeyError) if len(req.DeviceKeys) > 0 { a.uploadLocalDeviceKeys(ctx, req, res) @@ -71,7 +53,7 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform if len(req.OneTimeKeys) > 0 { a.uploadOneTimeKeys(ctx, req, res) } - otks, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { return err } @@ -79,7 +61,7 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform return nil } -func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error { +func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error { res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage) res.Failures = make(map[string]interface{}) // wrap request map in a top-level by-domain map @@ -97,11 +79,11 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC domainToDeviceKeys[string(serverName)] = nested } for domain, local := range domainToDeviceKeys { - if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + if !a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { continue } // claim local keys - keys, err := a.DB.ClaimKeys(ctx, local) + keys, err := a.KeyDatabase.ClaimKeys(ctx, local) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err), @@ -129,7 +111,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC return nil } -func (a *KeyInternalAPI) claimRemoteKeys( +func (a *UserInternalAPI) claimRemoteKeys( ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string, ) { var wg sync.WaitGroup // Wait for fan-out goroutines to finish @@ -146,7 +128,7 @@ func (a *KeyInternalAPI) claimRemoteKeys( defer cancel() defer wg.Done() - claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim) + claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim) mu.Lock() defer mu.Unlock() @@ -177,8 +159,8 @@ func (a *KeyInternalAPI) claimRemoteKeys( }).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys)) } -func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error { - if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil { +func (a *UserInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error { + if err := a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("Failed to delete device keys: %s", err), } @@ -186,8 +168,8 @@ func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.Perform return nil } -func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error { - count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) +func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error { + count, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("Failed to query OTK counts: %s", err), @@ -198,8 +180,8 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne return nil } -func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error { - msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false) +func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error { + msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query DB for device keys: %s", err), @@ -225,8 +207,8 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query // PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present // in our database. -func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error { - knownDevices, err := a.DB.DeviceKeysForUser(ctx, req.UserID, []string{}, true) +func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error { + knownDevices, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, []string{}, true) if err != nil { return err } @@ -244,7 +226,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap } // nolint:gocyclo -func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { +func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { var respMu sync.Mutex res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey) @@ -262,8 +244,8 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } domain := string(serverName) // query local devices - if a.Cfg.Matrix.IsLocalServerName(serverName) { - deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) + if a.Config.Matrix.IsLocalServerName(serverName) { + deviceKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query local device keys: %s", err), @@ -276,8 +258,8 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques for _, dk := range deviceKeys { dids = append(dids, dk.DeviceID) } - var queryRes userapi.QueryDeviceInfosResponse - err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{ + var queryRes api.QueryDeviceInfosResponse + err = a.QueryDeviceInfos(ctx, &api.QueryDeviceInfosRequest{ DeviceIDs: dids, }, &queryRes) if err != nil { @@ -341,14 +323,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} } for targetKeyID := range masterKey.Keys { - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID) + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID) if err != nil { // Stop executing the function if the context was canceled/the deadline was exceeded, // as we can't continue without a valid context. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil } - logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") + logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed") continue } if len(sigMap) == 0 { @@ -367,14 +349,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques for targetUserID, forUserID := range res.DeviceKeys { for targetKeyID, key := range forUserID { - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID)) + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID)) if err != nil { // Stop executing the function if the context was canceled/the deadline was exceeded, // as we can't continue without a valid context. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil } - logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") + logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed") continue } if len(sigMap) == 0 { @@ -403,7 +385,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques return nil } -func (a *KeyInternalAPI) remoteKeysFromDatabase( +func (a *UserInternalAPI) remoteKeysFromDatabase( ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string, ) map[string]map[string][]string { fetchRemote := make(map[string]map[string][]string) @@ -429,7 +411,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase( return fetchRemote } -func (a *KeyInternalAPI) queryRemoteKeys( +func (a *UserInternalAPI) queryRemoteKeys( ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{}, ) { @@ -441,13 +423,13 @@ func (a *KeyInternalAPI) queryRemoteKeys( domains := map[string]struct{}{} for domain := range domainToDeviceKeys { - if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { continue } domains[domain] = struct{}{} } for domain := range domainToCrossSigningKeys { - if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { continue } domains[domain] = struct{}{} @@ -499,7 +481,7 @@ func (a *KeyInternalAPI) queryRemoteKeys( } } -func (a *KeyInternalAPI) queryRemoteKeysOnServer( +func (a *UserInternalAPI) queryRemoteKeysOnServer( ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{}, wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys, res *api.QueryKeysResponse, @@ -559,7 +541,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( if len(devKeys) == 0 { return } - queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys) + queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys) if err == nil { resultCh <- &queryKeysResp return @@ -586,10 +568,10 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( respMu.Unlock() } -func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( +func (a *UserInternalAPI) populateResponseWithDeviceKeysFromDatabase( ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string, ) error { - keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) + keys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false) // if we can't query the db or there are fewer keys than requested, fetch from remote. if err != nil { return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err) @@ -621,11 +603,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( return nil } -func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { +func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { // get a list of devices from the user API that actually exist, as // we won't store keys for devices that don't exist - uapidevices := &userapi.QueryDevicesResponse{} - if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil { + uapidevices := &api.QueryDevicesResponse{} + if err := a.QueryDevices(ctx, &api.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil { res.Error = &api.KeyError{ Err: err.Error(), } @@ -643,7 +625,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per } // Get all of the user existing device keys so we can check for changes. - existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true) + existingKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, true) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), @@ -662,7 +644,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per } if len(toClean) > 0 { - if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { + if err = a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean)) } else { logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean)) @@ -693,7 +675,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per if err != nil { continue // ignore invalid users } - if !a.Cfg.Matrix.IsLocalServerName(serverName) { + if !a.Config.Matrix.IsLocalServerName(serverName) { continue // ignore remote users } if len(key.KeyJSON) == 0 { @@ -722,30 +704,30 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per } // store the device keys and emit changes - err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore) + err = a.KeyDatabase.StoreLocalDeviceKeys(ctx, keysToStore) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), } return } - err = emitDeviceKeyChanges(a.Producer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates) + err = emitDeviceKeyChanges(a.KeyChangeProducer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates) if err != nil { util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err) } } -func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { +func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { if req.UserID == "" { res.Error = &api.KeyError{ Err: "user ID missing", } } if req.DeviceID != "" && len(req.OneTimeKeys) == 0 { - counts, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + counts, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.DB.OneTimeKeysCount: %s", err), + Err: fmt.Sprintf("a.KeyDatabase.OneTimeKeysCount: %s", err), } } if counts != nil { @@ -761,7 +743,7 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform keyIDsWithAlgorithms[i] = keyIDWithAlgo i++ } - existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms) + existingKeys, err := a.KeyDatabase.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms) if err != nil { res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ Err: "failed to query existing one-time keys: " + err.Error(), @@ -778,7 +760,7 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform } } // store one-time keys - counts, err := a.DB.StoreOneTimeKeys(ctx, key) + counts, err := a.KeyDatabase.StoreOneTimeKeys(ctx, key) if err != nil { res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()), diff --git a/keyserver/internal/internal_test.go b/userapi/internal/key_api_test.go similarity index 89% rename from keyserver/internal/internal_test.go rename to userapi/internal/key_api_test.go index 8a2c9c5d9..fc7e7e0df 100644 --- a/keyserver/internal/internal_test.go +++ b/userapi/internal/key_api_test.go @@ -5,23 +5,28 @@ import ( "reflect" "testing" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/internal" - "github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/internal" + "github.com/matrix-org/dendrite/userapi/storage" ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewDatabase(nil, &config.DatabaseOptions{ + base, _, _ := testrig.Base(nil) + db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }) if err != nil { t.Fatalf("failed to create new user db: %v", err) } - return db, close + return db, func() { + base.Close() + close() + } } func Test_QueryDeviceMessages(t *testing.T) { @@ -140,8 +145,8 @@ func Test_QueryDeviceMessages(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - a := &internal.KeyInternalAPI{ - DB: db, + a := &internal.UserInternalAPI{ + KeyDatabase: db, } if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr { t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr) diff --git a/userapi/internal/api.go b/userapi/internal/user_api.go similarity index 97% rename from userapi/internal/api.go rename to userapi/internal/user_api.go index 0bb480da6..1cbd97190 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/user_api.go @@ -23,6 +23,7 @@ import ( "strconv" "time" + fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -32,7 +33,6 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/sqlutil" - keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" synctypes "github.com/matrix-org/dendrite/syncapi/types" @@ -44,17 +44,19 @@ import ( ) type UserInternalAPI struct { - DB storage.Database - SyncProducer *producers.SyncAPI - Config *config.UserAPI + DB storage.UserDatabase + KeyDatabase storage.KeyDatabase + SyncProducer *producers.SyncAPI + KeyChangeProducer *producers.KeyChange + Config *config.UserAPI DisableTLSValidation bool // AppServices is the list of all registered AS AppServices []config.ApplicationService - KeyAPI keyapi.UserKeyAPI RSAPI rsapi.UserRoomserverAPI PgClient pushgateway.Client - Cfg *config.UserAPI + FedClient fedsenderapi.KeyserverFederationAPI + Updater *DeviceListUpdater } func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { @@ -221,7 +223,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return fmt.Errorf("a.DB.SetDisplayName: %w", err) } - postRegisterJoinRooms(a.Cfg, acc, a.RSAPI) + postRegisterJoinRooms(a.Config, acc, a.RSAPI) res.AccountCreated = true res.Account = acc @@ -293,14 +295,14 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe return err } // Ask the keyserver to delete device keys and signatures for those devices - deleteReq := &keyapi.PerformDeleteKeysRequest{ + deleteReq := &api.PerformDeleteKeysRequest{ UserID: req.UserID, } for _, keyID := range req.DeviceIDs { deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID)) } - deleteRes := &keyapi.PerformDeleteKeysResponse{} - if err := a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil { + deleteRes := &api.PerformDeleteKeysResponse{} + if err := a.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil { return err } if err := deleteRes.Error; err != nil { @@ -311,17 +313,17 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe } func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error { - deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs)) + deviceKeys := make([]api.DeviceKeys, len(deviceIDs)) for i, did := range deviceIDs { - deviceKeys[i] = keyapi.DeviceKeys{ + deviceKeys[i] = api.DeviceKeys{ UserID: userID, DeviceID: did, KeyJSON: nil, } } - var uploadRes keyapi.PerformUploadKeysResponse - if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ + var uploadRes api.PerformUploadKeysResponse + if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{ UserID: userID, DeviceKeys: deviceKeys, }, &uploadRes); err != nil { @@ -385,10 +387,10 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf } if req.DisplayName != nil && dev.DisplayName != *req.DisplayName { // display name has changed: update the device key - var uploadRes keyapi.PerformUploadKeysResponse - if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ + var uploadRes api.PerformUploadKeysResponse + if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{ UserID: req.RequestingUserID, - DeviceKeys: []keyapi.DeviceKeys{ + DeviceKeys: []api.DeviceKeys{ { DeviceID: dev.ID, DisplayName: *req.DisplayName, diff --git a/keyserver/producers/keychange.go b/userapi/producers/keychange.go similarity index 94% rename from keyserver/producers/keychange.go rename to userapi/producers/keychange.go index f86c34177..da6cea31a 100644 --- a/keyserver/producers/keychange.go +++ b/userapi/producers/keychange.go @@ -18,9 +18,9 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) @@ -28,8 +28,8 @@ import ( // KeyChange produces key change events for the sync API and federation sender to consume type KeyChange struct { Topic string - JetStream nats.JetStreamContext - DB storage.Database + JetStream JetStreamPublisher + DB storage.KeyChangeDatabase } // ProduceKeyChanges creates new change events for each key diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go index 51eaa9856..165de8994 100644 --- a/userapi/producers/syncapi.go +++ b/userapi/producers/syncapi.go @@ -19,13 +19,13 @@ type JetStreamPublisher interface { // SyncAPI produces messages for the Sync API server to consume. type SyncAPI struct { - db storage.Database + db storage.Notification producer JetStreamPublisher clientDataTopic string notificationDataTopic string } -func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI { +func NewSyncAPI(db storage.UserDatabase, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI { return &SyncAPI{ db: db, producer: js, diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index c22b7658f..278378861 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -90,7 +90,7 @@ type KeyBackup interface { type LoginToken interface { // CreateLoginToken generates a token, stores and returns it. The lifetime is - // determined by the loginTokenLifetime given to the Database constructor. + // determined by the loginTokenLifetime given to the UserDatabase constructor. CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) // RemoveLoginToken removes the named token (and may clean up other expired tokens). @@ -130,7 +130,7 @@ type Notification interface { DeleteOldNotifications(ctx context.Context) error } -type Database interface { +type UserDatabase interface { Account AccountData Device @@ -144,6 +144,78 @@ type Database interface { ThreePID } +type KeyChangeDatabase interface { + // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change. + // `userID` is the the user who has changed their keys in some way. + StoreKeyChange(ctx context.Context, userID string) (int64, error) +} + +type KeyDatabase interface { + KeyChangeDatabase + // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination + // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database. + ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + + // StoreOneTimeKeys persists the given one-time keys. + StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) + + // OneTimeKeysCount returns a count of all OTKs for this device. + OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) + + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. + DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + + // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). + // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set. + // Returns an error if there was a problem storing the keys. + StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + + // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior + // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. + StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error + + // PrevIDsExists returns true if all prev IDs exist for this user. + PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) + + // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. + // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. + DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) + + // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying + // cross-signing signatures relating to that device. + DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error + + // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key + // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. + ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) + + // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive). + // A to offset of types.OffsetNewest means no upper limit. + // Returns the offset of the latest key change. + KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) + + // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. + // If no domains are given, all user IDs with stale device lists are returned. + StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + + // MarkDeviceListStale sets the stale bit for this user to isStale. + MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error + + CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) + CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) + CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) + + StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error + StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + + DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, + ) error +} + type Statistics interface { UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) DailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 2a4777d74..057160374 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -78,7 +78,13 @@ func (s *accountDataStatements) InsertAccountData( roomID, dataType string, content json.RawMessage, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) - _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content) + // Empty/nil json.RawMessage is not interpreted as "nil", so use *json.RawMessage + // when passing the data to trigger "NOT NULL" constraint + var data *json.RawMessage + if len(content) > 0 { + data = &content + } + _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, data) return } diff --git a/keyserver/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go similarity index 96% rename from keyserver/storage/postgres/cross_signing_keys_table.go rename to userapi/storage/postgres/cross_signing_keys_table.go index 1022157e8..c0ecbd303 100644 --- a/keyserver/storage/postgres/cross_signing_keys_table.go +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -21,8 +21,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/keyserver/storage/postgres/cross_signing_sigs_table.go b/userapi/storage/postgres/cross_signing_sigs_table.go similarity index 96% rename from keyserver/storage/postgres/cross_signing_sigs_table.go rename to userapi/storage/postgres/cross_signing_sigs_table.go index 4536b7d80..b0117145c 100644 --- a/keyserver/storage/postgres/cross_signing_sigs_table.go +++ b/userapi/storage/postgres/cross_signing_sigs_table.go @@ -21,9 +21,9 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go b/userapi/storage/postgres/deltas/2022012016470000_key_changes.go similarity index 100% rename from keyserver/storage/postgres/deltas/2022012016470000_key_changes.go rename to userapi/storage/postgres/deltas/2022012016470000_key_changes.go diff --git a/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go similarity index 100% rename from keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go rename to userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go diff --git a/keyserver/storage/postgres/device_keys_table.go b/userapi/storage/postgres/device_keys_table.go similarity index 87% rename from keyserver/storage/postgres/device_keys_table.go rename to userapi/storage/postgres/device_keys_table.go index 2aa11c520..a9203857f 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/userapi/storage/postgres/device_keys_table.go @@ -23,8 +23,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var deviceKeysSchema = ` @@ -92,31 +92,16 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if err != nil { return nil, err } - if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { - return nil, err - } - if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { - return nil, err - } - if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil { - return nil, err - } - if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil { - return nil, err - } - if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL}, + {&s.selectDeviceKeysStmt, selectDeviceKeysSQL}, + {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL}, + {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL}, + {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL}, + {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, + {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL}, + {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL}, + }.Prepare(db) } func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 7481ac5bb..88f8839c5 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -160,7 +160,7 @@ func (s *devicesStatements) InsertDevice( if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil { return nil, fmt.Errorf("insertDeviceStmt: %w", err) } - return &api.Device{ + dev := &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, @@ -168,7 +168,11 @@ func (s *devicesStatements) InsertDevice( LastSeenTS: createdTimeMS, LastSeenIP: ipAddr, UserAgent: userAgent, - }, nil + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil } func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, diff --git a/userapi/storage/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go index 7b58f7bae..91a34c357 100644 --- a/userapi/storage/postgres/key_backup_table.go +++ b/userapi/storage/postgres/key_backup_table.go @@ -52,7 +52,7 @@ const updateBackupKeySQL = "" + const countKeysSQL = "" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" -const selectKeysSQL = "" + +const selectBackupKeysSQL = "" + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" @@ -83,7 +83,7 @@ func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, - {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysStmt, selectBackupKeysSQL}, {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, }.Prepare(db) diff --git a/keyserver/storage/postgres/key_changes_table.go b/userapi/storage/postgres/key_changes_table.go similarity index 90% rename from keyserver/storage/postgres/key_changes_table.go rename to userapi/storage/postgres/key_changes_table.go index c0e3429c7..a00494140 100644 --- a/keyserver/storage/postgres/key_changes_table.go +++ b/userapi/storage/postgres/key_changes_table.go @@ -22,8 +22,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var keyChangesSchema = ` @@ -66,7 +66,10 @@ func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { if err = executeMigration(context.Background(), db); err != nil { return nil, err } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertKeyChangeStmt, upsertKeyChangeSQL}, + {&s.selectKeyChangesStmt, selectKeyChangesSQL}, + }.Prepare(db) } func executeMigration(ctx context.Context, db *sql.DB) error { @@ -95,16 +98,6 @@ func executeMigration(ctx context.Context, db *sql.DB) error { return m.Up(ctx) } -func (s *keyChangesStatements) Prepare() (err error) { - if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil { - return err - } - if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil { - return err - } - return nil -} - func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) return diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/userapi/storage/postgres/one_time_keys_table.go similarity index 89% rename from keyserver/storage/postgres/one_time_keys_table.go rename to userapi/storage/postgres/one_time_keys_table.go index 2117efcae..972a59147 100644 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ b/userapi/storage/postgres/one_time_keys_table.go @@ -23,8 +23,8 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var oneTimeKeysSchema = ` @@ -49,7 +49,7 @@ const upsertKeysSQL = "" + " ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" + " DO UPDATE SET key_json = $6" -const selectKeysSQL = "" + +const selectOneTimeKeysSQL = "" + "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);" const selectKeysCountSQL = "" + @@ -84,25 +84,14 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { if err != nil { return nil, err } - if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil { - return nil, err - } - if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { - return nil, err - } - if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { - return nil, err - } - if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertKeysStmt, upsertKeysSQL}, + {&s.selectKeysStmt, selectOneTimeKeysSQL}, + {&s.selectKeysCountStmt, selectKeysCountSQL}, + {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL}, + {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL}, + {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL}, + }.Prepare(db) } func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { diff --git a/keyserver/storage/postgres/stale_device_lists.go b/userapi/storage/postgres/stale_device_lists.go similarity index 98% rename from keyserver/storage/postgres/stale_device_lists.go rename to userapi/storage/postgres/stale_device_lists.go index 248ddfb45..c823b58c6 100644 --- a/keyserver/storage/postgres/stale_device_lists.go +++ b/userapi/storage/postgres/stale_device_lists.go @@ -24,7 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 92dc48081..673d123ba 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -136,3 +136,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, }, nil } + +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) + if err != nil { + return nil, err + } + otk, err := NewPostgresOneTimeKeysTable(db) + if err != nil { + return nil, err + } + dk, err := NewPostgresDeviceKeysTable(db) + if err != nil { + return nil, err + } + kc, err := NewPostgresKeyChangesTable(db) + if err != nil { + return nil, err + } + sdl, err := NewPostgresStaleDeviceListsTable(db) + if err != nil { + return nil, err + } + csk, err := NewPostgresCrossSigningKeysTable(db) + if err != nil { + return nil, err + } + css, err := NewPostgresCrossSigningSigsTable(db) + if err != nil { + return nil, err + } + + return &shared.KeyDatabase{ + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, + CrossSigningKeysTable: csk, + CrossSigningSigsTable: css, + Writer: writer, + }, nil +} diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index bf94f14d6..d3272a032 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -59,6 +59,17 @@ type Database struct { OpenIDTokenLifetimeMS int64 } +type KeyDatabase struct { + OneTimeKeysTable tables.OneTimeKeys + DeviceKeysTable tables.DeviceKeys + KeyChangesTable tables.KeyChanges + StaleDeviceListsTable tables.StaleDeviceLists + CrossSigningKeysTable tables.CrossSigningKeys + CrossSigningSigsTable tables.CrossSigningSigs + DB *sql.DB + Writer sqlutil.Writer +} + const ( // The length of generated device IDs deviceIDByteLength = 6 @@ -875,3 +886,227 @@ func (d *Database) DailyRoomsMessages( ) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) { return d.Stats.DailyRoomsMessages(ctx, nil, serverName) } + +// + +func (d *KeyDatabase) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms) +} + +func (d *KeyDatabase) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) { + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys) + return err + }) + return +} + +func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) +} + +func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { + return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) +} + +func (d *KeyDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { + count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs) + if err != nil { + return false, err + } + return count == len(prevIDs), nil +} + +func (d *KeyDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for _, userID := range clearUserIDs { + err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID) + if err != nil { + return err + } + } + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) +} + +func (d *KeyDatabase) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { + // work out the latest stream IDs for each user + userIDToStreamID := make(map[string]int64) + for _, k := range keys { + userIDToStreamID[k.UserID] = 0 + } + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for userID := range userIDToStreamID { + streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID) + if err != nil { + return err + } + userIDToStreamID[userID] = streamID + } + // set the stream IDs for each key + for i := range keys { + k := keys[i] + userIDToStreamID[k.UserID]++ // start stream from 1 + k.StreamID = userIDToStreamID[k.UserID] + keys[i] = k + } + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) +} + +func (d *KeyDatabase) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty) +} + +func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { + var result []api.OneTimeKeys + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for userID, deviceToAlgo := range userToDeviceToAlgorithm { + for deviceID, algo := range deviceToAlgo { + keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo) + if err != nil { + return err + } + if keyJSON != nil { + result = append(result, api.OneTimeKeys{ + UserID: userID, + DeviceID: deviceID, + KeyJSON: keyJSON, + }) + } + } + } + return nil + }) + return result, err +} + +func (d *KeyDatabase) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) { + err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID) + return err + }) + return +} + +func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { + return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset) +} + +// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. +// If no domains are given, all user IDs with stale device lists are returned. +func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) +} + +// MarkDeviceListStale sets the stale bit for this user to isStale. +func (d *KeyDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { + return d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) + }) +} + +// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying +// cross-signing signatures relating to that device. +func (d *KeyDatabase) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for _, deviceID := range deviceIDs { + if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err) + } + if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err) + } + if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err) + } + } + return nil + }) +} + +// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. +func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) { + keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) + if err != nil { + return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err) + } + results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} + for purpose, key := range keyMap { + keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode()) + result := gomatrixserverlib.CrossSigningKey{ + UserID: userID, + Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose}, + Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ + keyID: key, + }, + } + sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID) + if err != nil { + continue + } + for sigUserID, forSigUserID := range sigMap { + if userID != sigUserID { + continue + } + if result.Signatures == nil { + result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + if _, ok := result.Signatures[sigUserID]; !ok { + result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + for sigKeyID, sigBytes := range forSigUserID { + result.Signatures[sigUserID][sigKeyID] = sigBytes + } + } + results[purpose] = result + } + return results, nil +} + +// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. +func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) { + return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) +} + +// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. +func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { + return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID) +} + +// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. +func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for keyType, keyData := range keyMap { + if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil { + return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err) + } + } + return nil + }) +} + +// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice. +func (d *KeyDatabase) StoreCrossSigningSigsForTarget( + ctx context.Context, + originUserID string, originKeyID gomatrixserverlib.KeyID, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, + signature gomatrixserverlib.Base64Bytes, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err) + } + return nil + }) +} + +// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore. +func (d *KeyDatabase) DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs) + }) +} diff --git a/keyserver/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go similarity index 96% rename from keyserver/storage/sqlite3/cross_signing_keys_table.go rename to userapi/storage/sqlite3/cross_signing_keys_table.go index e103d9883..10721fcc8 100644 --- a/keyserver/storage/sqlite3/cross_signing_keys_table.go +++ b/userapi/storage/sqlite3/cross_signing_keys_table.go @@ -21,8 +21,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/keyserver/storage/sqlite3/cross_signing_sigs_table.go b/userapi/storage/sqlite3/cross_signing_sigs_table.go similarity index 96% rename from keyserver/storage/sqlite3/cross_signing_sigs_table.go rename to userapi/storage/sqlite3/cross_signing_sigs_table.go index 7a153e8fb..2be00c9c1 100644 --- a/keyserver/storage/sqlite3/cross_signing_sigs_table.go +++ b/userapi/storage/sqlite3/cross_signing_sigs_table.go @@ -21,9 +21,9 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go b/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go similarity index 100% rename from keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go rename to userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go diff --git a/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go similarity index 100% rename from keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go rename to userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/userapi/storage/sqlite3/device_keys_table.go similarity index 88% rename from keyserver/storage/sqlite3/device_keys_table.go rename to userapi/storage/sqlite3/device_keys_table.go index 73768da5b..15e69cc4c 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/userapi/storage/sqlite3/device_keys_table.go @@ -22,8 +22,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var deviceKeysSchema = ` @@ -86,28 +86,16 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if err != nil { return nil, err } - if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { - return nil, err - } - if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { - return nil, err - } - if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil { - return nil, err - } - if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL}, + {&s.selectDeviceKeysStmt, selectDeviceKeysSQL}, + {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL}, + {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL}, + {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL}, + // {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, // prepared at runtime + {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL}, + {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL}, + }.Prepare(db) } func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 449e4549a..65e17527d 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -151,7 +151,7 @@ func (s *devicesStatements) InsertDevice( if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { return nil, err } - return &api.Device{ + dev := &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, @@ -159,7 +159,11 @@ func (s *devicesStatements) InsertDevice( LastSeenTS: createdTimeMS, LastSeenIP: ipAddr, UserAgent: userAgent, - }, nil + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil } func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, @@ -172,7 +176,7 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn * if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { return nil, err } - return &api.Device{ + dev := &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, @@ -180,7 +184,11 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn * LastSeenTS: createdTimeMS, LastSeenIP: ipAddr, UserAgent: userAgent, - }, nil + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil } func (s *devicesStatements) DeleteDevice( @@ -202,6 +210,7 @@ func (s *devicesStatements) DeleteDevices( if err != nil { return err } + defer internal.CloseAndLogIfError(ctx, prep, "DeleteDevices.StmtClose() failed") stmt := sqlutil.TxStmt(txn, prep) params := make([]interface{}, len(devices)+2) params[0] = localpart diff --git a/userapi/storage/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go index 7883ffb19..ed2746310 100644 --- a/userapi/storage/sqlite3/key_backup_table.go +++ b/userapi/storage/sqlite3/key_backup_table.go @@ -52,7 +52,7 @@ const updateBackupKeySQL = "" + const countKeysSQL = "" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" -const selectKeysSQL = "" + +const selectBackupKeysSQL = "" + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" @@ -83,7 +83,7 @@ func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, - {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysStmt, selectBackupKeysSQL}, {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, }.Prepare(db) diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/userapi/storage/sqlite3/key_changes_table.go similarity index 90% rename from keyserver/storage/sqlite3/key_changes_table.go rename to userapi/storage/sqlite3/key_changes_table.go index 0c844d67a..923bb57eb 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/userapi/storage/sqlite3/key_changes_table.go @@ -22,8 +22,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var keyChangesSchema = ` @@ -65,7 +65,10 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { return nil, err } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertKeyChangeStmt, upsertKeyChangeSQL}, + {&s.selectKeyChangesStmt, selectKeyChangesSQL}, + }.Prepare(db) } func executeMigration(ctx context.Context, db *sql.DB) error { @@ -93,16 +96,6 @@ func executeMigration(ctx context.Context, db *sql.DB) error { return m.Up(ctx) } -func (s *keyChangesStatements) Prepare() (err error) { - if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil { - return err - } - if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil { - return err - } - return nil -} - func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) return diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/userapi/storage/sqlite3/one_time_keys_table.go similarity index 89% rename from keyserver/storage/sqlite3/one_time_keys_table.go rename to userapi/storage/sqlite3/one_time_keys_table.go index 7a923d0e5..a992d399c 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/userapi/storage/sqlite3/one_time_keys_table.go @@ -22,8 +22,8 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) var oneTimeKeysSchema = ` @@ -48,7 +48,7 @@ const upsertKeysSQL = "" + " ON CONFLICT (user_id, device_id, key_id, algorithm)" + " DO UPDATE SET key_json = $6" -const selectKeysSQL = "" + +const selectOneTimeKeysSQL = "" + "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" const selectKeysCountSQL = "" + @@ -83,25 +83,14 @@ func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { if err != nil { return nil, err } - if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil { - return nil, err - } - if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { - return nil, err - } - if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { - return nil, err - } - if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertKeysStmt, upsertKeysSQL}, + {&s.selectKeysStmt, selectOneTimeKeysSQL}, + {&s.selectKeysCountStmt, selectKeysCountSQL}, + {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL}, + {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL}, + {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL}, + }.Prepare(db) } func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/userapi/storage/sqlite3/stale_device_lists.go similarity index 98% rename from keyserver/storage/sqlite3/stale_device_lists.go rename to userapi/storage/sqlite3/stale_device_lists.go index fd76a6e3b..f078fc99f 100644 --- a/keyserver/storage/sqlite3/stale_device_lists.go +++ b/userapi/storage/sqlite3/stale_device_lists.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go index a1365c944..72b3ba49d 100644 --- a/userapi/storage/sqlite3/stats_table.go +++ b/userapi/storage/sqlite3/stats_table.go @@ -256,6 +256,7 @@ func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int if err != nil { return 0, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "allUsers.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) err = stmt.QueryRowContext(ctx, 1, 2, 3, 4, @@ -269,6 +270,7 @@ func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (res if err != nil { return 0, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "nonBridgedUsers.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) err = stmt.QueryRowContext(ctx, 1, 2, 3, @@ -286,6 +288,7 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx) if err != nil { return nil, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "registeredUserByType.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) registeredAfter := time.Now().AddDate(0, 0, -30) diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 85a1f7063..0f3eeed1b 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -30,8 +30,8 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" ) -// NewDatabase creates a new accounts and profiles database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { +// NewUserDatabase creates a new accounts and profiles database +func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) if err != nil { return nil, err @@ -134,3 +134,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, }, nil } + +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) + if err != nil { + return nil, err + } + otk, err := NewSqliteOneTimeKeysTable(db) + if err != nil { + return nil, err + } + dk, err := NewSqliteDeviceKeysTable(db) + if err != nil { + return nil, err + } + kc, err := NewSqliteKeyChangesTable(db) + if err != nil { + return nil, err + } + sdl, err := NewSqliteStaleDeviceListsTable(db) + if err != nil { + return nil, err + } + csk, err := NewSqliteCrossSigningKeysTable(db) + if err != nil { + return nil, err + } + css, err := NewSqliteCrossSigningSigsTable(db) + if err != nil { + return nil, err + } + + return &shared.KeyDatabase{ + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, + CrossSigningKeysTable: csk, + CrossSigningSigsTable: css, + Writer: writer, + }, nil +} diff --git a/userapi/storage/storage.go b/userapi/storage/storage.go index 42221e752..0329fb46a 100644 --- a/userapi/storage/storage.go +++ b/userapi/storage/storage.go @@ -29,15 +29,36 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/sqlite3" ) -// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) +// NewUserDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters -func NewUserAPIDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) { +func NewUserDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + serverName gomatrixserverlib.ServerName, + bcryptCost int, + openIDTokenLifetimeMS int64, + loginTokenLifetime time.Duration, + serverNoticesLocalpart string, +) (UserDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) default: return nil, fmt.Errorf("unexpected database type") } } + +// NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme) +// and sets postgres connection parameters. +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (KeyDatabase, error) { + switch { + case dbProperties.ConnectionString.IsSQLite(): + return sqlite3.NewKeyDatabase(base, dbProperties) + case dbProperties.ConnectionString.IsPostgres(): + return postgres.NewKeyDatabase(base, dbProperties) + default: + return nil, fmt.Errorf("unexpected database type") + } +} diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 23aafff03..f52e7e17d 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -4,9 +4,12 @@ import ( "context" "encoding/json" "fmt" + "reflect" + "sync" "testing" "time" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/stretchr/testify/assert" @@ -29,14 +32,14 @@ var ( ctx = context.Background() ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") if err != nil { - t.Fatalf("NewUserAPIDatabase returned %s", err) + t.Fatalf("NewUserDatabase returned %s", err) } return db, func() { close() @@ -47,7 +50,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun // Tests storing and getting account data func Test_AccountData(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() alice := test.NewUser(t) localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -78,7 +81,7 @@ func Test_AccountData(t *testing.T) { // Tests the creation of accounts func Test_Accounts(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() alice := test.NewUser(t) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -158,7 +161,7 @@ func Test_Devices(t *testing.T) { accessToken := util.RandomString(16) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "") @@ -238,7 +241,7 @@ func Test_KeyBackup(t *testing.T) { room := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() wantAuthData := json.RawMessage("my auth data") @@ -315,7 +318,7 @@ func Test_KeyBackup(t *testing.T) { func Test_LoginToken(t *testing.T) { alice := test.NewUser(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // create a new token @@ -347,7 +350,7 @@ func Test_OpenID(t *testing.T) { token := util.RandomString(24) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS @@ -368,7 +371,7 @@ func Test_Profile(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // create account, which also creates a profile @@ -417,7 +420,7 @@ func Test_Pusher(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() appID := util.RandomString(8) @@ -468,7 +471,7 @@ func Test_ThreePID(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() threePID := util.RandomString(8) medium := util.RandomString(8) @@ -507,7 +510,7 @@ func Test_Notification(t *testing.T) { room := test.NewRoom(t, alice) room2 := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // generate some dummy notifications for i := 0; i < 10; i++ { @@ -571,3 +574,184 @@ func Test_Notification(t *testing.T) { assert.Equal(t, int64(0), total) }) } + +func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { + base, close := testrig.CreateBaseDendrite(t, dbType) + db, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database) + if err != nil { + t.Fatalf("failed to create new database: %v", err) + } + return db, close +} + +func MustNotError(t *testing.T, err error) { + t.Helper() + if err == nil { + return + } + t.Fatalf("operation failed: %s", err) +} + +func TestKeyChanges(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + _, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeIDC { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) + } + if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +func TestKeyChangesNoDupes(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + if deviceChangeIDA == deviceChangeIDB { + t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) + } + deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeID { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) + } + if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +func TestKeyChangesUpperLimit(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + _, err = db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeIDB { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) + } + if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +var dbLock sync.Mutex +var deviceArray = []string{"AAA", "another_device"} + +// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, +// and that they are returned correctly when querying for device keys. +func TestDeviceKeysStreamIDGeneration(t *testing.T) { + var err error + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + alice := "@alice:TestDeviceKeysStreamIDGeneration" + bob := "@bob:TestDeviceKeysStreamIDGeneration" + msgs := []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 + }, + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: bob, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 as this is a different user + }, + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "another_device", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 2 as this is a 2nd device key + }, + } + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 1 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) + } + if msgs[1].StreamID != 1 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) + } + if msgs[2].StreamID != 2 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) + } + + // updating a device sets the next stream ID for that user + msgs = []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v2"}`), + }, + // StreamID: 3 + }, + } + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 3 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) + } + + dbLock.Lock() + defer dbLock.Unlock() + // Querying for device keys returns the latest stream IDs + msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false) + + if err != nil { + t.Fatalf("DeviceKeysForUser returned error: %s", err) + } + wantStreamIDs := map[string]int64{ + "AAA": 3, + "another_device": 2, + } + if len(msgs) != len(wantStreamIDs) { + t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) + } + for _, m := range msgs { + if m.StreamID != wantStreamIDs[m.DeviceID] { + t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) + } + } + }) +} diff --git a/userapi/storage/storage_wasm.go b/userapi/storage/storage_wasm.go index 5d5d292e6..163e3e173 100644 --- a/userapi/storage/storage_wasm.go +++ b/userapi/storage/storage_wasm.go @@ -32,10 +32,10 @@ func NewUserAPIDatabase( openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string, -) (Database, error) { +) (UserDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 9221e571f..693e73038 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -20,10 +20,10 @@ import ( "encoding/json" "time" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/types" ) @@ -145,3 +145,47 @@ const ( // uint32. AllNotifications NotificationFilter = (1 << 31) - 1 ) + +type OneTimeKeys interface { + SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) + InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) + // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. + // Returns an empty map if the key does not exist. + SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error) + DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error +} + +type DeviceKeys interface { + SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error + SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) + CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) + SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) + DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error + DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error +} + +type KeyChanges interface { + InsertKeyChange(ctx context.Context, userID string) (int64, error) + // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets. + // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset. + SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) +} + +type StaleDeviceLists interface { + InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error + SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error +} + +type CrossSigningKeys interface { + SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) + UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error +} + +type CrossSigningSigs interface { + SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error) + UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error +} diff --git a/keyserver/storage/tables/stale_device_lists_test.go b/userapi/storage/tables/stale_device_lists_test.go similarity index 94% rename from keyserver/storage/tables/stale_device_lists_test.go rename to userapi/storage/tables/stale_device_lists_test.go index 76d3baddd..b9bdafdaa 100644 --- a/keyserver/storage/tables/stale_device_lists_test.go +++ b/userapi/storage/tables/stale_device_lists_test.go @@ -4,15 +4,15 @@ import ( "context" "testing" + "github.com/matrix-org/dendrite/userapi/storage/postgres" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/keyserver/storage/postgres" - "github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) { diff --git a/keyserver/types/storage.go b/userapi/types/storage.go similarity index 100% rename from keyserver/types/storage.go rename to userapi/types/storage.go diff --git a/userapi/userapi.go b/userapi/userapi.go index 2dd81d759..826bd7213 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -17,13 +17,11 @@ package userapi import ( "time" + fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/internal/pushgateway" - keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/consumers" @@ -33,16 +31,20 @@ import ( "github.com/matrix-org/dendrite/userapi/util" ) -// NewInternalAPI returns a concerete implementation of the internal API. Callers +// NewInternalAPI returns a concrete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - base *base.BaseDendrite, cfg *config.UserAPI, - appServices []config.ApplicationService, keyAPI keyapi.UserKeyAPI, - rsAPI rsapi.UserRoomserverAPI, pgClient pushgateway.Client, -) api.UserInternalAPI { + base *base.BaseDendrite, + rsAPI rsapi.UserRoomserverAPI, + fedClient fedsenderapi.KeyserverFederationAPI, +) *internal.UserInternalAPI { + cfg := &base.Cfg.UserAPI js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + appServices := base.Cfg.Derived.ApplicationServices - db, err := storage.NewUserAPIDatabase( + pgClient := base.PushGatewayHTTPClient() + + db, err := storage.NewUserDatabase( base, &cfg.AccountDatabase, cfg.Matrix.ServerName, @@ -55,6 +57,11 @@ func NewInternalAPI( logrus.WithError(err).Panicf("failed to connect to accounts db") } + keyDB, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database) + if err != nil { + logrus.WithError(err).Panicf("failed to connect to key db") + } + syncProducer := producers.NewSyncAPI( db, js, // TODO: user API should handle syncs for account data. Right now, @@ -64,17 +71,50 @@ func NewInternalAPI( cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData), cfg.Matrix.JetStream.Prefixed(jetstream.OutputNotificationData), ) + keyChangeProducer := &producers.KeyChange{ + Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), + JetStream: js, + DB: keyDB, + } userAPI := &internal.UserInternalAPI{ DB: db, + KeyDatabase: keyDB, SyncProducer: syncProducer, + KeyChangeProducer: keyChangeProducer, Config: cfg, AppServices: appServices, - KeyAPI: keyAPI, RSAPI: rsAPI, DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, PgClient: pgClient, - Cfg: cfg, + FedClient: fedClient, + } + + updater := internal.NewDeviceListUpdater(base.ProcessContext, keyDB, userAPI, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable + userAPI.Updater = updater + // Remove users which we don't share a room with anymore + if err := updater.CleanUp(); err != nil { + logrus.WithError(err).Error("failed to cleanup stale device lists") + } + + go func() { + if err := updater.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start device list updater") + } + }() + + dlConsumer := consumers.NewDeviceListUpdateConsumer( + base.ProcessContext, cfg, js, updater, + ) + if err := dlConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start device list consumer") + } + + sigConsumer := consumers.NewSigningKeyUpdateConsumer( + base.ProcessContext, cfg, js, userAPI, + ) + if err := sigConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start signing key consumer") } receiptConsumer := consumers.NewOutputReceiptEventConsumer( diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 68d08c2fd..08b1336be 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -21,7 +21,10 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/nats-io/nats.go" "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/setup/config" @@ -38,32 +41,55 @@ const ( type apiTestOpts struct { loginTokenLifetime time.Duration + serverName string } -func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.Database, func()) { +type dummyProducer struct{} + +func (d *dummyProducer) PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error) { + return &nats.PubAck{}, nil +} + +func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.UserDatabase, func()) { if opts.loginTokenLifetime == 0 { opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond } base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - accountDB, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + sName := serverName + if opts.serverName != "" { + sName = gomatrixserverlib.ServerName(opts.serverName) + } + accountDB, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") + }, sName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") if err != nil { t.Fatalf("failed to create account DB: %s", err) } + keyDB, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }) + if err != nil { + t.Fatalf("failed to create key DB: %s", err) + } + cfg := &config.UserAPI{ Matrix: &config.Global{ SigningIdentity: gomatrixserverlib.SigningIdentity{ - ServerName: serverName, + ServerName: sName, }, }, } + syncProducer := producers.NewSyncAPI(accountDB, &dummyProducer{}, "", "") + keyChangeProducer := &producers.KeyChange{DB: keyDB, JetStream: &dummyProducer{}} return &internal.UserInternalAPI{ - DB: accountDB, - Config: cfg, + DB: accountDB, + KeyDatabase: keyDB, + Config: cfg, + SyncProducer: syncProducer, + KeyChangeProducer: keyChangeProducer, }, accountDB, func() { close() baseclose() @@ -332,3 +358,292 @@ func TestQueryAccountByLocalpart(t *testing.T) { testCases(t, intAPI) }) } + +func TestAccountData(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + + testCases := []struct { + name string + inputData *api.InputAccountDataRequest + wantErr bool + }{ + { + name: "not a local user", + inputData: &api.InputAccountDataRequest{UserID: "@notlocal:example.com"}, + wantErr: true, + }, + { + name: "local user missing datatype", + inputData: &api.InputAccountDataRequest{UserID: alice.ID}, + wantErr: true, + }, + { + name: "missing json", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: nil}, + wantErr: true, + }, + { + name: "with json", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}")}, + }, + { + name: "room data", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}"), RoomID: "!dummy:test"}, + }, + { + name: "ignored users", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.ignored_user_list", AccountData: []byte("{}")}, + }, + { + name: "m.fully_read", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.fully_read", AccountData: []byte("{}")}, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType) + defer close() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res := api.InputAccountDataResponse{} + err := intAPI.InputAccountData(ctx, tc.inputData, &res) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + + // query the data again and compare + queryRes := api.QueryAccountDataResponse{} + queryReq := api.QueryAccountDataRequest{ + UserID: tc.inputData.UserID, + DataType: tc.inputData.DataType, + RoomID: tc.inputData.RoomID, + } + err = intAPI.QueryAccountData(ctx, &queryReq, &queryRes) + if err != nil && !tc.wantErr { + t.Fatal(err) + } + // verify global data + if tc.inputData.RoomID == "" { + if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.GlobalAccountData[tc.inputData.DataType]) { + t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.GlobalAccountData[tc.inputData.DataType])) + } + } else { + // verify room data + if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType]) { + t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType])) + } + } + }) + } + }) +} + +func TestDevices(t *testing.T) { + ctx := context.Background() + + dupeAccessToken := util.RandomString(8) + + displayName := "testing" + + creationTests := []struct { + name string + inputData *api.PerformDeviceCreationRequest + wantErr bool + wantNewDevID bool + }{ + { + name: "not a local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", ServerName: "notlocal"}, + wantErr: true, + }, + { + name: "implicit local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, DeviceDisplayName: &displayName}, + }, + { + name: "explicit local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test2", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + { + name: "dupe token - ok", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true}, + }, + { + name: "dupe token - not ok", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true}, + wantErr: true, + }, + { + name: "test3 second device", // used to test deletion later + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + { + name: "test3 third device", // used to test deletion later + wantNewDevID: true, + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + } + + deletionTests := []struct { + name string + inputData *api.PerformDeviceDeletionRequest + wantErr bool + wantDevices int + }{ + { + name: "deletion - not a local user", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test:notlocalhost"}, + wantErr: true, + }, + { + name: "deleting not existing devices should not error", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test", DeviceIDs: []string{"iDontExist"}}, + wantDevices: 1, + }, + { + name: "delete all devices", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test"}, + wantDevices: 0, + }, + { + name: "delete all devices", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test3:test"}, + wantDevices: 0, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType) + defer close() + + for _, tc := range creationTests { + t.Run(tc.name, func(t *testing.T) { + res := api.PerformDeviceCreationResponse{} + deviceID := util.RandomString(8) + tc.inputData.DeviceID = &deviceID + if tc.wantNewDevID { + tc.inputData.DeviceID = nil + } + err := intAPI.PerformDeviceCreation(ctx, tc.inputData, &res) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + if !res.DeviceCreated { + return + } + + queryDevicesRes := api.QueryDevicesResponse{} + queryDevicesReq := api.QueryDevicesRequest{UserID: res.Device.UserID} + if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil { + t.Fatal(err) + } + // We only want to verify one device + if len(queryDevicesRes.Devices) > 1 { + return + } + res.Device.AccessToken = "" + + // At this point, there should only be one device + if !reflect.DeepEqual(*res.Device, queryDevicesRes.Devices[0]) { + t.Fatalf("expected device to be\n%#v, got \n%#v", *res.Device, queryDevicesRes.Devices[0]) + } + + newDisplayName := "new name" + if tc.inputData.DeviceDisplayName == nil { + updateRes := api.PerformDeviceUpdateResponse{} + updateReq := api.PerformDeviceUpdateRequest{ + RequestingUserID: fmt.Sprintf("@%s:%s", tc.inputData.Localpart, "test"), + DeviceID: deviceID, + DisplayName: &newDisplayName, + } + + if err = intAPI.PerformDeviceUpdate(ctx, &updateReq, &updateRes); err != nil { + t.Fatal(err) + } + } + + queryDeviceInfosRes := api.QueryDeviceInfosResponse{} + queryDeviceInfosReq := api.QueryDeviceInfosRequest{DeviceIDs: []string{*tc.inputData.DeviceID}} + if err = intAPI.QueryDeviceInfos(ctx, &queryDeviceInfosReq, &queryDeviceInfosRes); err != nil { + t.Fatal(err) + } + gotDisplayName := queryDeviceInfosRes.DeviceInfo[*tc.inputData.DeviceID].DisplayName + if tc.inputData.DeviceDisplayName != nil { + wantDisplayName := *tc.inputData.DeviceDisplayName + if wantDisplayName != gotDisplayName { + t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName) + } + } else { + wantDisplayName := newDisplayName + if wantDisplayName != gotDisplayName { + t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName) + } + } + }) + } + + for _, tc := range deletionTests { + t.Run(tc.name, func(t *testing.T) { + delRes := api.PerformDeviceDeletionResponse{} + err := intAPI.PerformDeviceDeletion(ctx, tc.inputData, &delRes) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + if tc.wantErr { + return + } + + queryDevicesRes := api.QueryDevicesResponse{} + queryDevicesReq := api.QueryDevicesRequest{UserID: tc.inputData.UserID} + if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil { + t.Fatal(err) + } + + if len(queryDevicesRes.Devices) != tc.wantDevices { + t.Fatalf("expected %d devices, got %d", tc.wantDevices, len(queryDevicesRes.Devices)) + } + + }) + } + }) +} + +// Tests that the session ID of a device is not reused when reusing the same device ID. +func TestDeviceIDReuse(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType) + defer close() + + res := api.PerformDeviceCreationResponse{} + // create a first device + deviceID := util.RandomString(8) + req := api.PerformDeviceCreationRequest{Localpart: "alice", ServerName: "test", DeviceID: &deviceID, NoDeviceListUpdate: true} + err := intAPI.PerformDeviceCreation(ctx, &req, &res) + if err != nil { + t.Fatal(err) + } + + // Do the same request again, we expect a different sessionID + res2 := api.PerformDeviceCreationResponse{} + err = intAPI.PerformDeviceCreation(ctx, &req, &res2) + if err != nil { + t.Fatalf("expected no error, but got: %v", err) + } + + if res2.Device.SessionID == res.Device.SessionID { + t.Fatalf("expected a different session ID, but they are the same") + } + }) +} diff --git a/userapi/util/devices.go b/userapi/util/devices.go index c55fc7999..31617d8c1 100644 --- a/userapi/util/devices.go +++ b/userapi/util/devices.go @@ -19,7 +19,7 @@ type PusherDevice struct { } // GetPushDevices pushes to the configured devices of a local user. -func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { +func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.UserDatabase) ([]*PusherDevice, error) { pushers, err := db.GetPushers(ctx, localpart, serverName) if err != nil { return nil, fmt.Errorf("db.GetPushers: %w", err) diff --git a/userapi/util/notify.go b/userapi/util/notify.go index fc0ab39bf..08d1371d6 100644 --- a/userapi/util/notify.go +++ b/userapi/util/notify.go @@ -13,11 +13,11 @@ import ( ) // NotifyUserCountsAsync sends notifications to a local user's -// notification destinations. Database lookups run synchronously, but +// notification destinations. UserDatabase lookups run synchronously, but // a single goroutine is started when talking to the Push // gateways. There is no way to know when the background goroutine has // finished. -func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error { +func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.UserDatabase) error { pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db) if err != nil { return err diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index f1d20259c..421852d3f 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -79,7 +79,7 @@ func TestNotifyUserCountsAsync(t *testing.T) { defer close() base, _, _ := testrig.Base(nil) defer base.Close() - db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "test", bcrypt.MinCost, 0, 0, "") if err != nil { diff --git a/userapi/util/phonehomestats_test.go b/userapi/util/phonehomestats_test.go index 6e62210e8..5f626b5bc 100644 --- a/userapi/util/phonehomestats_test.go +++ b/userapi/util/phonehomestats_test.go @@ -21,7 +21,7 @@ func TestCollect(t *testing.T) { b, _, _ := testrig.Base(nil) connStr, closeDB := test.PrepareDBConnectionString(t, dbType) defer closeDB() - db, err := storage.NewUserAPIDatabase(b, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(b, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, 1000, 1000, "") if err != nil {