Merge branch 'master' into neilalexander/config

This commit is contained in:
Neil Alexander 2020-08-04 09:43:02 +01:00
commit cc1d01cd28
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
40 changed files with 751 additions and 186 deletions

View file

@ -118,7 +118,9 @@ func (m *DendriteMonolith) Start() {
serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing()
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices)
keyAPI := keyserver.NewInternalAPI(base.Cfg, federation, base.KafkaProducer)
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices, keyAPI)
keyAPI.SetUserAPI(userAPI)
rsAPI := roomserver.NewInternalAPI(
base, keyRing, federation,
@ -156,7 +158,11 @@ func (m *DendriteMonolith) Start() {
RoomserverAPI: rsAPI,
UserAPI: userAPI,
StateAPI: stateAPI,
<<<<<<< HEAD
KeyAPI: keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, userAPI, base.KafkaProducer),
=======
KeyAPI: keyAPI,
>>>>>>> master
ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider(
ygg, fsAPI, federation,
),

View file

@ -115,33 +115,9 @@ func GetDevicesByLocalpart(
// UpdateDeviceByID handles PUT on /devices/{deviceID}
func UpdateDeviceByID(
req *http.Request, deviceDB devices.Database, device *api.Device,
req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
deviceID string,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
ctx := req.Context()
dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID)
if err == sql.ErrNoRows {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Unknown device"),
}
} else if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByID failed")
return jsonerror.InternalServerError()
}
if dev.UserID != device.UserID {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("device not owned by current user"),
}
}
defer req.Body.Close() // nolint: errcheck
@ -152,10 +128,28 @@ func UpdateDeviceByID(
return jsonerror.InternalServerError()
}
if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("deviceDB.UpdateDevice failed")
var performRes api.PerformDeviceUpdateResponse
err := userAPI.PerformDeviceUpdate(req.Context(), &api.PerformDeviceUpdateRequest{
RequestingUserID: device.UserID,
DeviceID: deviceID,
DisplayName: payload.DisplayName,
}, &performRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceUpdate failed")
return jsonerror.InternalServerError()
}
if !performRes.DeviceExists {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.Forbidden("device does not exist"),
}
}
if performRes.Forbidden {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("device not owned by current user"),
}
}
return util.JSONResponse{
Code: http.StatusOK,
@ -165,7 +159,7 @@ func UpdateDeviceByID(
// DeleteDeviceById handles DELETE requests to /devices/{deviceId}
func DeleteDeviceById(
req *http.Request, userInteractiveAuth *auth.UserInteractive, deviceDB devices.Database, device *api.Device,
req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.UserInternalAPI, device *api.Device,
deviceID string,
) util.JSONResponse {
ctx := req.Context()
@ -197,8 +191,12 @@ func DeleteDeviceById(
}
}
if err := deviceDB.RemoveDevice(ctx, deviceID, localpart); err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.RemoveDevice failed")
var res api.PerformDeviceDeletionResponse
if err := userAPI.PerformDeviceDeletion(ctx, &api.PerformDeviceDeletionRequest{
UserID: device.UserID,
DeviceIDs: []string{deviceID},
}, &res); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.PerformDeviceDeletion failed")
return jsonerror.InternalServerError()
}
@ -210,26 +208,24 @@ func DeleteDeviceById(
// DeleteDevices handles POST requests to /delete_devices
func DeleteDevices(
req *http.Request, deviceDB devices.Database, device *api.Device,
req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
ctx := req.Context()
payload := devicesDeleteJSON{}
if err := json.NewDecoder(req.Body).Decode(&payload); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("json.NewDecoder.Decode failed")
util.GetLogger(ctx).WithError(err).Error("json.NewDecoder.Decode failed")
return jsonerror.InternalServerError()
}
defer req.Body.Close() // nolint: errcheck
if err := deviceDB.RemoveDevices(ctx, localpart, payload.Devices); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveDevices failed")
var res api.PerformDeviceDeletionResponse
if err := userAPI.PerformDeviceDeletion(ctx, &api.PerformDeviceDeletionRequest{
UserID: device.UserID,
DeviceIDs: payload.Devices,
}, &res); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.PerformDeviceDeletion failed")
return jsonerror.InternalServerError()
}

View file

@ -23,8 +23,8 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/dendrite/userapi/storage/devices"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@ -57,7 +57,7 @@ func passwordLogin() flows {
// Login implements GET and POST /login
func Login(
req *http.Request, accountDB accounts.Database, deviceDB devices.Database,
req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI,
cfg *config.ClientAPI,
) util.JSONResponse {
if req.Method == http.MethodGet {
@ -81,7 +81,7 @@ func Login(
return *authErr
}
// make a device/access token
return completeAuth(req.Context(), cfg.Matrix.ServerName, deviceDB, login)
return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login)
}
return util.JSONResponse{
Code: http.StatusMethodNotAllowed,
@ -90,7 +90,7 @@ func Login(
}
func completeAuth(
ctx context.Context, serverName gomatrixserverlib.ServerName, deviceDB devices.Database, login *auth.Login,
ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.UserInternalAPI, login *auth.Login,
) util.JSONResponse {
token, err := auth.GenerateAccessToken()
if err != nil {
@ -104,9 +104,13 @@ func completeAuth(
return jsonerror.InternalServerError()
}
dev, err := deviceDB.CreateDevice(
ctx, localpart, login.DeviceID, token, login.InitialDisplayName,
)
var performRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{
DeviceDisplayName: login.InitialDisplayName,
DeviceID: login.DeviceID,
AccessToken: token,
Localpart: localpart,
}, &performRes)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
@ -117,10 +121,10 @@ func completeAuth(
return util.JSONResponse{
Code: http.StatusOK,
JSON: loginResponse{
UserID: dev.UserID,
AccessToken: dev.AccessToken,
UserID: performRes.Device.UserID,
AccessToken: performRes.Device.AccessToken,
HomeServer: serverName,
DeviceID: dev.ID,
DeviceID: performRes.Device.ID,
},
}
}

View file

@ -387,7 +387,7 @@ func Setup(
r0mux.Handle("/login",
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
return Login(req, accountDB, deviceDB, cfg)
return Login(req, accountDB, userAPI, cfg)
}),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
@ -644,7 +644,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return UpdateDeviceByID(req, deviceDB, device, vars["deviceID"])
return UpdateDeviceByID(req, userAPI, device, vars["deviceID"])
}),
).Methods(http.MethodPut, http.MethodOptions)
@ -654,13 +654,13 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return DeleteDeviceById(req, userInteractiveAuth, deviceDB, device, vars["deviceID"])
return DeleteDeviceById(req, userInteractiveAuth, userAPI, device, vars["deviceID"])
}),
).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/delete_devices",
httputil.MakeAuthAPI("delete_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return DeleteDevices(req, deviceDB, device)
return DeleteDevices(req, userAPI, device)
}),
).Methods(http.MethodPost, http.MethodOptions)

View file

@ -144,7 +144,9 @@ func main() {
accountDB := base.Base.CreateAccountsDB()
deviceDB := base.Base.CreateDeviceDB()
federation := createFederationClient(base)
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil)
keyAPI := keyserver.NewInternalAPI(&base.Base.Cfg.KeyServer, federation, base.Base.KafkaProducer)
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil, keyAPI)
keyAPI.SetUserAPI(userAPI)
serverKeyAPI := serverkeyapi.NewInternalAPI(
&base.Base.Cfg.ServerKeyAPI, federation, base.Base.Caches,
@ -189,7 +191,7 @@ func main() {
ServerKeyAPI: serverKeyAPI,
StateAPI: stateAPI,
UserAPI: userAPI,
KeyAPI: keyserver.NewInternalAPI(&base.Base.Cfg.KeyServer, federation, userAPI, base.Base.KafkaProducer),
KeyAPI: keyAPI,
ExtPublicRoomsProvider: provider,
}
monolith.AddAllPublicRoutes(base.Base.PublicAPIMux)

View file

@ -105,7 +105,9 @@ func main() {
serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing()
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil)
keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer)
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil, keyAPI)
keyAPI.SetUserAPI(userAPI)
rsComponent := roomserver.NewInternalAPI(
base, keyRing, federation,
@ -144,8 +146,7 @@ func main() {
RoomserverAPI: rsAPI,
UserAPI: userAPI,
StateAPI: stateAPI,
KeyAPI: keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, userAPI, base.KafkaProducer),
//ServerKeyAPI: serverKeyAPI,
KeyAPI: keyAPI,
ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider(
ygg, fsAPI, federation,
),

View file

@ -24,7 +24,8 @@ func main() {
base := setup.NewBaseDendrite(cfg, "KeyServer", true)
defer base.Close() // nolint: errcheck
intAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, base.CreateFederationClient(), base.UserAPIClient(), base.KafkaProducer)
intAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, base.CreateFederationClient(), base.KafkaProducer)
intAPI.SetUserAPI(base.UserAPIClient())
keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI)

View file

@ -78,7 +78,9 @@ func main() {
serverKeyAPI = base.ServerKeyAPIClient()
}
keyRing := serverKeyAPI.KeyRing()
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices)
keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer)
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices, keyAPI)
keyAPI.SetUserAPI(userAPI)
rsImpl := roomserver.NewInternalAPI(
base, keyRing, federation,
@ -121,7 +123,6 @@ func main() {
rsImpl.SetFederationSenderAPI(fsAPI)
stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer)
keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, userAPI, base.KafkaProducer)
monolith := setup.Monolith{
Config: base.Cfg,

View file

@ -27,7 +27,7 @@ func main() {
accountDB := base.CreateAccountsDB()
deviceDB := base.CreateDeviceDB()
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices)
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient())
userapi.AddInternalRoutes(base.InternalAPIMux, userAPI)

View file

@ -196,7 +196,9 @@ func main() {
accountDB := base.CreateAccountsDB()
deviceDB := base.CreateDeviceDB()
federation := createFederationClient(cfg, node)
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil)
keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer)
userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil, keyAPI)
keyAPI.SetUserAPI(userAPI)
fetcher := &libp2pKeyFetcher{}
keyRing := gomatrixserverlib.KeyRing{
@ -233,7 +235,7 @@ func main() {
RoomserverAPI: rsAPI,
StateAPI: stateAPI,
UserAPI: userAPI,
KeyAPI: keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, userAPI, base.KafkaProducer),
KeyAPI: keyAPI,
//ServerKeyAPI: serverKeyAPI,
ExtPublicRoomsProvider: p2pPublicRoomProvider,
}

View file

@ -20,6 +20,7 @@ import (
"crypto/ed25519"
"encoding/json"
"net/http"
"os"
"reflect"
"testing"
"time"
@ -92,13 +93,14 @@ func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *rooms
return nil
}
func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.SyncProducer) {
func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.SyncProducer, func()) {
cfg := &config.Dendrite{}
cfg.Defaults()
stateDBName := "test_state.db"
naffkaDBName := "test_naffka.db"
cfg.Global.ServerName = "kaer.morhen"
cfg.Global.Kafka.Topics.OutputRoomEvent = config.Topic(kafkaTopic)
cfg.CurrentStateServer.Database.ConnectionString = config.DataSource("file::memory:")
db, err := sqlutil.Open(&cfg.CurrentStateServer.Database)
cfg.CurrentStateServer.Database.ConnectionString = config.DataSource("file:" + stateDBName)
db, err := sqlutil.Open(sqlutil.SQLiteDriverName(), "file:"+naffkaDBName, nil)
if err != nil {
t.Fatalf("Failed to open naffka database: %s", err)
}
@ -110,11 +112,15 @@ func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.Sync
if err != nil {
t.Fatalf("Failed to create naffka consumer: %s", err)
}
return NewInternalAPI(&cfg.CurrentStateServer, naff), naff
return NewInternalAPI(cfg, naff), naff, func() {
os.Remove(naffkaDBName)
os.Remove(stateDBName)
}
}
func TestQueryCurrentState(t *testing.T) {
currStateAPI, producer := MustMakeInternalAPI(t)
currStateAPI, producer, cancel := MustMakeInternalAPI(t)
defer cancel()
plTuple := gomatrixserverlib.StateKeyTuple{
EventType: "m.room.power_levels",
StateKey: "",
@ -217,7 +223,8 @@ func mustMakeMembershipEvent(t *testing.T, roomID, userID, membership string) *r
// This test makes sure that QuerySharedUsers is returning the correct users for a range of sets.
func TestQuerySharedUsers(t *testing.T) {
currStateAPI, producer := MustMakeInternalAPI(t)
currStateAPI, producer, cancel := MustMakeInternalAPI(t)
defer cancel()
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join"))
@ -230,6 +237,9 @@ func TestQuerySharedUsers(t *testing.T) {
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join"))
// we don't know when the server has processed the events
time.Sleep(10 * time.Millisecond)
testCases := []struct {
req api.QuerySharedUsersRequest
wantRes api.QuerySharedUsersResponse

View file

@ -19,14 +19,19 @@ import (
"encoding/json"
"strings"
"time"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
type KeyInternalAPI interface {
// SetUserAPI assigns a user API to query when extracting device names.
SetUserAPI(i userapi.UserInternalAPI)
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse)
// PerformClaimKeys claims one-time keys for use in pre-key messages
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse)
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse)
}
// KeyError is returned if there was a problem performing/querying the server
@ -38,6 +43,13 @@ func (k *KeyError) Error() string {
return k.Err
}
// DeviceMessage represents the message produced into Kafka by the key server.
type DeviceMessage struct {
DeviceKeys
// A monotonically increasing number which represents device changes for this user.
StreamID int
}
// DeviceKeys represents a set of device keys for a single device
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
type DeviceKeys struct {
@ -45,10 +57,20 @@ type DeviceKeys struct {
UserID string
// The device ID of this device
DeviceID string
// The device display name
DisplayName string
// The raw device key JSON
KeyJSON []byte
}
// WithStreamID returns a copy of this device message with the given stream ID
func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage {
return DeviceMessage{
DeviceKeys: *k,
StreamID: streamID,
}
}
// OneTimeKeys represents a set of one-time keys for a single device
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
type OneTimeKeys struct {
@ -153,3 +175,16 @@ type QueryKeyChangesResponse struct {
// Set if there was a problem handling the request.
Error *KeyError
}
type QueryOneTimeKeysRequest struct {
// The local user to query OTK counts for
UserID string
// The device to query OTK counts for
DeviceID string
}
type QueryOneTimeKeysResponse struct {
// OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84
Count OneTimeKeysCount
Error *KeyError
}

View file

@ -40,6 +40,10 @@ type KeyInternalAPI struct {
Producer *producers.KeyChange
}
func (a *KeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) {
a.UserAPI = i
}
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
if req.Partition < 0 {
req.Partition = a.Producer.DefaultPartition()
@ -57,7 +61,7 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
res.KeyErrors = make(map[string]map[string]*api.KeyError)
a.uploadDeviceKeys(ctx, req, res)
a.uploadLocalDeviceKeys(ctx, req, res)
a.uploadOneTimeKeys(ctx, req, res)
}
@ -164,6 +168,17 @@ func (a *KeyInternalAPI) claimRemoteKeys(
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
}
func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) {
count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
}
return
}
res.Count = *count
}
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.Failures = make(map[string]interface{})
@ -202,6 +217,9 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
}
for _, dk := range deviceKeys {
if len(dk.KeyJSON) == 0 {
continue // don't include blank keys
}
// inject display name if known
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"`
@ -268,14 +286,25 @@ func (a *KeyInternalAPI) queryRemoteKeys(
}
}
func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
var keysToStore []api.DeviceKeys
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
var keysToStore []api.DeviceMessage
// assert that the user ID / device ID are not lying for each key
for _, key := range req.DeviceKeys {
_, serverName, err := gomatrixserverlib.SplitID('@', key.UserID)
if err != nil {
continue // ignore invalid users
}
if serverName != a.ThisServer {
continue // ignore remote users
}
if len(key.KeyJSON) == 0 {
keysToStore = append(keysToStore, key.WithStreamID(0))
continue // deleted keys don't need sanity checking
}
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
keysToStore = append(keysToStore, key)
keysToStore = append(keysToStore, key.WithStreamID(0))
continue
}
@ -286,12 +315,15 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU
),
})
}
// get existing device keys so we can check for changes
existingKeys := make([]api.DeviceKeys, len(keysToStore))
existingKeys := make([]api.DeviceMessage, len(keysToStore))
for i := range keysToStore {
existingKeys[i] = api.DeviceKeys{
existingKeys[i] = api.DeviceMessage{
DeviceKeys: api.DeviceKeys{
UserID: keysToStore[i].UserID,
DeviceID: keysToStore[i].DeviceID,
},
}
}
if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
@ -301,13 +333,14 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU
return
}
// store the device keys and emit changes
if err := a.DB.StoreDeviceKeys(ctx, keysToStore); err != nil {
err := a.DB.StoreDeviceKeys(ctx, keysToStore)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
}
return
}
err := a.emitDeviceKeyChanges(existingKeys, keysToStore)
err = a.emitDeviceKeyChanges(existingKeys, keysToStore)
if err != nil {
util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
}
@ -352,13 +385,15 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
}
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) error {
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage) error {
// find keys in new that are not in existing
var keysAdded []api.DeviceKeys
var keysAdded []api.DeviceMessage
for _, newKey := range new {
exists := false
for _, existingKey := range existing {
if bytes.Equal(existingKey.KeyJSON, newKey.KeyJSON) {
// Do not treat the absence of keys as equal, or else we will not emit key changes
// when users delete devices which never had a key to begin with as both KeyJSONs are nil.
if bytes.Equal(existingKey.KeyJSON, newKey.KeyJSON) && len(existingKey.KeyJSON) > 0 {
exists = true
break
}

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/opentracing/opentracing-go"
)
@ -30,6 +31,7 @@ const (
PerformClaimKeysPath = "/keyserver/performClaimKeys"
QueryKeysPath = "/keyserver/queryKeys"
QueryKeyChangesPath = "/keyserver/queryKeyChanges"
QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys"
)
// NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API.
@ -52,6 +54,10 @@ type httpKeyInternalAPI struct {
httpClient *http.Client
}
func (h *httpKeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) {
// no-op: doesn't need it
}
func (h *httpKeyInternalAPI) PerformClaimKeys(
ctx context.Context,
request *api.PerformClaimKeysRequest,
@ -103,6 +109,23 @@ func (h *httpKeyInternalAPI) QueryKeys(
}
}
func (h *httpKeyInternalAPI) QueryOneTimeKeys(
ctx context.Context,
request *api.QueryOneTimeKeysRequest,
response *api.QueryOneTimeKeysResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOneTimeKeys")
defer span.Finish()
apiURL := h.apiURL + QueryOneTimeKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
}
func (h *httpKeyInternalAPI) QueryKeyChanges(
ctx context.Context,
request *api.QueryKeyChangesRequest,

View file

@ -58,6 +58,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryOneTimeKeysPath,
httputil.MakeInternalAPI("queryOneTimeKeys", func(req *http.Request) util.JSONResponse {
request := api.QueryOneTimeKeysRequest{}
response := api.QueryOneTimeKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QueryOneTimeKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryKeyChangesPath,
httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse {
request := api.QueryKeyChangesRequest{}

View file

@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/keyserver/inthttp"
"github.com/matrix-org/dendrite/keyserver/producers"
"github.com/matrix-org/dendrite/keyserver/storage"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
@ -37,7 +36,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
cfg *config.KeyServer, fedClient *gomatrixserverlib.FederationClient, userAPI userapi.UserInternalAPI, producer sarama.SyncProducer,
cfg *config.KeyServer, fedClient *gomatrixserverlib.FederationClient, producer sarama.SyncProducer,
) api.KeyInternalAPI {
db, err := storage.NewDatabase(&cfg.Database)
if err != nil {
@ -52,7 +51,6 @@ func NewInternalAPI(
DB: db,
ThisServer: cfg.Matrix.ServerName,
FedClient: fedClient,
UserAPI: userAPI,
Producer: keyChangeProducer,
}
}

View file

@ -41,7 +41,7 @@ func (p *KeyChange) DefaultPartition() int32 {
}
// ProduceKeyChanges creates new change events for each key
func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error {
func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
for _, key := range keys {
var m sarama.ProducerMessage
@ -67,6 +67,7 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error {
"device_id": key.DeviceID,
"partition": partition,
"offset": offset,
"len_key_bytes": len(key.KeyJSON),
}).Infof("Produced to key change topic '%s'", p.Topic)
}
return nil

View file

@ -29,16 +29,21 @@ type Database interface {
// StoreOneTimeKeys persists the given one-time keys.
StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error
// OneTimeKeysCount returns a count of all OTKs for this device.
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
// StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced.
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
// StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device).
// The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
// Returns an error if there was a problem storing the keys.
StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.

View file

@ -20,7 +20,6 @@ import (
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
@ -32,28 +31,37 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
device_id TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
-- the stream ID of this key, scoped per-user. This gets updated when the device key changes.
-- This means we do not store an unbounded append-only log of device keys, which is not actually
-- required in the spec because in the event of a missed update the server fetches the entire
-- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
stream_id BIGINT NOT NULL,
-- Clobber based on tuple of user/device.
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
);
`
const upsertDeviceKeysSQL = "" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" +
" VALUES ($1, $2, $3, $4)" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
" VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
" DO UPDATE SET key_json = $4"
" DO UPDATE SET key_json = $4, stream_id = $5"
const selectDeviceKeysSQL = "" +
"SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
"SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1"
"SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
db *sql.DB
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
}
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -73,38 +81,54 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys {
var keyJSONStr string
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr)
var streamID int
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
if err != nil && err != sql.ErrNoRows {
return err
}
// this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID
}
return nil
}
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error {
now := time.Now().Unix()
return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
// nullable if there are no results
var nullStream sql.NullInt32
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows {
err = nil
}
if nullStream.Valid {
streamID = nullStream.Int32
}
return
}
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
for _, key := range keys {
now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON),
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
)
if err != nil {
return err
}
}
return nil
})
}
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
@ -114,15 +138,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
for _, d := range deviceIDs {
deviceIDMap[d] = true
}
var result []api.DeviceKeys
var result []api.DeviceMessage
for rows.Next() {
var dk api.DeviceKeys
var dk api.DeviceMessage
dk.UserID = userID
var keyJSON string
if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil {
var streamID int
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
return nil, err
}
dk.KeyJSON = []byte(keyJSON)
dk.StreamID = streamID
// include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk)

View file

@ -121,6 +121,28 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
return result, rows.Err()
}
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
counts := &api.OneTimeKeysCount{
DeviceID: deviceID,
UserID: userID,
KeyCount: make(map[string]int),
}
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return nil, err
}
counts.KeyCount[algorithm] = count
}
return counts, nil
}
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
now := time.Now().Unix()
counts := &api.OneTimeKeysCount{

View file

@ -39,15 +39,40 @@ func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (
return d.OneTimeKeysTable.InsertOneTimeKeys(ctx, keys)
}
func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
}
func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
}
func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error {
return d.DeviceKeysTable.InsertDeviceKeys(ctx, keys)
func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
// work out the latest stream IDs for each user
userIDToStreamID := make(map[string]int)
for _, k := range keys {
userIDToStreamID[k.UserID] = 0
}
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
for userID := range userIDToStreamID {
streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
if err != nil {
return err
}
userIDToStreamID[userID] = int(streamID)
}
// set the stream IDs for each key
for i := range keys {
k := keys[i]
userIDToStreamID[k.UserID]++ // start stream from 1
k.StreamID = userIDToStreamID[k.UserID]
keys[i] = k
}
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
})
}
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
}

View file

@ -20,7 +20,6 @@ import (
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
@ -32,28 +31,33 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
device_id TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
stream_id BIGINT NOT NULL,
-- Clobber based on tuple of user/device.
UNIQUE (user_id, device_id)
);
`
const upsertDeviceKeysSQL = "" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" +
" VALUES ($1, $2, $3, $4)" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
" VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT (user_id, device_id)" +
" DO UPDATE SET key_json = $4"
" DO UPDATE SET key_json = $4, stream_id = $5"
const selectDeviceKeysSQL = "" +
"SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
"SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1"
"SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
db *sql.DB
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
}
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -73,10 +77,13 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs {
deviceIDMap[d] = true
@ -86,15 +93,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
var result []api.DeviceKeys
var result []api.DeviceMessage
for rows.Next() {
var dk api.DeviceKeys
var dk api.DeviceMessage
dk.UserID = userID
var keyJSON string
if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil {
var streamID int
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
return nil, err
}
dk.KeyJSON = []byte(keyJSON)
dk.StreamID = streamID
// include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk)
@ -103,30 +112,43 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
return result, rows.Err()
}
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys {
var keyJSONStr string
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr)
var streamID int
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
if err != nil && err != sql.ErrNoRows {
return err
}
// this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID
}
return nil
}
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error {
now := time.Now().Unix()
return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
// nullable if there are no results
var nullStream sql.NullInt32
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows {
err = nil
}
if nullStream.Valid {
streamID = nullStream.Int32
}
return
}
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
for _, key := range keys {
now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON),
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
)
if err != nil {
return err
}
}
return nil
})
}

View file

@ -121,6 +121,28 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
return result, rows.Err()
}
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
counts := &api.OneTimeKeysCount{
DeviceID: deviceID,
UserID: userID,
KeyCount: make(map[string]int),
}
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return nil, err
}
counts.KeyCount[algorithm] = count
}
return counts, nil
}
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
now := time.Now().Unix()
counts := &api.OneTimeKeysCount{

View file

@ -7,6 +7,7 @@ import (
"github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/keyserver/api"
)
var ctx = context.Background()
@ -82,3 +83,84 @@ func TestKeyChangesUpperLimit(t *testing.T) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
}
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
// and that they are returned correctly when querying for device keys.
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
db, err := NewDatabase("file::memory:", nil)
if err != nil {
t.Fatalf("Failed to NewDatabase: %s", err)
}
alice := "@alice:TestDeviceKeysStreamIDGeneration"
bob := "@bob:TestDeviceKeysStreamIDGeneration"
msgs := []api.DeviceMessage{
{
DeviceKeys: api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 1
},
{
DeviceKeys: api.DeviceKeys{
DeviceID: "AAA",
UserID: bob,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 1 as this is a different user
},
{
DeviceKeys: api.DeviceKeys{
DeviceID: "another_device",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 2 as this is a 2nd device key
},
}
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 1 {
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
}
if msgs[1].StreamID != 1 {
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
}
if msgs[2].StreamID != 2 {
t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
}
// updating a device sets the next stream ID for that user
msgs = []api.DeviceMessage{
{
DeviceKeys: api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v2"}`),
},
// StreamID: 3
},
}
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 3 {
t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
}
// Querying for device keys returns the latest stream IDs
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"})
if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err)
}
wantStreamIDs := map[string]int{
"AAA": 3,
"another_device": 2,
}
if len(msgs) != len(wantStreamIDs) {
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
}
for _, m := range msgs {
if m.StreamID != wantStreamIDs[m.DeviceID] {
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
}
}
}

