mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-21 05:43:09 -06:00
Merge branch 'master' into neilalexander/config
This commit is contained in:
commit
cc1d01cd28
|
|
@ -118,7 +118,9 @@ func (m *DendriteMonolith) Start() {
|
||||||
|
|
||||||
serverKeyAPI := &signing.YggdrasilKeys{}
|
serverKeyAPI := &signing.YggdrasilKeys{}
|
||||||
keyRing := serverKeyAPI.KeyRing()
|
keyRing := serverKeyAPI.KeyRing()
|
||||||
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices)
|
keyAPI := keyserver.NewInternalAPI(base.Cfg, federation, base.KafkaProducer)
|
||||||
|
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices, keyAPI)
|
||||||
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
|
||||||
rsAPI := roomserver.NewInternalAPI(
|
rsAPI := roomserver.NewInternalAPI(
|
||||||
base, keyRing, federation,
|
base, keyRing, federation,
|
||||||
|
|
@ -156,7 +158,11 @@ func (m *DendriteMonolith) Start() {
|
||||||
RoomserverAPI: rsAPI,
|
RoomserverAPI: rsAPI,
|
||||||
UserAPI: userAPI,
|
UserAPI: userAPI,
|
||||||
StateAPI: stateAPI,
|
StateAPI: stateAPI,
|
||||||
|
<<<<<<< HEAD
|
||||||
KeyAPI: keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, userAPI, base.KafkaProducer),
|
KeyAPI: keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, userAPI, base.KafkaProducer),
|
||||||
|
=======
|
||||||
|
KeyAPI: keyAPI,
|
||||||
|
>>>>>>> master
|
||||||
ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider(
|
ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider(
|
||||||
ygg, fsAPI, federation,
|
ygg, fsAPI, federation,
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -115,33 +115,9 @@ func GetDevicesByLocalpart(
|
||||||
|
|
||||||
// UpdateDeviceByID handles PUT on /devices/{deviceID}
|
// UpdateDeviceByID handles PUT on /devices/{deviceID}
|
||||||
func UpdateDeviceByID(
|
func UpdateDeviceByID(
|
||||||
req *http.Request, deviceDB devices.Database, device *api.Device,
|
req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
|
||||||
deviceID string,
|
deviceID string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
|
||||||
return jsonerror.InternalServerError()
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := req.Context()
|
|
||||||
dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: http.StatusNotFound,
|
|
||||||
JSON: jsonerror.NotFound("Unknown device"),
|
|
||||||
}
|
|
||||||
} else if err != nil {
|
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByID failed")
|
|
||||||
return jsonerror.InternalServerError()
|
|
||||||
}
|
|
||||||
|
|
||||||
if dev.UserID != device.UserID {
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: http.StatusForbidden,
|
|
||||||
JSON: jsonerror.Forbidden("device not owned by current user"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
defer req.Body.Close() // nolint: errcheck
|
defer req.Body.Close() // nolint: errcheck
|
||||||
|
|
||||||
|
|
@ -152,10 +128,28 @@ func UpdateDeviceByID(
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil {
|
var performRes api.PerformDeviceUpdateResponse
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("deviceDB.UpdateDevice failed")
|
err := userAPI.PerformDeviceUpdate(req.Context(), &api.PerformDeviceUpdateRequest{
|
||||||
|
RequestingUserID: device.UserID,
|
||||||
|
DeviceID: deviceID,
|
||||||
|
DisplayName: payload.DisplayName,
|
||||||
|
}, &performRes)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceUpdate failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
if !performRes.DeviceExists {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusNotFound,
|
||||||
|
JSON: jsonerror.Forbidden("device does not exist"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if performRes.Forbidden {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.Forbidden("device not owned by current user"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
|
|
@ -165,7 +159,7 @@ func UpdateDeviceByID(
|
||||||
|
|
||||||
// DeleteDeviceById handles DELETE requests to /devices/{deviceId}
|
// DeleteDeviceById handles DELETE requests to /devices/{deviceId}
|
||||||
func DeleteDeviceById(
|
func DeleteDeviceById(
|
||||||
req *http.Request, userInteractiveAuth *auth.UserInteractive, deviceDB devices.Database, device *api.Device,
|
req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.UserInternalAPI, device *api.Device,
|
||||||
deviceID string,
|
deviceID string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
ctx := req.Context()
|
ctx := req.Context()
|
||||||
|
|
@ -197,8 +191,12 @@ func DeleteDeviceById(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := deviceDB.RemoveDevice(ctx, deviceID, localpart); err != nil {
|
var res api.PerformDeviceDeletionResponse
|
||||||
util.GetLogger(ctx).WithError(err).Error("deviceDB.RemoveDevice failed")
|
if err := userAPI.PerformDeviceDeletion(ctx, &api.PerformDeviceDeletionRequest{
|
||||||
|
UserID: device.UserID,
|
||||||
|
DeviceIDs: []string{deviceID},
|
||||||
|
}, &res); err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("userAPI.PerformDeviceDeletion failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -210,26 +208,24 @@ func DeleteDeviceById(
|
||||||
|
|
||||||
// DeleteDevices handles POST requests to /delete_devices
|
// DeleteDevices handles POST requests to /delete_devices
|
||||||
func DeleteDevices(
|
func DeleteDevices(
|
||||||
req *http.Request, deviceDB devices.Database, device *api.Device,
|
req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
|
||||||
return jsonerror.InternalServerError()
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := req.Context()
|
ctx := req.Context()
|
||||||
payload := devicesDeleteJSON{}
|
payload := devicesDeleteJSON{}
|
||||||
|
|
||||||
if err := json.NewDecoder(req.Body).Decode(&payload); err != nil {
|
if err := json.NewDecoder(req.Body).Decode(&payload); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("json.NewDecoder.Decode failed")
|
util.GetLogger(ctx).WithError(err).Error("json.NewDecoder.Decode failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
defer req.Body.Close() // nolint: errcheck
|
defer req.Body.Close() // nolint: errcheck
|
||||||
|
|
||||||
if err := deviceDB.RemoveDevices(ctx, localpart, payload.Devices); err != nil {
|
var res api.PerformDeviceDeletionResponse
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveDevices failed")
|
if err := userAPI.PerformDeviceDeletion(ctx, &api.PerformDeviceDeletionRequest{
|
||||||
|
UserID: device.UserID,
|
||||||
|
DeviceIDs: payload.Devices,
|
||||||
|
}, &res); err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("userAPI.PerformDeviceDeletion failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,8 +23,8 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
"github.com/matrix-org/dendrite/internal/config"
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/devices"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
@ -57,7 +57,7 @@ func passwordLogin() flows {
|
||||||
|
|
||||||
// Login implements GET and POST /login
|
// Login implements GET and POST /login
|
||||||
func Login(
|
func Login(
|
||||||
req *http.Request, accountDB accounts.Database, deviceDB devices.Database,
|
req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI,
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if req.Method == http.MethodGet {
|
if req.Method == http.MethodGet {
|
||||||
|
|
@ -81,7 +81,7 @@ func Login(
|
||||||
return *authErr
|
return *authErr
|
||||||
}
|
}
|
||||||
// make a device/access token
|
// make a device/access token
|
||||||
return completeAuth(req.Context(), cfg.Matrix.ServerName, deviceDB, login)
|
return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login)
|
||||||
}
|
}
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusMethodNotAllowed,
|
Code: http.StatusMethodNotAllowed,
|
||||||
|
|
@ -90,7 +90,7 @@ func Login(
|
||||||
}
|
}
|
||||||
|
|
||||||
func completeAuth(
|
func completeAuth(
|
||||||
ctx context.Context, serverName gomatrixserverlib.ServerName, deviceDB devices.Database, login *auth.Login,
|
ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.UserInternalAPI, login *auth.Login,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
token, err := auth.GenerateAccessToken()
|
token, err := auth.GenerateAccessToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -104,9 +104,13 @@ func completeAuth(
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
dev, err := deviceDB.CreateDevice(
|
var performRes userapi.PerformDeviceCreationResponse
|
||||||
ctx, localpart, login.DeviceID, token, login.InitialDisplayName,
|
err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{
|
||||||
)
|
DeviceDisplayName: login.InitialDisplayName,
|
||||||
|
DeviceID: login.DeviceID,
|
||||||
|
AccessToken: token,
|
||||||
|
Localpart: localpart,
|
||||||
|
}, &performRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusInternalServerError,
|
Code: http.StatusInternalServerError,
|
||||||
|
|
@ -117,10 +121,10 @@ func completeAuth(
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: loginResponse{
|
JSON: loginResponse{
|
||||||
UserID: dev.UserID,
|
UserID: performRes.Device.UserID,
|
||||||
AccessToken: dev.AccessToken,
|
AccessToken: performRes.Device.AccessToken,
|
||||||
HomeServer: serverName,
|
HomeServer: serverName,
|
||||||
DeviceID: dev.ID,
|
DeviceID: performRes.Device.ID,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -387,7 +387,7 @@ func Setup(
|
||||||
|
|
||||||
r0mux.Handle("/login",
|
r0mux.Handle("/login",
|
||||||
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
|
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
|
||||||
return Login(req, accountDB, deviceDB, cfg)
|
return Login(req, accountDB, userAPI, cfg)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
|
|
@ -644,7 +644,7 @@ func Setup(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return UpdateDeviceByID(req, deviceDB, device, vars["deviceID"])
|
return UpdateDeviceByID(req, userAPI, device, vars["deviceID"])
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
|
|
@ -654,13 +654,13 @@ func Setup(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return DeleteDeviceById(req, userInteractiveAuth, deviceDB, device, vars["deviceID"])
|
return DeleteDeviceById(req, userInteractiveAuth, userAPI, device, vars["deviceID"])
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodDelete, http.MethodOptions)
|
).Methods(http.MethodDelete, http.MethodOptions)
|
||||||
|
|
||||||
r0mux.Handle("/delete_devices",
|
r0mux.Handle("/delete_devices",
|
||||||
httputil.MakeAuthAPI("delete_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("delete_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
return DeleteDevices(req, deviceDB, device)
|
return DeleteDevices(req, userAPI, device)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -144,7 +144,9 @@ func main() {
|
||||||
accountDB := base.Base.CreateAccountsDB()
|
accountDB := base.Base.CreateAccountsDB()
|
||||||
deviceDB := base.Base.CreateDeviceDB()
|
deviceDB := base.Base.CreateDeviceDB()
|
||||||
federation := createFederationClient(base)
|
federation := createFederationClient(base)
|
||||||
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil)
|
keyAPI := keyserver.NewInternalAPI(&base.Base.Cfg.KeyServer, federation, base.Base.KafkaProducer)
|
||||||
|
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil, keyAPI)
|
||||||
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
|
||||||
serverKeyAPI := serverkeyapi.NewInternalAPI(
|
serverKeyAPI := serverkeyapi.NewInternalAPI(
|
||||||
&base.Base.Cfg.ServerKeyAPI, federation, base.Base.Caches,
|
&base.Base.Cfg.ServerKeyAPI, federation, base.Base.Caches,
|
||||||
|
|
@ -189,7 +191,7 @@ func main() {
|
||||||
ServerKeyAPI: serverKeyAPI,
|
ServerKeyAPI: serverKeyAPI,
|
||||||
StateAPI: stateAPI,
|
StateAPI: stateAPI,
|
||||||
UserAPI: userAPI,
|
UserAPI: userAPI,
|
||||||
KeyAPI: keyserver.NewInternalAPI(&base.Base.Cfg.KeyServer, federation, userAPI, base.Base.KafkaProducer),
|
KeyAPI: keyAPI,
|
||||||
ExtPublicRoomsProvider: provider,
|
ExtPublicRoomsProvider: provider,
|
||||||
}
|
}
|
||||||
monolith.AddAllPublicRoutes(base.Base.PublicAPIMux)
|
monolith.AddAllPublicRoutes(base.Base.PublicAPIMux)
|
||||||
|
|
|
||||||
|
|
@ -105,7 +105,9 @@ func main() {
|
||||||
serverKeyAPI := &signing.YggdrasilKeys{}
|
serverKeyAPI := &signing.YggdrasilKeys{}
|
||||||
keyRing := serverKeyAPI.KeyRing()
|
keyRing := serverKeyAPI.KeyRing()
|
||||||
|
|
||||||
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil)
|
keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer)
|
||||||
|
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil, keyAPI)
|
||||||
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
|
||||||
rsComponent := roomserver.NewInternalAPI(
|
rsComponent := roomserver.NewInternalAPI(
|
||||||
base, keyRing, federation,
|
base, keyRing, federation,
|
||||||
|
|
@ -144,8 +146,7 @@ func main() {
|
||||||
RoomserverAPI: rsAPI,
|
RoomserverAPI: rsAPI,
|
||||||
UserAPI: userAPI,
|
UserAPI: userAPI,
|
||||||
StateAPI: stateAPI,
|
StateAPI: stateAPI,
|
||||||
KeyAPI: keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, userAPI, base.KafkaProducer),
|
KeyAPI: keyAPI,
|
||||||
//ServerKeyAPI: serverKeyAPI,
|
|
||||||
ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider(
|
ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider(
|
||||||
ygg, fsAPI, federation,
|
ygg, fsAPI, federation,
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,8 @@ func main() {
|
||||||
base := setup.NewBaseDendrite(cfg, "KeyServer", true)
|
base := setup.NewBaseDendrite(cfg, "KeyServer", true)
|
||||||
defer base.Close() // nolint: errcheck
|
defer base.Close() // nolint: errcheck
|
||||||
|
|
||||||
intAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, base.CreateFederationClient(), base.UserAPIClient(), base.KafkaProducer)
|
intAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, base.CreateFederationClient(), base.KafkaProducer)
|
||||||
|
intAPI.SetUserAPI(base.UserAPIClient())
|
||||||
|
|
||||||
keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI)
|
keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,9 @@ func main() {
|
||||||
serverKeyAPI = base.ServerKeyAPIClient()
|
serverKeyAPI = base.ServerKeyAPIClient()
|
||||||
}
|
}
|
||||||
keyRing := serverKeyAPI.KeyRing()
|
keyRing := serverKeyAPI.KeyRing()
|
||||||
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices)
|
keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer)
|
||||||
|
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices, keyAPI)
|
||||||
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
|
||||||
rsImpl := roomserver.NewInternalAPI(
|
rsImpl := roomserver.NewInternalAPI(
|
||||||
base, keyRing, federation,
|
base, keyRing, federation,
|
||||||
|
|
@ -121,7 +123,6 @@ func main() {
|
||||||
rsImpl.SetFederationSenderAPI(fsAPI)
|
rsImpl.SetFederationSenderAPI(fsAPI)
|
||||||
|
|
||||||
stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer)
|
stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer)
|
||||||
keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, userAPI, base.KafkaProducer)
|
|
||||||
|
|
||||||
monolith := setup.Monolith{
|
monolith := setup.Monolith{
|
||||||
Config: base.Cfg,
|
Config: base.Cfg,
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ func main() {
|
||||||
accountDB := base.CreateAccountsDB()
|
accountDB := base.CreateAccountsDB()
|
||||||
deviceDB := base.CreateDeviceDB()
|
deviceDB := base.CreateDeviceDB()
|
||||||
|
|
||||||
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices)
|
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient())
|
||||||
|
|
||||||
userapi.AddInternalRoutes(base.InternalAPIMux, userAPI)
|
userapi.AddInternalRoutes(base.InternalAPIMux, userAPI)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -196,7 +196,9 @@ func main() {
|
||||||
accountDB := base.CreateAccountsDB()
|
accountDB := base.CreateAccountsDB()
|
||||||
deviceDB := base.CreateDeviceDB()
|
deviceDB := base.CreateDeviceDB()
|
||||||
federation := createFederationClient(cfg, node)
|
federation := createFederationClient(cfg, node)
|
||||||
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil)
|
keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer)
|
||||||
|
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil, keyAPI)
|
||||||
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
|
||||||
fetcher := &libp2pKeyFetcher{}
|
fetcher := &libp2pKeyFetcher{}
|
||||||
keyRing := gomatrixserverlib.KeyRing{
|
keyRing := gomatrixserverlib.KeyRing{
|
||||||
|
|
@ -233,7 +235,7 @@ func main() {
|
||||||
RoomserverAPI: rsAPI,
|
RoomserverAPI: rsAPI,
|
||||||
StateAPI: stateAPI,
|
StateAPI: stateAPI,
|
||||||
UserAPI: userAPI,
|
UserAPI: userAPI,
|
||||||
KeyAPI: keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, userAPI, base.KafkaProducer),
|
KeyAPI: keyAPI,
|
||||||
//ServerKeyAPI: serverKeyAPI,
|
//ServerKeyAPI: serverKeyAPI,
|
||||||
ExtPublicRoomsProvider: p2pPublicRoomProvider,
|
ExtPublicRoomsProvider: p2pPublicRoomProvider,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -92,13 +93,14 @@ func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *rooms
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.SyncProducer) {
|
func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.SyncProducer, func()) {
|
||||||
cfg := &config.Dendrite{}
|
cfg := &config.Dendrite{}
|
||||||
cfg.Defaults()
|
stateDBName := "test_state.db"
|
||||||
|
naffkaDBName := "test_naffka.db"
|
||||||
cfg.Global.ServerName = "kaer.morhen"
|
cfg.Global.ServerName = "kaer.morhen"
|
||||||
cfg.Global.Kafka.Topics.OutputRoomEvent = config.Topic(kafkaTopic)
|
cfg.Global.Kafka.Topics.OutputRoomEvent = config.Topic(kafkaTopic)
|
||||||
cfg.CurrentStateServer.Database.ConnectionString = config.DataSource("file::memory:")
|
cfg.CurrentStateServer.Database.ConnectionString = config.DataSource("file:" + stateDBName)
|
||||||
db, err := sqlutil.Open(&cfg.CurrentStateServer.Database)
|
db, err := sqlutil.Open(sqlutil.SQLiteDriverName(), "file:"+naffkaDBName, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to open naffka database: %s", err)
|
t.Fatalf("Failed to open naffka database: %s", err)
|
||||||
}
|
}
|
||||||
|
|
@ -110,11 +112,15 @@ func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.Sync
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create naffka consumer: %s", err)
|
t.Fatalf("Failed to create naffka consumer: %s", err)
|
||||||
}
|
}
|
||||||
return NewInternalAPI(&cfg.CurrentStateServer, naff), naff
|
return NewInternalAPI(cfg, naff), naff, func() {
|
||||||
|
os.Remove(naffkaDBName)
|
||||||
|
os.Remove(stateDBName)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryCurrentState(t *testing.T) {
|
func TestQueryCurrentState(t *testing.T) {
|
||||||
currStateAPI, producer := MustMakeInternalAPI(t)
|
currStateAPI, producer, cancel := MustMakeInternalAPI(t)
|
||||||
|
defer cancel()
|
||||||
plTuple := gomatrixserverlib.StateKeyTuple{
|
plTuple := gomatrixserverlib.StateKeyTuple{
|
||||||
EventType: "m.room.power_levels",
|
EventType: "m.room.power_levels",
|
||||||
StateKey: "",
|
StateKey: "",
|
||||||
|
|
@ -217,7 +223,8 @@ func mustMakeMembershipEvent(t *testing.T, roomID, userID, membership string) *r
|
||||||
|
|
||||||
// This test makes sure that QuerySharedUsers is returning the correct users for a range of sets.
|
// This test makes sure that QuerySharedUsers is returning the correct users for a range of sets.
|
||||||
func TestQuerySharedUsers(t *testing.T) {
|
func TestQuerySharedUsers(t *testing.T) {
|
||||||
currStateAPI, producer := MustMakeInternalAPI(t)
|
currStateAPI, producer, cancel := MustMakeInternalAPI(t)
|
||||||
|
defer cancel()
|
||||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join"))
|
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join"))
|
||||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join"))
|
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join"))
|
||||||
|
|
||||||
|
|
@ -230,6 +237,9 @@ func TestQuerySharedUsers(t *testing.T) {
|
||||||
|
|
||||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join"))
|
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join"))
|
||||||
|
|
||||||
|
// we don't know when the server has processed the events
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
req api.QuerySharedUsersRequest
|
req api.QuerySharedUsersRequest
|
||||||
wantRes api.QuerySharedUsersResponse
|
wantRes api.QuerySharedUsersResponse
|
||||||
|
|
|
||||||
|
|
@ -19,14 +19,19 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
type KeyInternalAPI interface {
|
type KeyInternalAPI interface {
|
||||||
|
// SetUserAPI assigns a user API to query when extracting device names.
|
||||||
|
SetUserAPI(i userapi.UserInternalAPI)
|
||||||
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse)
|
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse)
|
||||||
// PerformClaimKeys claims one-time keys for use in pre-key messages
|
// PerformClaimKeys claims one-time keys for use in pre-key messages
|
||||||
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
|
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
|
||||||
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
|
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
|
||||||
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse)
|
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse)
|
||||||
|
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
// KeyError is returned if there was a problem performing/querying the server
|
// KeyError is returned if there was a problem performing/querying the server
|
||||||
|
|
@ -38,6 +43,13 @@ func (k *KeyError) Error() string {
|
||||||
return k.Err
|
return k.Err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeviceMessage represents the message produced into Kafka by the key server.
|
||||||
|
type DeviceMessage struct {
|
||||||
|
DeviceKeys
|
||||||
|
// A monotonically increasing number which represents device changes for this user.
|
||||||
|
StreamID int
|
||||||
|
}
|
||||||
|
|
||||||
// DeviceKeys represents a set of device keys for a single device
|
// DeviceKeys represents a set of device keys for a single device
|
||||||
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
|
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
|
||||||
type DeviceKeys struct {
|
type DeviceKeys struct {
|
||||||
|
|
@ -45,10 +57,20 @@ type DeviceKeys struct {
|
||||||
UserID string
|
UserID string
|
||||||
// The device ID of this device
|
// The device ID of this device
|
||||||
DeviceID string
|
DeviceID string
|
||||||
|
// The device display name
|
||||||
|
DisplayName string
|
||||||
// The raw device key JSON
|
// The raw device key JSON
|
||||||
KeyJSON []byte
|
KeyJSON []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithStreamID returns a copy of this device message with the given stream ID
|
||||||
|
func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage {
|
||||||
|
return DeviceMessage{
|
||||||
|
DeviceKeys: *k,
|
||||||
|
StreamID: streamID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// OneTimeKeys represents a set of one-time keys for a single device
|
// OneTimeKeys represents a set of one-time keys for a single device
|
||||||
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
|
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
|
||||||
type OneTimeKeys struct {
|
type OneTimeKeys struct {
|
||||||
|
|
@ -153,3 +175,16 @@ type QueryKeyChangesResponse struct {
|
||||||
// Set if there was a problem handling the request.
|
// Set if there was a problem handling the request.
|
||||||
Error *KeyError
|
Error *KeyError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QueryOneTimeKeysRequest struct {
|
||||||
|
// The local user to query OTK counts for
|
||||||
|
UserID string
|
||||||
|
// The device to query OTK counts for
|
||||||
|
DeviceID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryOneTimeKeysResponse struct {
|
||||||
|
// OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84
|
||||||
|
Count OneTimeKeysCount
|
||||||
|
Error *KeyError
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,10 @@ type KeyInternalAPI struct {
|
||||||
Producer *producers.KeyChange
|
Producer *producers.KeyChange
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *KeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) {
|
||||||
|
a.UserAPI = i
|
||||||
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
|
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
|
||||||
if req.Partition < 0 {
|
if req.Partition < 0 {
|
||||||
req.Partition = a.Producer.DefaultPartition()
|
req.Partition = a.Producer.DefaultPartition()
|
||||||
|
|
@ -57,7 +61,7 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC
|
||||||
|
|
||||||
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||||
res.KeyErrors = make(map[string]map[string]*api.KeyError)
|
res.KeyErrors = make(map[string]map[string]*api.KeyError)
|
||||||
a.uploadDeviceKeys(ctx, req, res)
|
a.uploadLocalDeviceKeys(ctx, req, res)
|
||||||
a.uploadOneTimeKeys(ctx, req, res)
|
a.uploadOneTimeKeys(ctx, req, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -164,6 +168,17 @@ func (a *KeyInternalAPI) claimRemoteKeys(
|
||||||
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
|
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) {
|
||||||
|
count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||||
|
if err != nil {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res.Count = *count
|
||||||
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
|
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
|
||||||
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
|
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
|
||||||
res.Failures = make(map[string]interface{})
|
res.Failures = make(map[string]interface{})
|
||||||
|
|
@ -202,6 +217,9 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
||||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||||
}
|
}
|
||||||
for _, dk := range deviceKeys {
|
for _, dk := range deviceKeys {
|
||||||
|
if len(dk.KeyJSON) == 0 {
|
||||||
|
continue // don't include blank keys
|
||||||
|
}
|
||||||
// inject display name if known
|
// inject display name if known
|
||||||
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
|
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
|
||||||
DisplayName string `json:"device_display_name,omitempty"`
|
DisplayName string `json:"device_display_name,omitempty"`
|
||||||
|
|
@ -268,14 +286,25 @@ func (a *KeyInternalAPI) queryRemoteKeys(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||||
var keysToStore []api.DeviceKeys
|
var keysToStore []api.DeviceMessage
|
||||||
// assert that the user ID / device ID are not lying for each key
|
// assert that the user ID / device ID are not lying for each key
|
||||||
for _, key := range req.DeviceKeys {
|
for _, key := range req.DeviceKeys {
|
||||||
|
_, serverName, err := gomatrixserverlib.SplitID('@', key.UserID)
|
||||||
|
if err != nil {
|
||||||
|
continue // ignore invalid users
|
||||||
|
}
|
||||||
|
if serverName != a.ThisServer {
|
||||||
|
continue // ignore remote users
|
||||||
|
}
|
||||||
|
if len(key.KeyJSON) == 0 {
|
||||||
|
keysToStore = append(keysToStore, key.WithStreamID(0))
|
||||||
|
continue // deleted keys don't need sanity checking
|
||||||
|
}
|
||||||
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
|
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
|
||||||
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
|
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
|
||||||
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
|
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
|
||||||
keysToStore = append(keysToStore, key)
|
keysToStore = append(keysToStore, key.WithStreamID(0))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -286,12 +315,15 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU
|
||||||
),
|
),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// get existing device keys so we can check for changes
|
// get existing device keys so we can check for changes
|
||||||
existingKeys := make([]api.DeviceKeys, len(keysToStore))
|
existingKeys := make([]api.DeviceMessage, len(keysToStore))
|
||||||
for i := range keysToStore {
|
for i := range keysToStore {
|
||||||
existingKeys[i] = api.DeviceKeys{
|
existingKeys[i] = api.DeviceMessage{
|
||||||
|
DeviceKeys: api.DeviceKeys{
|
||||||
UserID: keysToStore[i].UserID,
|
UserID: keysToStore[i].UserID,
|
||||||
DeviceID: keysToStore[i].DeviceID,
|
DeviceID: keysToStore[i].DeviceID,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
|
if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
|
||||||
|
|
@ -301,13 +333,14 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// store the device keys and emit changes
|
// store the device keys and emit changes
|
||||||
if err := a.DB.StoreDeviceKeys(ctx, keysToStore); err != nil {
|
err := a.DB.StoreDeviceKeys(ctx, keysToStore)
|
||||||
|
if err != nil {
|
||||||
res.Error = &api.KeyError{
|
res.Error = &api.KeyError{
|
||||||
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
|
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err := a.emitDeviceKeyChanges(existingKeys, keysToStore)
|
err = a.emitDeviceKeyChanges(existingKeys, keysToStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
|
util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
|
||||||
}
|
}
|
||||||
|
|
@ -352,13 +385,15 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) error {
|
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage) error {
|
||||||
// find keys in new that are not in existing
|
// find keys in new that are not in existing
|
||||||
var keysAdded []api.DeviceKeys
|
var keysAdded []api.DeviceMessage
|
||||||
for _, newKey := range new {
|
for _, newKey := range new {
|
||||||
exists := false
|
exists := false
|
||||||
for _, existingKey := range existing {
|
for _, existingKey := range existing {
|
||||||
if bytes.Equal(existingKey.KeyJSON, newKey.KeyJSON) {
|
// Do not treat the absence of keys as equal, or else we will not emit key changes
|
||||||
|
// when users delete devices which never had a key to begin with as both KeyJSONs are nil.
|
||||||
|
if bytes.Equal(existingKey.KeyJSON, newKey.KeyJSON) && len(existingKey.KeyJSON) > 0 {
|
||||||
exists = true
|
exists = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/httputil"
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/opentracing/opentracing-go"
|
"github.com/opentracing/opentracing-go"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -30,6 +31,7 @@ const (
|
||||||
PerformClaimKeysPath = "/keyserver/performClaimKeys"
|
PerformClaimKeysPath = "/keyserver/performClaimKeys"
|
||||||
QueryKeysPath = "/keyserver/queryKeys"
|
QueryKeysPath = "/keyserver/queryKeys"
|
||||||
QueryKeyChangesPath = "/keyserver/queryKeyChanges"
|
QueryKeyChangesPath = "/keyserver/queryKeyChanges"
|
||||||
|
QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API.
|
// NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API.
|
||||||
|
|
@ -52,6 +54,10 @@ type httpKeyInternalAPI struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpKeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) {
|
||||||
|
// no-op: doesn't need it
|
||||||
|
}
|
||||||
|
|
||||||
func (h *httpKeyInternalAPI) PerformClaimKeys(
|
func (h *httpKeyInternalAPI) PerformClaimKeys(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *api.PerformClaimKeysRequest,
|
request *api.PerformClaimKeysRequest,
|
||||||
|
|
@ -103,6 +109,23 @@ func (h *httpKeyInternalAPI) QueryKeys(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpKeyInternalAPI) QueryOneTimeKeys(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryOneTimeKeysRequest,
|
||||||
|
response *api.QueryOneTimeKeysResponse,
|
||||||
|
) {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOneTimeKeys")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + QueryOneTimeKeysPath
|
||||||
|
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
|
if err != nil {
|
||||||
|
response.Error = &api.KeyError{
|
||||||
|
Err: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *httpKeyInternalAPI) QueryKeyChanges(
|
func (h *httpKeyInternalAPI) QueryKeyChanges(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *api.QueryKeyChangesRequest,
|
request *api.QueryKeyChangesRequest,
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(QueryOneTimeKeysPath,
|
||||||
|
httputil.MakeInternalAPI("queryOneTimeKeys", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryOneTimeKeysRequest{}
|
||||||
|
response := api.QueryOneTimeKeysResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
s.QueryOneTimeKeys(req.Context(), &request, &response)
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
internalAPIMux.Handle(QueryKeyChangesPath,
|
internalAPIMux.Handle(QueryKeyChangesPath,
|
||||||
httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse {
|
httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse {
|
||||||
request := api.QueryKeyChangesRequest{}
|
request := api.QueryKeyChangesRequest{}
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/keyserver/inthttp"
|
"github.com/matrix-org/dendrite/keyserver/inthttp"
|
||||||
"github.com/matrix-org/dendrite/keyserver/producers"
|
"github.com/matrix-org/dendrite/keyserver/producers"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage"
|
"github.com/matrix-org/dendrite/keyserver/storage"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
@ -37,7 +36,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
|
||||||
// NewInternalAPI returns a concerete implementation of the internal API. Callers
|
// NewInternalAPI returns a concerete implementation of the internal API. Callers
|
||||||
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
||||||
func NewInternalAPI(
|
func NewInternalAPI(
|
||||||
cfg *config.KeyServer, fedClient *gomatrixserverlib.FederationClient, userAPI userapi.UserInternalAPI, producer sarama.SyncProducer,
|
cfg *config.KeyServer, fedClient *gomatrixserverlib.FederationClient, producer sarama.SyncProducer,
|
||||||
) api.KeyInternalAPI {
|
) api.KeyInternalAPI {
|
||||||
db, err := storage.NewDatabase(&cfg.Database)
|
db, err := storage.NewDatabase(&cfg.Database)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -52,7 +51,6 @@ func NewInternalAPI(
|
||||||
DB: db,
|
DB: db,
|
||||||
ThisServer: cfg.Matrix.ServerName,
|
ThisServer: cfg.Matrix.ServerName,
|
||||||
FedClient: fedClient,
|
FedClient: fedClient,
|
||||||
UserAPI: userAPI,
|
|
||||||
Producer: keyChangeProducer,
|
Producer: keyChangeProducer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ func (p *KeyChange) DefaultPartition() int32 {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProduceKeyChanges creates new change events for each key
|
// ProduceKeyChanges creates new change events for each key
|
||||||
func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error {
|
func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
var m sarama.ProducerMessage
|
var m sarama.ProducerMessage
|
||||||
|
|
||||||
|
|
@ -67,6 +67,7 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error {
|
||||||
"device_id": key.DeviceID,
|
"device_id": key.DeviceID,
|
||||||
"partition": partition,
|
"partition": partition,
|
||||||
"offset": offset,
|
"offset": offset,
|
||||||
|
"len_key_bytes": len(key.KeyJSON),
|
||||||
}).Infof("Produced to key change topic '%s'", p.Topic)
|
}).Infof("Produced to key change topic '%s'", p.Topic)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
|
|
@ -29,16 +29,21 @@ type Database interface {
|
||||||
// StoreOneTimeKeys persists the given one-time keys.
|
// StoreOneTimeKeys persists the given one-time keys.
|
||||||
StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
|
StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
|
||||||
|
|
||||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced.
|
// OneTimeKeysCount returns a count of all OTKs for this device.
|
||||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error
|
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
||||||
|
|
||||||
// StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced.
|
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||||
|
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
|
|
||||||
|
// StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||||
|
// for this (user, device).
|
||||||
|
// The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
|
||||||
// Returns an error if there was a problem storing the keys.
|
// Returns an error if there was a problem storing the keys.
|
||||||
StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
|
StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
|
|
||||||
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
||||||
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
||||||
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
|
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
||||||
|
|
||||||
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
|
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
|
||||||
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
@ -32,28 +31,37 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
||||||
device_id TEXT NOT NULL,
|
device_id TEXT NOT NULL,
|
||||||
ts_added_secs BIGINT NOT NULL,
|
ts_added_secs BIGINT NOT NULL,
|
||||||
key_json TEXT NOT NULL,
|
key_json TEXT NOT NULL,
|
||||||
|
-- the stream ID of this key, scoped per-user. This gets updated when the device key changes.
|
||||||
|
-- This means we do not store an unbounded append-only log of device keys, which is not actually
|
||||||
|
-- required in the spec because in the event of a missed update the server fetches the entire
|
||||||
|
-- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
|
||||||
|
stream_id BIGINT NOT NULL,
|
||||||
-- Clobber based on tuple of user/device.
|
-- Clobber based on tuple of user/device.
|
||||||
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
|
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const upsertDeviceKeysSQL = "" +
|
const upsertDeviceKeysSQL = "" +
|
||||||
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" +
|
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
|
||||||
" VALUES ($1, $2, $3, $4)" +
|
" VALUES ($1, $2, $3, $4, $5)" +
|
||||||
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
|
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
|
||||||
" DO UPDATE SET key_json = $4"
|
" DO UPDATE SET key_json = $4, stream_id = $5"
|
||||||
|
|
||||||
const selectDeviceKeysSQL = "" +
|
const selectDeviceKeysSQL = "" +
|
||||||
"SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
"SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||||
|
|
||||||
const selectBatchDeviceKeysSQL = "" +
|
const selectBatchDeviceKeysSQL = "" +
|
||||||
"SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
|
const selectMaxStreamForUserSQL = "" +
|
||||||
|
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
type deviceKeysStatements struct {
|
type deviceKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
upsertDeviceKeysStmt *sql.Stmt
|
upsertDeviceKeysStmt *sql.Stmt
|
||||||
selectDeviceKeysStmt *sql.Stmt
|
selectDeviceKeysStmt *sql.Stmt
|
||||||
selectBatchDeviceKeysStmt *sql.Stmt
|
selectBatchDeviceKeysStmt *sql.Stmt
|
||||||
|
selectMaxStreamForUserStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
|
|
@ -73,38 +81,54 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
|
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
for i, key := range keys {
|
for i, key := range keys {
|
||||||
var keyJSONStr string
|
var keyJSONStr string
|
||||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr)
|
var streamID int
|
||||||
|
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// this will be '' when there is no device
|
// this will be '' when there is no device
|
||||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
keys[i].KeyJSON = []byte(keyJSONStr)
|
||||||
|
keys[i].StreamID = streamID
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error {
|
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
||||||
now := time.Now().Unix()
|
// nullable if there are no results
|
||||||
return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
|
var nullStream sql.NullInt32
|
||||||
|
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
if nullStream.Valid {
|
||||||
|
streamID = nullStream.Int32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
|
now := time.Now().Unix()
|
||||||
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
||||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON),
|
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
|
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||||
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -114,15 +138,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
||||||
for _, d := range deviceIDs {
|
for _, d := range deviceIDs {
|
||||||
deviceIDMap[d] = true
|
deviceIDMap[d] = true
|
||||||
}
|
}
|
||||||
var result []api.DeviceKeys
|
var result []api.DeviceMessage
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var dk api.DeviceKeys
|
var dk api.DeviceMessage
|
||||||
dk.UserID = userID
|
dk.UserID = userID
|
||||||
var keyJSON string
|
var keyJSON string
|
||||||
if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil {
|
var streamID int
|
||||||
|
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
dk.KeyJSON = []byte(keyJSON)
|
dk.KeyJSON = []byte(keyJSON)
|
||||||
|
dk.StreamID = streamID
|
||||||
// include the key if we want all keys (no device) or it was asked
|
// include the key if we want all keys (no device) or it was asked
|
||||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||||
result = append(result, dk)
|
result = append(result, dk)
|
||||||
|
|
|
||||||
|
|
@ -121,6 +121,28 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||||
|
counts := &api.OneTimeKeysCount{
|
||||||
|
DeviceID: deviceID,
|
||||||
|
UserID: userID,
|
||||||
|
KeyCount: make(map[string]int),
|
||||||
|
}
|
||||||
|
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var algorithm string
|
||||||
|
var count int
|
||||||
|
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
counts.KeyCount[algorithm] = count
|
||||||
|
}
|
||||||
|
return counts, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
|
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
counts := &api.OneTimeKeysCount{
|
counts := &api.OneTimeKeysCount{
|
||||||
|
|
|
||||||
|
|
@ -39,15 +39,40 @@ func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (
|
||||||
return d.OneTimeKeysTable.InsertOneTimeKeys(ctx, keys)
|
return d.OneTimeKeysTable.InsertOneTimeKeys(ctx, keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
|
func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||||
|
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error {
|
func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
return d.DeviceKeysTable.InsertDeviceKeys(ctx, keys)
|
// work out the latest stream IDs for each user
|
||||||
|
userIDToStreamID := make(map[string]int)
|
||||||
|
for _, k := range keys {
|
||||||
|
userIDToStreamID[k.UserID] = 0
|
||||||
|
}
|
||||||
|
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
||||||
|
for userID := range userIDToStreamID {
|
||||||
|
streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
userIDToStreamID[userID] = int(streamID)
|
||||||
|
}
|
||||||
|
// set the stream IDs for each key
|
||||||
|
for i := range keys {
|
||||||
|
k := keys[i]
|
||||||
|
userIDToStreamID[k.UserID]++ // start stream from 1
|
||||||
|
k.StreamID = userIDToStreamID[k.UserID]
|
||||||
|
keys[i] = k
|
||||||
|
}
|
||||||
|
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
|
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||||
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
|
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
@ -32,28 +31,33 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
||||||
device_id TEXT NOT NULL,
|
device_id TEXT NOT NULL,
|
||||||
ts_added_secs BIGINT NOT NULL,
|
ts_added_secs BIGINT NOT NULL,
|
||||||
key_json TEXT NOT NULL,
|
key_json TEXT NOT NULL,
|
||||||
|
stream_id BIGINT NOT NULL,
|
||||||
-- Clobber based on tuple of user/device.
|
-- Clobber based on tuple of user/device.
|
||||||
UNIQUE (user_id, device_id)
|
UNIQUE (user_id, device_id)
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const upsertDeviceKeysSQL = "" +
|
const upsertDeviceKeysSQL = "" +
|
||||||
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" +
|
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
|
||||||
" VALUES ($1, $2, $3, $4)" +
|
" VALUES ($1, $2, $3, $4, $5)" +
|
||||||
" ON CONFLICT (user_id, device_id)" +
|
" ON CONFLICT (user_id, device_id)" +
|
||||||
" DO UPDATE SET key_json = $4"
|
" DO UPDATE SET key_json = $4, stream_id = $5"
|
||||||
|
|
||||||
const selectDeviceKeysSQL = "" +
|
const selectDeviceKeysSQL = "" +
|
||||||
"SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
"SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||||
|
|
||||||
const selectBatchDeviceKeysSQL = "" +
|
const selectBatchDeviceKeysSQL = "" +
|
||||||
"SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
|
const selectMaxStreamForUserSQL = "" +
|
||||||
|
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
type deviceKeysStatements struct {
|
type deviceKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
upsertDeviceKeysStmt *sql.Stmt
|
upsertDeviceKeysStmt *sql.Stmt
|
||||||
selectDeviceKeysStmt *sql.Stmt
|
selectDeviceKeysStmt *sql.Stmt
|
||||||
selectBatchDeviceKeysStmt *sql.Stmt
|
selectBatchDeviceKeysStmt *sql.Stmt
|
||||||
|
selectMaxStreamForUserStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
|
|
@ -73,10 +77,13 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
|
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||||
deviceIDMap := make(map[string]bool)
|
deviceIDMap := make(map[string]bool)
|
||||||
for _, d := range deviceIDs {
|
for _, d := range deviceIDs {
|
||||||
deviceIDMap[d] = true
|
deviceIDMap[d] = true
|
||||||
|
|
@ -86,15 +93,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
||||||
var result []api.DeviceKeys
|
var result []api.DeviceMessage
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var dk api.DeviceKeys
|
var dk api.DeviceMessage
|
||||||
dk.UserID = userID
|
dk.UserID = userID
|
||||||
var keyJSON string
|
var keyJSON string
|
||||||
if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil {
|
var streamID int
|
||||||
|
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
dk.KeyJSON = []byte(keyJSON)
|
dk.KeyJSON = []byte(keyJSON)
|
||||||
|
dk.StreamID = streamID
|
||||||
// include the key if we want all keys (no device) or it was asked
|
// include the key if we want all keys (no device) or it was asked
|
||||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||||
result = append(result, dk)
|
result = append(result, dk)
|
||||||
|
|
@ -103,30 +112,43 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
|
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
for i, key := range keys {
|
for i, key := range keys {
|
||||||
var keyJSONStr string
|
var keyJSONStr string
|
||||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr)
|
var streamID int
|
||||||
|
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// this will be '' when there is no device
|
// this will be '' when there is no device
|
||||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
keys[i].KeyJSON = []byte(keyJSONStr)
|
||||||
|
keys[i].StreamID = streamID
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error {
|
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
||||||
now := time.Now().Unix()
|
// nullable if there are no results
|
||||||
return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
|
var nullStream sql.NullInt32
|
||||||
|
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
if nullStream.Valid {
|
||||||
|
streamID = nullStream.Int32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
|
now := time.Now().Unix()
|
||||||
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
||||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON),
|
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -121,6 +121,28 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||||
|
counts := &api.OneTimeKeysCount{
|
||||||
|
DeviceID: deviceID,
|
||||||
|
UserID: userID,
|
||||||
|
KeyCount: make(map[string]int),
|
||||||
|
}
|
||||||
|
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var algorithm string
|
||||||
|
var count int
|
||||||
|
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
counts.KeyCount[algorithm] = count
|
||||||
|
}
|
||||||
|
return counts, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
|
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
counts := &api.OneTimeKeysCount{
|
counts := &api.OneTimeKeysCount{
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import (
|
||||||
|
|
||||||
"github.com/Shopify/sarama"
|
"github.com/Shopify/sarama"
|
||||||
"github.com/matrix-org/dendrite/internal/config"
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ctx = context.Background()
|
var ctx = context.Background()
|
||||||
|
|
@ -82,3 +83,84 @@ func TestKeyChangesUpperLimit(t *testing.T) {
|
||||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
|
||||||
|
// and that they are returned correctly when querying for device keys.
|
||||||
|
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||||
|
db, err := NewDatabase("file::memory:", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to NewDatabase: %s", err)
|
||||||
|
}
|
||||||
|
alice := "@alice:TestDeviceKeysStreamIDGeneration"
|
||||||
|
bob := "@bob:TestDeviceKeysStreamIDGeneration"
|
||||||
|
msgs := []api.DeviceMessage{
|
||||||
|
{
|
||||||
|
DeviceKeys: api.DeviceKeys{
|
||||||
|
DeviceID: "AAA",
|
||||||
|
UserID: alice,
|
||||||
|
KeyJSON: []byte(`{"key":"v1"}`),
|
||||||
|
},
|
||||||
|
// StreamID: 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
DeviceKeys: api.DeviceKeys{
|
||||||
|
DeviceID: "AAA",
|
||||||
|
UserID: bob,
|
||||||
|
KeyJSON: []byte(`{"key":"v1"}`),
|
||||||
|
},
|
||||||
|
// StreamID: 1 as this is a different user
|
||||||
|
},
|
||||||
|
{
|
||||||
|
DeviceKeys: api.DeviceKeys{
|
||||||
|
DeviceID: "another_device",
|
||||||
|
UserID: alice,
|
||||||
|
KeyJSON: []byte(`{"key":"v1"}`),
|
||||||
|
},
|
||||||
|
// StreamID: 2 as this is a 2nd device key
|
||||||
|
},
|
||||||
|
}
|
||||||
|
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
|
||||||
|
if msgs[0].StreamID != 1 {
|
||||||
|
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
|
||||||
|
}
|
||||||
|
if msgs[1].StreamID != 1 {
|
||||||
|
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
|
||||||
|
}
|
||||||
|
if msgs[2].StreamID != 2 {
|
||||||
|
t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// updating a device sets the next stream ID for that user
|
||||||
|
msgs = []api.DeviceMessage{
|
||||||
|
{
|
||||||
|
DeviceKeys: api.DeviceKeys{
|
||||||
|
DeviceID: "AAA",
|
||||||
|
UserID: alice,
|
||||||
|
KeyJSON: []byte(`{"key":"v2"}`),
|
||||||
|
},
|
||||||
|
// StreamID: 3
|
||||||
|
},
|
||||||
|
}
|
||||||
|
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
|
||||||
|
if msgs[0].StreamID != 3 {
|
||||||
|
t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Querying for device keys returns the latest stream IDs
|
||||||
|
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DeviceKeysForUser returned error: %s", err)
|
||||||
|
}
|
||||||
|
wantStreamIDs := map[string]int{
|
||||||
|
"AAA": 3,
|
||||||
|
"another_device": 2,
|
||||||
|
}
|
||||||
|
if len(msgs) != len(wantStreamIDs) {
|
||||||
|
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
|
||||||
|
}
|
||||||
|
for _, m := range msgs {
|
||||||
|
if m.StreamID != wantStreamIDs[m.DeviceID] {
|
||||||
|
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ import (
|
||||||
|
|
||||||
type OneTimeKeys interface {
|
type OneTimeKeys interface {
|
||||||
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||||
|
CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
||||||
InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
|
InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
|
||||||
// SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
|
// SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
|
||||||
// Returns an empty map if the key does not exist.
|
// Returns an empty map if the key does not exist.
|
||||||
|
|
@ -31,9 +32,10 @@ type OneTimeKeys interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type DeviceKeys interface {
|
type DeviceKeys interface {
|
||||||
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error
|
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
|
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
||||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
|
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
||||||
|
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type KeyChanges interface {
|
type KeyChanges interface {
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,7 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er
|
||||||
defer func() {
|
defer func() {
|
||||||
s.updateOffset(msg)
|
s.updateOffset(msg)
|
||||||
}()
|
}()
|
||||||
var output api.DeviceKeys
|
var output api.DeviceMessage
|
||||||
if err := json.Unmarshal(msg.Value, &output); err != nil {
|
if err := json.Unmarshal(msg.Value, &output); err != nil {
|
||||||
// If the message was invalid, log it and move on to the next message in the stream
|
// If the message was invalid, log it and move on to the next message in the stream
|
||||||
log.WithError(err).Error("syncapi: failed to unmarshal key change event from key server")
|
log.WithError(err).Error("syncapi: failed to unmarshal key change event from key server")
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ type OutputRoomEventConsumer struct {
|
||||||
rsConsumer *internal.ContinualConsumer
|
rsConsumer *internal.ContinualConsumer
|
||||||
db storage.Database
|
db storage.Database
|
||||||
notifier *sync.Notifier
|
notifier *sync.Notifier
|
||||||
|
keyChanges *OutputKeyChangeEventConsumer
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
|
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
|
||||||
|
|
@ -44,6 +45,7 @@ func NewOutputRoomEventConsumer(
|
||||||
n *sync.Notifier,
|
n *sync.Notifier,
|
||||||
store storage.Database,
|
store storage.Database,
|
||||||
rsAPI api.RoomserverInternalAPI,
|
rsAPI api.RoomserverInternalAPI,
|
||||||
|
keyChanges *OutputKeyChangeEventConsumer,
|
||||||
) *OutputRoomEventConsumer {
|
) *OutputRoomEventConsumer {
|
||||||
|
|
||||||
consumer := internal.ContinualConsumer{
|
consumer := internal.ContinualConsumer{
|
||||||
|
|
@ -56,6 +58,7 @@ func NewOutputRoomEventConsumer(
|
||||||
db: store,
|
db: store,
|
||||||
notifier: n,
|
notifier: n,
|
||||||
rsAPI: rsAPI,
|
rsAPI: rsAPI,
|
||||||
|
keyChanges: keyChanges,
|
||||||
}
|
}
|
||||||
consumer.ProcessMessage = s.onMessage
|
consumer.ProcessMessage = s.onMessage
|
||||||
|
|
||||||
|
|
@ -160,9 +163,29 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
|
||||||
}
|
}
|
||||||
s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0, nil))
|
s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0, nil))
|
||||||
|
|
||||||
|
s.notifyKeyChanges(&ev)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OutputRoomEventConsumer) notifyKeyChanges(ev *gomatrixserverlib.HeaderedEvent) {
|
||||||
|
if ev.Type() != gomatrixserverlib.MRoomMember || ev.StateKey() == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
membership, err := ev.Membership()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch membership {
|
||||||
|
case gomatrixserverlib.Join:
|
||||||
|
s.keyChanges.OnJoinEvent(ev)
|
||||||
|
case gomatrixserverlib.Ban:
|
||||||
|
fallthrough
|
||||||
|
case gomatrixserverlib.Leave:
|
||||||
|
s.keyChanges.OnLeaveEvent(ev)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OutputRoomEventConsumer) onNewInviteEvent(
|
func (s *OutputRoomEventConsumer) onNewInviteEvent(
|
||||||
ctx context.Context, msg api.OutputNewInviteEvent,
|
ctx context.Context, msg api.OutputNewInviteEvent,
|
||||||
) error {
|
) error {
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/Shopify/sarama"
|
"github.com/Shopify/sarama"
|
||||||
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
|
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
|
||||||
|
|
@ -28,6 +29,20 @@ import (
|
||||||
|
|
||||||
const DeviceListLogName = "dl"
|
const DeviceListLogName = "dl"
|
||||||
|
|
||||||
|
// DeviceOTKCounts adds one-time key counts to the /sync response
|
||||||
|
func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID, deviceID string, res *types.Response) error {
|
||||||
|
var queryRes api.QueryOneTimeKeysResponse
|
||||||
|
keyAPI.QueryOneTimeKeys(ctx, &api.QueryOneTimeKeysRequest{
|
||||||
|
UserID: userID,
|
||||||
|
DeviceID: deviceID,
|
||||||
|
}, &queryRes)
|
||||||
|
if queryRes.Error != nil {
|
||||||
|
return queryRes.Error
|
||||||
|
}
|
||||||
|
res.DeviceListsOTKCount = queryRes.Count.KeyCount
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
|
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
|
||||||
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
|
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
|
||||||
// be already filled in with join/leave information.
|
// be already filled in with join/leave information.
|
||||||
|
|
@ -35,6 +50,7 @@ func DeviceListCatchup(
|
||||||
ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI,
|
ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI,
|
||||||
userID string, res *types.Response, from, to types.StreamingToken,
|
userID string, res *types.Response, from, to types.StreamingToken,
|
||||||
) (hasNew bool, err error) {
|
) (hasNew bool, err error) {
|
||||||
|
|
||||||
// Track users who we didn't track before but now do by virtue of sharing a room with them, or not.
|
// Track users who we didn't track before but now do by virtue of sharing a room with them, or not.
|
||||||
newlyJoinedRooms := joinedRooms(res, userID)
|
newlyJoinedRooms := joinedRooms(res, userID)
|
||||||
newlyLeftRooms := leftRooms(res)
|
newlyLeftRooms := leftRooms(res)
|
||||||
|
|
@ -88,6 +104,16 @@ func DeviceListCatchup(
|
||||||
if !userSet[userID] {
|
if !userSet[userID] {
|
||||||
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
|
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
|
||||||
hasNew = true
|
hasNew = true
|
||||||
|
userSet[userID] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if the response has any join/leave events, add them now.
|
||||||
|
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
|
||||||
|
for _, userID := range membershipEvents(res) {
|
||||||
|
if !userSet[userID] {
|
||||||
|
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
|
||||||
|
hasNew = true
|
||||||
|
userSet[userID] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return hasNew, nil
|
return hasNew, nil
|
||||||
|
|
@ -219,3 +245,25 @@ func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID strin
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// returns the user IDs of anyone joining or leaving a room in this response. These users will be added to
|
||||||
|
// the 'changed' property because of https://matrix.org/docs/spec/client_server/r0.6.1#id84
|
||||||
|
// "For optimal performance, Alice should be added to changed in Bob's sync only when she adds a new device,
|
||||||
|
// or when Alice and Bob now share a room but didn't share any room previously. However, for the sake of simpler
|
||||||
|
// logic, a server may add Alice to changed when Alice and Bob share a new room, even if they previously already shared a room."
|
||||||
|
func membershipEvents(res *types.Response) (userIDs []string) {
|
||||||
|
for _, room := range res.Rooms.Join {
|
||||||
|
for _, ev := range room.Timeline.Events {
|
||||||
|
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil {
|
||||||
|
if strings.Contains(string(ev.Content), `"join"`) {
|
||||||
|
userIDs = append(userIDs, *ev.StateKey)
|
||||||
|
} else if strings.Contains(string(ev.Content), `"leave"`) {
|
||||||
|
userIDs = append(userIDs, *ev.StateKey)
|
||||||
|
} else if strings.Contains(string(ev.Content), `"ban"`) {
|
||||||
|
userIDs = append(userIDs, *ev.StateKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/currentstateserver/api"
|
"github.com/matrix-org/dendrite/currentstateserver/api"
|
||||||
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -29,12 +30,17 @@ type mockKeyAPI struct{}
|
||||||
func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) {
|
func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {}
|
||||||
|
|
||||||
// PerformClaimKeys claims one-time keys for use in pre-key messages
|
// PerformClaimKeys claims one-time keys for use in pre-key messages
|
||||||
func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) {
|
func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) {
|
||||||
}
|
}
|
||||||
func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) {
|
func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) {
|
||||||
}
|
}
|
||||||
func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) {
|
func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) {
|
||||||
|
}
|
||||||
|
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockCurrentStateAPI struct {
|
type mockCurrentStateAPI struct {
|
||||||
|
|
|
||||||
|
|
@ -168,7 +168,7 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use
|
||||||
}
|
}
|
||||||
// work out room joins/leaves
|
// work out room joins/leaves
|
||||||
res, err := rp.db.IncrementalSync(
|
res, err := rp.db.IncrementalSync(
|
||||||
req.Context(), types.NewResponse(), *device, fromToken, toToken, 0, false,
|
req.Context(), types.NewResponse(), *device, fromToken, toToken, 10, false,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("Failed to IncrementalSync")
|
util.GetLogger(req.Context()).WithError(err).Error("Failed to IncrementalSync")
|
||||||
|
|
@ -192,8 +192,9 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
|
// nolint:gocyclo
|
||||||
res = types.NewResponse()
|
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (*types.Response, error) {
|
||||||
|
res := types.NewResponse()
|
||||||
|
|
||||||
since := types.NewStreamToken(0, 0, nil)
|
since := types.NewStreamToken(0, 0, nil)
|
||||||
if req.since != nil {
|
if req.since != nil {
|
||||||
|
|
@ -213,17 +214,21 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
|
||||||
res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState)
|
res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead
|
accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead
|
||||||
res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter)
|
res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return res, err
|
||||||
}
|
}
|
||||||
res, err = rp.appendDeviceLists(res, req.device.UserID, since, latestPos)
|
res, err = rp.appendDeviceLists(res, req.device.UserID, since, latestPos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return res, err
|
||||||
|
}
|
||||||
|
err = internal.DeviceOTKCounts(req.ctx, rp.keyAPI, req.device.UserID, req.device.ID, res)
|
||||||
|
if err != nil {
|
||||||
|
return res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Before we return the sync response, make sure that we take action on
|
// Before we return the sync response, make sure that we take action on
|
||||||
|
|
@ -233,7 +238,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
|
||||||
// Handle the updates and deletions in the database.
|
// Handle the updates and deletions in the database.
|
||||||
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, since)
|
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, since)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return res, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(events) > 0 {
|
if len(events) > 0 {
|
||||||
|
|
@ -250,7 +255,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *RequestPool) appendDeviceLists(
|
func (rp *RequestPool) appendDeviceLists(
|
||||||
|
|
|
||||||
|
|
@ -64,8 +64,16 @@ func AddPublicRoutes(
|
||||||
|
|
||||||
requestPool := sync.NewRequestPool(syncDB, notifier, userAPI, keyAPI, currentStateAPI)
|
requestPool := sync.NewRequestPool(syncDB, notifier, userAPI, keyAPI, currentStateAPI)
|
||||||
|
|
||||||
|
keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer(
|
||||||
|
cfg.Matrix.ServerName, string(cfg.Matrix.Kafka.Topics.OutputKeyChangeEvent),
|
||||||
|
consumer, notifier, keyAPI, currentStateAPI, syncDB,
|
||||||
|
)
|
||||||
|
if err = keyChangeConsumer.Start(); err != nil {
|
||||||
|
logrus.WithError(err).Panicf("failed to start key change consumer")
|
||||||
|
}
|
||||||
|
|
||||||
roomConsumer := consumers.NewOutputRoomEventConsumer(
|
roomConsumer := consumers.NewOutputRoomEventConsumer(
|
||||||
cfg, consumer, notifier, syncDB, rsAPI,
|
cfg, consumer, notifier, syncDB, rsAPI, keyChangeConsumer,
|
||||||
)
|
)
|
||||||
if err = roomConsumer.Start(); err != nil {
|
if err = roomConsumer.Start(); err != nil {
|
||||||
logrus.WithError(err).Panicf("failed to start room server consumer")
|
logrus.WithError(err).Panicf("failed to start room server consumer")
|
||||||
|
|
@ -92,13 +100,5 @@ func AddPublicRoutes(
|
||||||
logrus.WithError(err).Panicf("failed to start send-to-device consumer")
|
logrus.WithError(err).Panicf("failed to start send-to-device consumer")
|
||||||
}
|
}
|
||||||
|
|
||||||
keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer(
|
|
||||||
cfg.Matrix.ServerName, string(cfg.Matrix.Kafka.Topics.OutputKeyChangeEvent),
|
|
||||||
consumer, notifier, keyAPI, currentStateAPI, syncDB,
|
|
||||||
)
|
|
||||||
if err = keyChangeConsumer.Start(); err != nil {
|
|
||||||
logrus.WithError(err).Panicf("failed to start key change consumer")
|
|
||||||
}
|
|
||||||
|
|
||||||
routing.Setup(router, requestPool, syncDB, userAPI, federation, rsAPI, cfg)
|
routing.Setup(router, requestPool, syncDB, userAPI, federation, rsAPI, cfg)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -393,6 +393,7 @@ type Response struct {
|
||||||
Changed []string `json:"changed,omitempty"`
|
Changed []string `json:"changed,omitempty"`
|
||||||
Left []string `json:"left,omitempty"`
|
Left []string `json:"left,omitempty"`
|
||||||
} `json:"device_lists,omitempty"`
|
} `json:"device_lists,omitempty"`
|
||||||
|
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResponse creates an empty response with initialised maps.
|
// NewResponse creates an empty response with initialised maps.
|
||||||
|
|
@ -411,6 +412,7 @@ func NewResponse() *Response {
|
||||||
res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0)
|
res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0)
|
||||||
res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0)
|
res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0)
|
||||||
res.ToDevice.Events = make([]gomatrixserverlib.SendToDeviceEvent, 0)
|
res.ToDevice.Events = make([]gomatrixserverlib.SendToDeviceEvent, 0)
|
||||||
|
res.DeviceListsOTKCount = make(map[string]int)
|
||||||
|
|
||||||
return &res
|
return &res
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -110,6 +110,7 @@ Rooms a user is invited to appear in an incremental sync
|
||||||
Sync can be polled for updates
|
Sync can be polled for updates
|
||||||
Sync is woken up for leaves
|
Sync is woken up for leaves
|
||||||
Newly left rooms appear in the leave section of incremental sync
|
Newly left rooms appear in the leave section of incremental sync
|
||||||
|
Rooms can be created with an initial invite list (SYN-205)
|
||||||
We should see our own leave event, even if history_visibility is restricted (SYN-662)
|
We should see our own leave event, even if history_visibility is restricted (SYN-662)
|
||||||
We should see our own leave event when rejecting an invite, even if history_visibility is restricted (riot-web/3462)
|
We should see our own leave event when rejecting an invite, even if history_visibility is restricted (riot-web/3462)
|
||||||
Newly left rooms appear in the leave section of gapped sync
|
Newly left rooms appear in the leave section of gapped sync
|
||||||
|
|
@ -129,6 +130,11 @@ Can claim one time key using POST
|
||||||
Can claim remote one time key using POST
|
Can claim remote one time key using POST
|
||||||
Local device key changes appear in v2 /sync
|
Local device key changes appear in v2 /sync
|
||||||
Local device key changes appear in /keys/changes
|
Local device key changes appear in /keys/changes
|
||||||
|
New users appear in /keys/changes
|
||||||
|
Local delete device changes appear in v2 /sync
|
||||||
|
Local new device changes appear in v2 /sync
|
||||||
|
Local update device changes appear in v2 /sync
|
||||||
|
Users receive device_list updates for their own devices
|
||||||
Get left notifs for other users in sync and /keys/changes when user leaves
|
Get left notifs for other users in sync and /keys/changes when user leaves
|
||||||
Can add account data
|
Can add account data
|
||||||
Can add account data to room
|
Can add account data to room
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,8 @@ type UserInternalAPI interface {
|
||||||
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
|
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
|
||||||
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
|
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
|
||||||
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
|
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
|
||||||
|
PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error
|
||||||
|
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
|
||||||
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
||||||
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
|
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
|
||||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||||
|
|
@ -47,6 +49,25 @@ type InputAccountDataRequest struct {
|
||||||
type InputAccountDataResponse struct {
|
type InputAccountDataResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PerformDeviceUpdateRequest struct {
|
||||||
|
RequestingUserID string
|
||||||
|
DeviceID string
|
||||||
|
DisplayName *string
|
||||||
|
}
|
||||||
|
type PerformDeviceUpdateResponse struct {
|
||||||
|
DeviceExists bool
|
||||||
|
Forbidden bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type PerformDeviceDeletionRequest struct {
|
||||||
|
UserID string
|
||||||
|
// The devices to delete
|
||||||
|
DeviceIDs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type PerformDeviceDeletionResponse struct {
|
||||||
|
}
|
||||||
|
|
||||||
// QueryDeviceInfosRequest is the request to QueryDeviceInfos
|
// QueryDeviceInfosRequest is the request to QueryDeviceInfos
|
||||||
type QueryDeviceInfosRequest struct {
|
type QueryDeviceInfosRequest struct {
|
||||||
DeviceIDs []string
|
DeviceIDs []string
|
||||||
|
|
|
||||||
|
|
@ -25,10 +25,12 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
"github.com/matrix-org/dendrite/internal/config"
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/devices"
|
"github.com/matrix-org/dendrite/userapi/storage/devices"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserInternalAPI struct {
|
type UserInternalAPI struct {
|
||||||
|
|
@ -37,6 +39,7 @@ type UserInternalAPI struct {
|
||||||
ServerName gomatrixserverlib.ServerName
|
ServerName gomatrixserverlib.ServerName
|
||||||
// AppServices is the list of all registered AS
|
// AppServices is the list of all registered AS
|
||||||
AppServices []config.ApplicationService
|
AppServices []config.ApplicationService
|
||||||
|
KeyAPI keyapi.KeyInternalAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
|
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
|
||||||
|
|
@ -101,6 +104,76 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
|
||||||
}
|
}
|
||||||
res.DeviceCreated = true
|
res.DeviceCreated = true
|
||||||
res.Device = dev
|
res.Device = dev
|
||||||
|
// create empty device keys and upload them to trigger device list changes
|
||||||
|
return a.deviceListUpdate(dev.UserID, []string{dev.ID})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.PerformDeviceDeletionRequest, res *api.PerformDeviceDeletionResponse) error {
|
||||||
|
util.GetLogger(ctx).WithField("user_id", req.UserID).WithField("devices", req.DeviceIDs).Info("PerformDeviceDeletion")
|
||||||
|
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if domain != a.ServerName {
|
||||||
|
return fmt.Errorf("cannot PerformDeviceDeletion of remote users: got %s want %s", domain, a.ServerName)
|
||||||
|
}
|
||||||
|
err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// create empty device keys and upload them to delete what was once there and trigger device list changes
|
||||||
|
return a.deviceListUpdate(req.UserID, req.DeviceIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error {
|
||||||
|
deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs))
|
||||||
|
for i, did := range deviceIDs {
|
||||||
|
deviceKeys[i] = keyapi.DeviceKeys{
|
||||||
|
UserID: userID,
|
||||||
|
DeviceID: did,
|
||||||
|
KeyJSON: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var uploadRes keyapi.PerformUploadKeysResponse
|
||||||
|
a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
|
||||||
|
DeviceKeys: deviceKeys,
|
||||||
|
}, &uploadRes)
|
||||||
|
if uploadRes.Error != nil {
|
||||||
|
return fmt.Errorf("Failed to delete device keys: %v", uploadRes.Error)
|
||||||
|
}
|
||||||
|
if len(uploadRes.KeyErrors) > 0 {
|
||||||
|
return fmt.Errorf("Failed to delete device keys, key errors: %+v", uploadRes.KeyErrors)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
|
||||||
|
localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
res.DeviceExists = false
|
||||||
|
return nil
|
||||||
|
} else if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.DeviceExists = true
|
||||||
|
|
||||||
|
if dev.UserID != req.RequestingUserID {
|
||||||
|
res.Forbidden = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
|
||||||
|
return err
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,8 @@ const (
|
||||||
|
|
||||||
PerformDeviceCreationPath = "/userapi/performDeviceCreation"
|
PerformDeviceCreationPath = "/userapi/performDeviceCreation"
|
||||||
PerformAccountCreationPath = "/userapi/performAccountCreation"
|
PerformAccountCreationPath = "/userapi/performAccountCreation"
|
||||||
|
PerformDeviceDeletionPath = "/userapi/performDeviceDeletion"
|
||||||
|
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
|
||||||
|
|
||||||
QueryProfilePath = "/userapi/queryProfile"
|
QueryProfilePath = "/userapi/queryProfile"
|
||||||
QueryAccessTokenPath = "/userapi/queryAccessToken"
|
QueryAccessTokenPath = "/userapi/queryAccessToken"
|
||||||
|
|
@ -91,6 +93,26 @@ func (h *httpUserInternalAPI) PerformDeviceCreation(
|
||||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) PerformDeviceDeletion(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.PerformDeviceDeletionRequest,
|
||||||
|
response *api.PerformDeviceDeletionResponse,
|
||||||
|
) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceDeletion")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + PerformDeviceDeletionPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceUpdate")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + PerformDeviceUpdatePath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *httpUserInternalAPI) QueryProfile(
|
func (h *httpUserInternalAPI) QueryProfile(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *api.QueryProfileRequest,
|
request *api.QueryProfileRequest,
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,32 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(PerformDeviceUpdatePath,
|
||||||
|
httputil.MakeInternalAPI("performDeviceUpdate", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.PerformDeviceUpdateRequest{}
|
||||||
|
response := api.PerformDeviceUpdateResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := s.PerformDeviceUpdate(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(PerformDeviceDeletionPath,
|
||||||
|
httputil.MakeInternalAPI("performDeviceDeletion", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.PerformDeviceDeletionRequest{}
|
||||||
|
response := api.PerformDeviceDeletionResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := s.PerformDeviceDeletion(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
internalAPIMux.Handle(QueryProfilePath,
|
internalAPIMux.Handle(QueryProfilePath,
|
||||||
httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse {
|
httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse {
|
||||||
request := api.QueryProfileRequest{}
|
request := api.QueryProfileRequest{}
|
||||||
|
|
|
||||||
|
|
@ -174,7 +174,7 @@ func (s *devicesStatements) deleteDevice(
|
||||||
func (s *devicesStatements) deleteDevices(
|
func (s *devicesStatements) deleteDevices(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||||
) error {
|
) error {
|
||||||
orig := strings.Replace(deleteDevicesSQL, "($1)", sqlutil.QueryVariadic(len(devices)), 1)
|
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
|
||||||
prep, err := s.db.Prepare(orig)
|
prep, err := s.db.Prepare(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -186,7 +186,6 @@ func (s *devicesStatements) deleteDevices(
|
||||||
for i, v := range devices {
|
for i, v := range devices {
|
||||||
params[i+1] = v
|
params[i+1] = v
|
||||||
}
|
}
|
||||||
params = append(params, params...)
|
|
||||||
_, err = stmt.ExecContext(ctx, params...)
|
_, err = stmt.ExecContext(ctx, params...)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ package userapi
|
||||||
import (
|
import (
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/matrix-org/dendrite/internal/config"
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/internal"
|
"github.com/matrix-org/dendrite/userapi/internal"
|
||||||
"github.com/matrix-org/dendrite/userapi/inthttp"
|
"github.com/matrix-org/dendrite/userapi/inthttp"
|
||||||
|
|
@ -34,12 +35,13 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
|
||||||
// NewInternalAPI returns a concerete implementation of the internal API. Callers
|
// NewInternalAPI returns a concerete implementation of the internal API. Callers
|
||||||
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
||||||
func NewInternalAPI(accountDB accounts.Database, deviceDB devices.Database,
|
func NewInternalAPI(accountDB accounts.Database, deviceDB devices.Database,
|
||||||
serverName gomatrixserverlib.ServerName, appServices []config.ApplicationService) api.UserInternalAPI {
|
serverName gomatrixserverlib.ServerName, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI) api.UserInternalAPI {
|
||||||
|
|
||||||
return &internal.UserInternalAPI{
|
return &internal.UserInternalAPI{
|
||||||
AccountDB: accountDB,
|
AccountDB: accountDB,
|
||||||
DeviceDB: deviceDB,
|
DeviceDB: deviceDB,
|
||||||
ServerName: serverName,
|
ServerName: serverName,
|
||||||
AppServices: appServices,
|
AppServices: appServices,
|
||||||
|
KeyAPI: keyAPI,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database,
|
||||||
t.Fatalf("failed to create device DB: %s", err)
|
t.Fatalf("failed to create device DB: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return userapi.NewInternalAPI(accountDB, deviceDB, serverName, nil), accountDB, deviceDB
|
return userapi.NewInternalAPI(accountDB, deviceDB, serverName, nil, nil), accountDB, deviceDB
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryProfile(t *testing.T) {
|
func TestQueryProfile(t *testing.T) {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue