Handle inbound federation E2E key queries/claims (#1215)

* Handle inbound /keys/claim and /keys/query requests

* Add display names to device key responses

* Linting
This commit is contained in:
Kegsay 2020-07-22 17:04:57 +01:00 committed by GitHub
parent 1e71fd645e
commit 541a23f712
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 321 additions and 35 deletions

View file

@ -186,7 +186,7 @@ func main() {
ServerKeyAPI: serverKeyAPI, ServerKeyAPI: serverKeyAPI,
StateAPI: stateAPI, StateAPI: stateAPI,
UserAPI: userAPI, UserAPI: userAPI,
KeyAPI: keyserver.NewInternalAPI(base.Base.Cfg, federation), KeyAPI: keyserver.NewInternalAPI(base.Base.Cfg, federation, userAPI),
ExtPublicRoomsProvider: provider, ExtPublicRoomsProvider: provider,
} }
monolith.AddAllPublicRoutes(base.Base.PublicAPIMux) monolith.AddAllPublicRoutes(base.Base.PublicAPIMux)

View file

@ -141,7 +141,7 @@ func main() {
RoomserverAPI: rsAPI, RoomserverAPI: rsAPI,
UserAPI: userAPI, UserAPI: userAPI,
StateAPI: stateAPI, StateAPI: stateAPI,
KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation), KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation, userAPI),
//ServerKeyAPI: serverKeyAPI, //ServerKeyAPI: serverKeyAPI,
ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider(
ygg, fsAPI, federation, ygg, fsAPI, federation,

View file

@ -30,10 +30,11 @@ func main() {
keyRing := serverKeyAPI.KeyRing() keyRing := serverKeyAPI.KeyRing()
fsAPI := base.FederationSenderHTTPClient() fsAPI := base.FederationSenderHTTPClient()
rsAPI := base.RoomserverHTTPClient() rsAPI := base.RoomserverHTTPClient()
keyAPI := base.KeyServerHTTPClient()
federationapi.AddPublicRoutes( federationapi.AddPublicRoutes(
base.PublicAPIMux, base.Cfg, userAPI, federation, keyRing, base.PublicAPIMux, base.Cfg, userAPI, federation, keyRing,
rsAPI, fsAPI, base.EDUServerClient(), base.CurrentStateAPIClient(), rsAPI, fsAPI, base.EDUServerClient(), base.CurrentStateAPIClient(), keyAPI,
) )
base.SetupAndServeHTTP(string(base.Cfg.Bind.FederationAPI), string(base.Cfg.Listen.FederationAPI)) base.SetupAndServeHTTP(string(base.Cfg.Bind.FederationAPI), string(base.Cfg.Listen.FederationAPI))

View file

@ -24,7 +24,7 @@ func main() {
base := setup.NewBaseDendrite(cfg, "KeyServer", true) base := setup.NewBaseDendrite(cfg, "KeyServer", true)
defer base.Close() // nolint: errcheck defer base.Close() // nolint: errcheck
intAPI := keyserver.NewInternalAPI(base.Cfg, base.CreateFederationClient()) intAPI := keyserver.NewInternalAPI(base.Cfg, base.CreateFederationClient(), base.UserAPIClient())
keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI) keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI)

View file

@ -119,7 +119,7 @@ func main() {
rsImpl.SetFederationSenderAPI(fsAPI) rsImpl.SetFederationSenderAPI(fsAPI)
stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer) stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer)
keyAPI := keyserver.NewInternalAPI(base.Cfg, federation) keyAPI := keyserver.NewInternalAPI(base.Cfg, federation, userAPI)
monolith := setup.Monolith{ monolith := setup.Monolith{
Config: base.Cfg, Config: base.Cfg,

View file

@ -233,7 +233,7 @@ func main() {
RoomserverAPI: rsAPI, RoomserverAPI: rsAPI,
StateAPI: stateAPI, StateAPI: stateAPI,
UserAPI: userAPI, UserAPI: userAPI,
KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation), KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation, userAPI),
//ServerKeyAPI: serverKeyAPI, //ServerKeyAPI: serverKeyAPI,
ExtPublicRoomsProvider: p2pPublicRoomProvider, ExtPublicRoomsProvider: p2pPublicRoomProvider,
} }