View file

@ -24,6 +24,7 @@ import (
type OneTimeKeys interface {
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
// SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
// Returns an empty map if the key does not exist.
@ -31,9 +32,10 @@ type OneTimeKeys interface {
}
type DeviceKeys interface {
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error
InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
}
type KeyChanges interface {

View file

@ -98,7 +98,7 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er
defer func() {
s.updateOffset(msg)
}()
var output api.DeviceKeys
var output api.DeviceMessage
if err := json.Unmarshal(msg.Value, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Error("syncapi: failed to unmarshal key change event from key server")

View file

@ -35,6 +35,7 @@ type OutputRoomEventConsumer struct {
rsConsumer *internal.ContinualConsumer
db storage.Database
notifier *sync.Notifier
keyChanges *OutputKeyChangeEventConsumer
}
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -44,6 +45,7 @@ func NewOutputRoomEventConsumer(
n *sync.Notifier,
store storage.Database,
rsAPI api.RoomserverInternalAPI,
keyChanges *OutputKeyChangeEventConsumer,
) *OutputRoomEventConsumer {
consumer := internal.ContinualConsumer{
@ -56,6 +58,7 @@ func NewOutputRoomEventConsumer(
db: store,
notifier: n,
rsAPI: rsAPI,
keyChanges: keyChanges,
}
consumer.ProcessMessage = s.onMessage
@ -160,9 +163,29 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
}
s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0, nil))
s.notifyKeyChanges(&ev)
return nil
}
func (s *OutputRoomEventConsumer) notifyKeyChanges(ev *gomatrixserverlib.HeaderedEvent) {
if ev.Type() != gomatrixserverlib.MRoomMember || ev.StateKey() == nil {
return
}
membership, err := ev.Membership()
if err != nil {
return
}
switch membership {
case gomatrixserverlib.Join:
s.keyChanges.OnJoinEvent(ev)
case gomatrixserverlib.Ban:
fallthrough
case gomatrixserverlib.Leave:
s.keyChanges.OnLeaveEvent(ev)
}
}
func (s *OutputRoomEventConsumer) onNewInviteEvent(
ctx context.Context, msg api.OutputNewInviteEvent,
) error {

View file

@ -16,6 +16,7 @@ package internal
import (
"context"
"strings"
"github.com/Shopify/sarama"
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
@ -28,6 +29,20 @@ import (
const DeviceListLogName = "dl"
// DeviceOTKCounts adds one-time key counts to the /sync response
func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID, deviceID string, res *types.Response) error {
var queryRes api.QueryOneTimeKeysResponse
keyAPI.QueryOneTimeKeys(ctx, &api.QueryOneTimeKeysRequest{
UserID: userID,
DeviceID: deviceID,
}, &queryRes)
if queryRes.Error != nil {
return queryRes.Error
}
res.DeviceListsOTKCount = queryRes.Count.KeyCount
return nil
}
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
// be already filled in with join/leave information.
@ -35,6 +50,7 @@ func DeviceListCatchup(
ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI,
userID string, res *types.Response, from, to types.StreamingToken,
) (hasNew bool, err error) {
// Track users who we didn't track before but now do by virtue of sharing a room with them, or not.
newlyJoinedRooms := joinedRooms(res, userID)
newlyLeftRooms := leftRooms(res)
@ -88,6 +104,16 @@ func DeviceListCatchup(
if !userSet[userID] {
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
hasNew = true
userSet[userID] = true
}
}
// if the response has any join/leave events, add them now.
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
for _, userID := range membershipEvents(res) {
if !userSet[userID] {
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
hasNew = true
userSet[userID] = true
}
}
return hasNew, nil
@ -219,3 +245,25 @@ func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID strin
}
return false
}
// returns the user IDs of anyone joining or leaving a room in this response. These users will be added to
// the 'changed' property because of https://matrix.org/docs/spec/client_server/r0.6.1#id84
// "For optimal performance, Alice should be added to changed in Bob's sync only when she adds a new device,
// or when Alice and Bob now share a room but didn't share any room previously. However, for the sake of simpler
// logic, a server may add Alice to changed when Alice and Bob share a new room, even if they previously already shared a room."
func membershipEvents(res *types.Response) (userIDs []string) {
for _, room := range res.Rooms.Join {
for _, ev := range room.Timeline.Events {
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil {
if strings.Contains(string(ev.Content), `"join"`) {
userIDs = append(userIDs, *ev.StateKey)
} else if strings.Contains(string(ev.Content), `"leave"`) {
userIDs = append(userIDs, *ev.StateKey)
} else if strings.Contains(string(ev.Content), `"ban"`) {
userIDs = append(userIDs, *ev.StateKey)
}
}
}
}
return
}

View file

@ -10,6 +10,7 @@ import (
"github.com/matrix-org/dendrite/currentstateserver/api"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
@ -29,12 +30,17 @@ type mockKeyAPI struct{}
func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) {
}
func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {}
// PerformClaimKeys claims one-time keys for use in pre-key messages
func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) {
}
func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) {
}
func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) {
}
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) {
}
type mockCurrentStateAPI struct {

View file

@ -168,7 +168,7 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use
}
// work out room joins/leaves
res, err := rp.db.IncrementalSync(
req.Context(), types.NewResponse(), *device, fromToken, toToken, 0, false,
req.Context(), types.NewResponse(), *device, fromToken, toToken, 10, false,
)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("Failed to IncrementalSync")
@ -192,8 +192,9 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use
}
}
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
res = types.NewResponse()
// nolint:gocyclo
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (*types.Response, error) {
res := types.NewResponse()
since := types.NewStreamToken(0, 0, nil)
if req.since != nil {
@ -213,17 +214,21 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState)
}
if err != nil {
return
return res, err
}
accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead
res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter)
if err != nil {
return
return res, err
}
res, err = rp.appendDeviceLists(res, req.device.UserID, since, latestPos)
if err != nil {
return
return res, err
}
err = internal.DeviceOTKCounts(req.ctx, rp.keyAPI, req.device.UserID, req.device.ID, res)
if err != nil {
return res, err
}
// Before we return the sync response, make sure that we take action on
@ -233,7 +238,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
// Handle the updates and deletions in the database.
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, since)
if err != nil {
return
return res, err
}
}
if len(events) > 0 {
@ -250,7 +255,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
}
}
return
return res, err
}
func (rp *RequestPool) appendDeviceLists(

View file

@ -64,8 +64,16 @@ func AddPublicRoutes(
requestPool := sync.NewRequestPool(syncDB, notifier, userAPI, keyAPI, currentStateAPI)
keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer(
cfg.Matrix.ServerName, string(cfg.Matrix.Kafka.Topics.OutputKeyChangeEvent),
consumer, notifier, keyAPI, currentStateAPI, syncDB,
)
if err = keyChangeConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start key change consumer")
}
roomConsumer := consumers.NewOutputRoomEventConsumer(
cfg, consumer, notifier, syncDB, rsAPI,
cfg, consumer, notifier, syncDB, rsAPI, keyChangeConsumer,
)
if err = roomConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start room server consumer")
@ -92,13 +100,5 @@ func AddPublicRoutes(
logrus.WithError(err).Panicf("failed to start send-to-device consumer")
}
keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer(
cfg.Matrix.ServerName, string(cfg.Matrix.Kafka.Topics.OutputKeyChangeEvent),
consumer, notifier, keyAPI, currentStateAPI, syncDB,
)
if err = keyChangeConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start key change consumer")
}
routing.Setup(router, requestPool, syncDB, userAPI, federation, rsAPI, cfg)
}

View file

@ -393,6 +393,7 @@ type Response struct {
Changed []string `json:"changed,omitempty"`
Left []string `json:"left,omitempty"`
} `json:"device_lists,omitempty"`
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count"`
}
// NewResponse creates an empty response with initialised maps.
@ -411,6 +412,7 @@ func NewResponse() *Response {
res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0)
res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0)
res.ToDevice.Events = make([]gomatrixserverlib.SendToDeviceEvent, 0)
res.DeviceListsOTKCount = make(map[string]int)
return &res
}

View file

@ -110,6 +110,7 @@ Rooms a user is invited to appear in an incremental sync
Sync can be polled for updates
Sync is woken up for leaves
Newly left rooms appear in the leave section of incremental sync
Rooms can be created with an initial invite list (SYN-205)
We should see our own leave event, even if history_visibility is restricted (SYN-662)
We should see our own leave event when rejecting an invite, even if history_visibility is restricted (riot-web/3462)
Newly left rooms appear in the leave section of gapped sync
@ -129,6 +130,11 @@ Can claim one time key using POST
Can claim remote one time key using POST
Local device key changes appear in v2 /sync
Local device key changes appear in /keys/changes
New users appear in /keys/changes
Local delete device changes appear in v2 /sync
Local new device changes appear in v2 /sync
Local update device changes appear in v2 /sync
Users receive device_list updates for their own devices
Get left notifs for other users in sync and /keys/changes when user leaves
Can add account data
Can add account data to room