View file

@ -20,6 +20,7 @@ import (
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
@ -38,11 +39,12 @@ func AddPublicRoutes(
federationSenderAPI federationSenderAPI.FederationSenderInternalAPI, federationSenderAPI federationSenderAPI.FederationSenderInternalAPI,
eduAPI eduserverAPI.EDUServerInputAPI, eduAPI eduserverAPI.EDUServerInputAPI,
stateAPI currentstateAPI.CurrentStateInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI,
keyAPI keyserverAPI.KeyInternalAPI,
) { ) {
routing.Setup( routing.Setup(
router, cfg, rsAPI, router, cfg, rsAPI,
eduAPI, federationSenderAPI, keyRing, eduAPI, federationSenderAPI, keyRing,
federation, userAPI, stateAPI, federation, userAPI, stateAPI, keyAPI,
) )
} }

View file

@ -31,7 +31,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
fsAPI := base.FederationSenderHTTPClient() fsAPI := base.FederationSenderHTTPClient()
// TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break.
// Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing.
federationapi.AddPublicRoutes(base.PublicAPIMux, cfg, nil, nil, keyRing, nil, fsAPI, nil, nil) federationapi.AddPublicRoutes(base.PublicAPIMux, cfg, nil, nil, keyRing, nil, fsAPI, nil, nil, nil)
httputil.SetupHTTPAPI( httputil.SetupHTTPAPI(
base.BaseMux, base.BaseMux,
base.PublicAPIMux, base.PublicAPIMux,

View file

@ -19,12 +19,106 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
type queryKeysRequest struct {
DeviceKeys map[string][]string `json:"device_keys"`
}
// QueryDeviceKeys returns device keys for users on this server.
// https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-query
func QueryDeviceKeys(
httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.KeyInternalAPI, thisServer gomatrixserverlib.ServerName,
) util.JSONResponse {
var qkr queryKeysRequest
err := json.Unmarshal(request.Content(), &qkr)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
}
}
// make sure we only query users on our domain
for userID := range qkr.DeviceKeys {
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
delete(qkr.DeviceKeys, userID)
continue // ignore invalid users
}
if serverName != thisServer {
delete(qkr.DeviceKeys, userID)
continue
}
}
var queryRes api.QueryKeysResponse
keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{
UserToDevices: qkr.DeviceKeys,
}, &queryRes)
if queryRes.Error != nil {
util.GetLogger(httpReq.Context()).WithError(queryRes.Error).Error("Failed to QueryKeys")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: 200,
JSON: struct {
DeviceKeys interface{} `json:"device_keys"`
}{queryRes.DeviceKeys},
}
}
type claimOTKsRequest struct {
OneTimeKeys map[string]map[string]string `json:"one_time_keys"`
}
// ClaimOneTimeKeys claims OTKs for users on this server.
// https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-claim
func ClaimOneTimeKeys(
httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.KeyInternalAPI, thisServer gomatrixserverlib.ServerName,
) util.JSONResponse {
var cor claimOTKsRequest
err := json.Unmarshal(request.Content(), &cor)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
}
}
// make sure we only claim users on our domain
for userID := range cor.OneTimeKeys {
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
delete(cor.OneTimeKeys, userID)
continue // ignore invalid users
}
if serverName != thisServer {
delete(cor.OneTimeKeys, userID)
continue
}
}
var claimRes api.PerformClaimKeysResponse
keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{
OneTimeKeys: cor.OneTimeKeys,
}, &claimRes)
if claimRes.Error != nil {
util.GetLogger(httpReq.Context()).WithError(claimRes.Error).Error("Failed to PerformClaimKeys")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: 200,
JSON: struct {
OneTimeKeys interface{} `json:"one_time_keys"`
}{claimRes.OneTimeKeys},
}
}
// LocalKeys returns the local keys for the server. // LocalKeys returns the local keys for the server.
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys // See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
func LocalKeys(cfg *config.Dendrite) util.JSONResponse { func LocalKeys(cfg *config.Dendrite) util.JSONResponse {

View file

@ -24,6 +24,7 @@ import (
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -54,6 +55,7 @@ func Setup(
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
stateAPI currentstateAPI.CurrentStateInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI,
keyAPI keyserverAPI.KeyInternalAPI,
) { ) {
v2keysmux := publicAPIMux.PathPrefix(pathPrefixV2Keys).Subrouter() v2keysmux := publicAPIMux.PathPrefix(pathPrefixV2Keys).Subrouter()
v1fedmux := publicAPIMux.PathPrefix(pathPrefixV1Federation).Subrouter() v1fedmux := publicAPIMux.PathPrefix(pathPrefixV1Federation).Subrouter()
@ -299,4 +301,18 @@ func Setup(
return GetPostPublicRooms(req, rsAPI, stateAPI) return GetPostPublicRooms(req, rsAPI, stateAPI)
}), }),
).Methods(http.MethodGet) ).Methods(http.MethodGet)
v1fedmux.Handle("/user/keys/claim", httputil.MakeFedAPI(
"federation_keys_claim", cfg.Matrix.ServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return ClaimOneTimeKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName)
},
)).Methods(http.MethodPost)
v1fedmux.Handle("/user/keys/query", httputil.MakeFedAPI(
"federation_keys_query", cfg.Matrix.ServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName)
},
)).Methods(http.MethodPost)
} }