View file

@ -27,6 +27,8 @@ type UserInternalAPI interface {
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
@ -47,6 +49,25 @@ type InputAccountDataRequest struct {
type InputAccountDataResponse struct {
}
type PerformDeviceUpdateRequest struct {
RequestingUserID string
DeviceID string
DisplayName *string
}
type PerformDeviceUpdateResponse struct {
DeviceExists bool
Forbidden bool
}
type PerformDeviceDeletionRequest struct {
UserID string
// The devices to delete
DeviceIDs []string
}
type PerformDeviceDeletionResponse struct {
}
// QueryDeviceInfosRequest is the request to QueryDeviceInfos
type QueryDeviceInfosRequest struct {
DeviceIDs []string

View file

@ -25,10 +25,12 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/sqlutil"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/dendrite/userapi/storage/devices"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
type UserInternalAPI struct {
@ -37,6 +39,7 @@ type UserInternalAPI struct {
ServerName gomatrixserverlib.ServerName
// AppServices is the list of all registered AS
AppServices []config.ApplicationService
KeyAPI keyapi.KeyInternalAPI
}
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
@ -101,6 +104,76 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
}
res.DeviceCreated = true
res.Device = dev
// create empty device keys and upload them to trigger device list changes
return a.deviceListUpdate(dev.UserID, []string{dev.ID})
}
func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.PerformDeviceDeletionRequest, res *api.PerformDeviceDeletionResponse) error {
util.GetLogger(ctx).WithField("user_id", req.UserID).WithField("devices", req.DeviceIDs).Info("PerformDeviceDeletion")
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
return err
}
if domain != a.ServerName {
return fmt.Errorf("cannot PerformDeviceDeletion of remote users: got %s want %s", domain, a.ServerName)
}
err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs)
if err != nil {
return err
}
// create empty device keys and upload them to delete what was once there and trigger device list changes
return a.deviceListUpdate(req.UserID, req.DeviceIDs)
}
func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error {
deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs))
for i, did := range deviceIDs {
deviceKeys[i] = keyapi.DeviceKeys{
UserID: userID,
DeviceID: did,
KeyJSON: nil,
}
}
var uploadRes keyapi.PerformUploadKeysResponse
a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
DeviceKeys: deviceKeys,
}, &uploadRes)
if uploadRes.Error != nil {
return fmt.Errorf("Failed to delete device keys: %v", uploadRes.Error)
}
if len(uploadRes.KeyErrors) > 0 {
return fmt.Errorf("Failed to delete device keys, key errors: %+v", uploadRes.KeyErrors)
}
return nil
}
func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return err
}
dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID)
if err == sql.ErrNoRows {
res.DeviceExists = false
return nil
} else if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed")
return err
}
res.DeviceExists = true
if dev.UserID != req.RequestingUserID {
res.Forbidden = true
return nil
}
err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err
}
return nil
}

View file

@ -30,6 +30,8 @@ const (
PerformDeviceCreationPath = "/userapi/performDeviceCreation"
PerformAccountCreationPath = "/userapi/performAccountCreation"
PerformDeviceDeletionPath = "/userapi/performDeviceDeletion"
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
QueryProfilePath = "/userapi/queryProfile"
QueryAccessTokenPath = "/userapi/queryAccessToken"
@ -91,6 +93,26 @@ func (h *httpUserInternalAPI) PerformDeviceCreation(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) PerformDeviceDeletion(
ctx context.Context,
request *api.PerformDeviceDeletionRequest,
response *api.PerformDeviceDeletionResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceDeletion")
defer span.Finish()
apiURL := h.apiURL + PerformDeviceDeletionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceUpdate")
defer span.Finish()
apiURL := h.apiURL + PerformDeviceUpdatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) QueryProfile(
ctx context.Context,
request *api.QueryProfileRequest,

View file

@ -52,6 +52,32 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformDeviceUpdatePath,
httputil.MakeInternalAPI("performDeviceUpdate", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceUpdateRequest{}
response := api.PerformDeviceUpdateResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformDeviceUpdate(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformDeviceDeletionPath,
httputil.MakeInternalAPI("performDeviceDeletion", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceDeletionRequest{}
response := api.PerformDeviceDeletionResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformDeviceDeletion(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryProfilePath,
httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse {
request := api.QueryProfileRequest{}

View file

@ -174,7 +174,7 @@ func (s *devicesStatements) deleteDevice(
func (s *devicesStatements) deleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error {
orig := strings.Replace(deleteDevicesSQL, "($1)", sqlutil.QueryVariadic(len(devices)), 1)
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
prep, err := s.db.Prepare(orig)
if err != nil {
return err
@ -186,7 +186,6 @@ func (s *devicesStatements) deleteDevices(
for i, v := range devices {
params[i+1] = v
}
params = append(params, params...)
_, err = stmt.ExecContext(ctx, params...)
return err
})

View file

@ -17,6 +17,7 @@ package userapi
import (
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/config"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/dendrite/userapi/inthttp"
@ -34,12 +35,13 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(accountDB accounts.Database, deviceDB devices.Database,
serverName gomatrixserverlib.ServerName, appServices []config.ApplicationService) api.UserInternalAPI {
serverName gomatrixserverlib.ServerName, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI) api.UserInternalAPI {
return &internal.UserInternalAPI{
AccountDB: accountDB,
DeviceDB: deviceDB,
ServerName: serverName,
AppServices: appServices,
KeyAPI: keyAPI,
}
}

View file

@ -37,7 +37,7 @@ func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database,
t.Fatalf("failed to create device DB: %s", err)
}
return userapi.NewInternalAPI(accountDB, deviceDB, serverName, nil), accountDB, deviceDB
return userapi.NewInternalAPI(accountDB, deviceDB, serverName, nil, nil), accountDB, deviceDB
}
func TestQueryProfile(t *testing.T) {