2
go.mod
View file

@ -21,7 +21,7 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7
github.com/mattn/go-sqlite3 v2.0.2+incompatible github.com/mattn/go-sqlite3 v2.0.2+incompatible

2
go.sum
View file

@ -423,6 +423,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b h1:ul/Jc5q5+QBHNvhd9idfglOwyGf/Tc3ittINEbKJPsQ= github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b h1:ul/Jc5q5+QBHNvhd9idfglOwyGf/Tc3ittINEbKJPsQ=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d h1:WZXyd8YI+PQIDYjN8HxtqNRJ1DCckt9wPTi2P8cdnKM=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y=
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=

View file

@ -73,7 +73,7 @@ func (m *Monolith) AddAllPublicRoutes(publicMux *mux.Router) {
federationapi.AddPublicRoutes( federationapi.AddPublicRoutes(
publicMux, m.Config, m.UserAPI, m.FedClient, publicMux, m.Config, m.UserAPI, m.FedClient,
m.KeyRing, m.RoomserverAPI, m.FederationSenderAPI, m.KeyRing, m.RoomserverAPI, m.FederationSenderAPI,
m.EDUInternalAPI, m.StateAPI, m.EDUInternalAPI, m.StateAPI, m.KeyAPI,
) )
mediaapi.AddPublicRoutes(publicMux, m.Config, m.UserAPI, m.Client) mediaapi.AddPublicRoutes(publicMux, m.Config, m.UserAPI, m.Client)
syncapi.AddPublicRoutes( syncapi.AddPublicRoutes(

View file

@ -24,7 +24,9 @@ import (
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/keyserver/storage"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@ -33,6 +35,7 @@ type KeyInternalAPI struct {
DB storage.Database DB storage.Database
ThisServer gomatrixserverlib.ServerName ThisServer gomatrixserverlib.ServerName
FedClient *gomatrixserverlib.FederationClient FedClient *gomatrixserverlib.FederationClient
UserAPI userapi.UserInternalAPI
} }
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
@ -66,12 +69,26 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err), Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
} }
} }
mergeInto(res.OneTimeKeys, keys) util.GetLogger(ctx).WithField("keys_claimed", len(keys)).WithField("num_users", len(local)).Info("Claimed local keys")
for _, key := range keys {
_, ok := res.OneTimeKeys[key.UserID]
if !ok {
res.OneTimeKeys[key.UserID] = make(map[string]map[string]json.RawMessage)
}
_, ok = res.OneTimeKeys[key.UserID][key.DeviceID]
if !ok {
res.OneTimeKeys[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
}
for keyID, keyJSON := range key.KeyJSON {
res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
}
}
delete(domainToDeviceKeys, string(a.ThisServer)) delete(domainToDeviceKeys, string(a.ThisServer))
} }
// claim remote keys if len(domainToDeviceKeys) > 0 {
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
} }
}
func (a *KeyInternalAPI) claimRemoteKeys( func (a *KeyInternalAPI) claimRemoteKeys(
ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string, ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
@ -82,6 +99,7 @@ func (a *KeyInternalAPI) claimRemoteKeys(
wg.Add(len(domainToDeviceKeys)) wg.Add(len(domainToDeviceKeys))
// mutex for failures // mutex for failures
var failMu sync.Mutex var failMu sync.Mutex
util.GetLogger(ctx).WithField("num_servers", len(domainToDeviceKeys)).Info("Claiming remote keys from servers")
// fan out // fan out
for d, k := range domainToDeviceKeys { for d, k := range domainToDeviceKeys {
@ -91,6 +109,7 @@ func (a *KeyInternalAPI) claimRemoteKeys(
defer cancel() defer cancel()
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed")
failMu.Lock() failMu.Lock()
res.Failures[domain] = map[string]interface{}{ res.Failures[domain] = map[string]interface{}{
"message": err.Error(), "message": err.Error(),
@ -108,6 +127,7 @@ func (a *KeyInternalAPI) claimRemoteKeys(
close(resultCh) close(resultCh)
}() }()
keysClaimed := 0
for result := range resultCh { for result := range resultCh {
for userID, nest := range result.OneTimeKeys { for userID, nest := range result.OneTimeKeys {
res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage) res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
@ -119,10 +139,12 @@ func (a *KeyInternalAPI) claimRemoteKeys(
continue continue
} }
res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON
keysClaimed++
} }
} }
} }
} }
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
} }
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
@ -145,13 +167,28 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
} }
return return
} }
// pull out display names after we have the keys so we handle wildcards correctly
var dids []string
for _, dk := range deviceKeys {
dids = append(dids, dk.DeviceID)
}
var queryRes userapi.QueryDeviceInfosResponse
err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{
DeviceIDs: dids,
}, &queryRes)
if err != nil {
util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing")
}
if res.DeviceKeys[userID] == nil { if res.DeviceKeys[userID] == nil {
res.DeviceKeys[userID] = make(map[string]json.RawMessage) res.DeviceKeys[userID] = make(map[string]json.RawMessage)
} }
for _, dk := range deviceKeys { for _, dk := range deviceKeys {
// inject an empty 'unsigned' key which should be used for display names // inject display name if known
// (but not via this API? unsure when they should be added) dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct{}{}) DisplayName string `json:"device_display_name,omitempty"`
}{queryRes.DeviceInfo[dk.DeviceID].DisplayName})
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
} }
} else { } else {
@ -298,19 +335,3 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) { func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) {
// TODO // TODO
} }
func mergeInto(dst map[string]map[string]map[string]json.RawMessage, src []api.OneTimeKeys) {
for _, key := range src {
_, ok := dst[key.UserID]
if !ok {
dst[key.UserID] = make(map[string]map[string]json.RawMessage)
}
_, ok = dst[key.UserID][key.DeviceID]
if !ok {
dst[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
}
for keyID, keyJSON := range key.KeyJSON {
dst[key.UserID][key.DeviceID][keyID] = keyJSON
}
}
}

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/keyserver/internal" "github.com/matrix-org/dendrite/keyserver/internal"
"github.com/matrix-org/dendrite/keyserver/inthttp" "github.com/matrix-org/dendrite/keyserver/inthttp"
"github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/keyserver/storage"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -33,7 +34,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers // NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.FederationClient) api.KeyInternalAPI { func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.FederationClient, userAPI userapi.UserInternalAPI) api.KeyInternalAPI {
db, err := storage.NewDatabase( db, err := storage.NewDatabase(
string(cfg.Database.E2EKey), string(cfg.Database.E2EKey),
cfg.DbProperties(), cfg.DbProperties(),
@ -45,5 +46,6 @@ func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.Federatio
DB: db, DB: db,
ThisServer: cfg.Matrix.ServerName, ThisServer: cfg.Matrix.ServerName,
FedClient: fedClient, FedClient: fedClient,
UserAPI: userAPI,
} }
} }

View file

@ -122,9 +122,11 @@ User can invite local user to room with version 1
Can upload device keys Can upload device keys
Should reject keys claiming to belong to a different user Should reject keys claiming to belong to a different user
Can query device keys using POST Can query device keys using POST
Can query remote device keys using POST
Can query specific device keys using POST Can query specific device keys using POST
query for user with no keys returns empty key dict query for user with no keys returns empty key dict
Can claim one time key using POST Can claim one time key using POST
Can claim remote one time key using POST
Can add account data Can add account data
Can add account data to room Can add account data to room
Can get account data without syncing Can get account data without syncing

View file

@ -30,6 +30,7 @@ type UserInternalAPI interface {
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
} }
// InputAccountDataRequest is the request for InputAccountData // InputAccountDataRequest is the request for InputAccountData
@ -44,6 +45,19 @@ type InputAccountDataRequest struct {
type InputAccountDataResponse struct { type InputAccountDataResponse struct {
} }
// QueryDeviceInfosRequest is the request to QueryDeviceInfos
type QueryDeviceInfosRequest struct {
DeviceIDs []string
}
// QueryDeviceInfosResponse is the response to QueryDeviceInfos
type QueryDeviceInfosResponse struct {
DeviceInfo map[string]struct {
DisplayName string
UserID string
}
}
// QueryAccessTokenRequest is the request for QueryAccessToken // QueryAccessTokenRequest is the request for QueryAccessToken
type QueryAccessTokenRequest struct { type QueryAccessTokenRequest struct {
AccessToken string AccessToken string

View file

@ -125,6 +125,27 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
return nil return nil
} }
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs)
if err != nil {
return err
}
res.DeviceInfo = make(map[string]struct {
DisplayName string
UserID string
})
for _, d := range devices {
res.DeviceInfo[d.ID] = struct {
DisplayName string
UserID string
}{
DisplayName: d.DisplayName,
UserID: d.UserID,
}
}
return nil
}
func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error { func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error {
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {

View file

@ -35,6 +35,7 @@ const (
QueryAccessTokenPath = "/userapi/queryAccessToken" QueryAccessTokenPath = "/userapi/queryAccessToken"
QueryDevicesPath = "/userapi/queryDevices" QueryDevicesPath = "/userapi/queryDevices"
QueryAccountDataPath = "/userapi/queryAccountData" QueryAccountDataPath = "/userapi/queryAccountData"
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
) )
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -101,6 +102,18 @@ func (h *httpUserInternalAPI) QueryProfile(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpUserInternalAPI) QueryDeviceInfos(
ctx context.Context,
request *api.QueryDeviceInfosRequest,
response *api.QueryDeviceInfosResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceInfos")
defer span.Finish()
apiURL := h.apiURL + QueryDeviceInfosPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) QueryAccessToken( func (h *httpUserInternalAPI) QueryAccessToken(
ctx context.Context, ctx context.Context,
request *api.QueryAccessTokenRequest, request *api.QueryAccessTokenRequest,

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// nolint: gocyclo
func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
internalAPIMux.Handle(PerformAccountCreationPath, internalAPIMux.Handle(PerformAccountCreationPath,
httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse {
@ -103,4 +104,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(QueryDeviceInfosPath,
httputil.MakeInternalAPI("queryDeviceInfos", func(req *http.Request) util.JSONResponse {
request := api.QueryDeviceInfosRequest{}
response := api.QueryDeviceInfosResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryDeviceInfos(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
} }

View file

@ -24,6 +24,7 @@ type Database interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart. // CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked // If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device, // and replaced with the given accessToken. If the given accessToken is already in use for another device,

View file

@ -84,11 +84,15 @@ const deleteDevicesByLocalpartSQL = "" +
const deleteDevicesSQL = "" + const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)"
type devicesStatements struct { type devicesStatements struct {
insertDeviceStmt *sql.Stmt insertDeviceStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt
selectDeviceByIDStmt *sql.Stmt selectDeviceByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt selectDevicesByLocalpartStmt *sql.Stmt
selectDevicesByIDStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt
@ -125,6 +129,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil { if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
return return
} }
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
s.serverName = server s.serverName = server
return return
} }
@ -207,15 +214,42 @@ func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context, localpart, deviceID string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
var displayName sql.NullString
stmt := s.selectDeviceByIDStmt stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
if err == nil { if err == nil {
dev.ID = deviceID dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, s.serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
} }
return &dev, err return &dev, err
} }
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
var devices []api.Device
for rows.Next() {
var dev api.Device
var localpart string
var displayName sql.NullString
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
return nil, err
}
if displayName.Valid {
dev.DisplayName = displayName.String
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}
func (s *devicesStatements) selectDevicesByLocalpart( func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ([]api.Device, error) { ) ([]api.Device, error) {

View file

@ -71,6 +71,10 @@ func (d *Database) GetDevicesByLocalpart(
return d.devices.selectDevicesByLocalpart(ctx, localpart) return d.devices.selectDevicesByLocalpart(ctx, localpart)
} }
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart. // CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked // If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device, // and replaced with the given accessToken. If the given accessToken is already in use for another device,

View file

@ -20,6 +20,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
@ -72,6 +73,9 @@ const deleteDevicesByLocalpartSQL = "" +
const deleteDevicesSQL = "" + const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
type devicesStatements struct { type devicesStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter writer *sqlutil.TransactionWriter
@ -79,6 +83,7 @@ type devicesStatements struct {
selectDevicesCountStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt
selectDeviceByIDStmt *sql.Stmt selectDeviceByIDStmt *sql.Stmt
selectDevicesByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt selectDevicesByLocalpartStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt deleteDeviceStmt *sql.Stmt
@ -117,6 +122,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return return
} }
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
s.serverName = server s.serverName = server
return return
} }
@ -224,11 +232,15 @@ func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context, localpart, deviceID string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
var displayName sql.NullString
stmt := s.selectDeviceByIDStmt stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
if err == nil { if err == nil {
dev.ID = deviceID dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, s.serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
} }
return &dev, err return &dev, err
} }
@ -263,3 +275,32 @@ func (s *devicesStatements) selectDevicesByLocalpart(
return devices, nil return devices, nil
} }
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
iDeviceIDs := make([]interface{}, len(deviceIDs))
for i := range deviceIDs {
iDeviceIDs[i] = deviceIDs[i]
}
rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
var devices []api.Device
for rows.Next() {
var dev api.Device
var localpart string
var displayName sql.NullString
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
return nil, err
}
if displayName.Valid {
dev.DisplayName = displayName.String
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}

View file

@ -77,6 +77,10 @@ func (d *Database) GetDevicesByLocalpart(
return d.devices.selectDevicesByLocalpart(ctx, localpart) return d.devices.selectDevicesByLocalpart(ctx, localpart)
} }
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart. // CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked // If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device, // and replaced with the given accessToken. If the given accessToken is already in use for another device,