diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index 79c331424..c9430543f 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -186,7 +186,7 @@ func main() { ServerKeyAPI: serverKeyAPI, StateAPI: stateAPI, UserAPI: userAPI, - KeyAPI: keyserver.NewInternalAPI(base.Base.Cfg, federation), + KeyAPI: keyserver.NewInternalAPI(base.Base.Cfg, federation, userAPI), ExtPublicRoomsProvider: provider, } monolith.AddAllPublicRoutes(base.Base.PublicAPIMux) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 3cf0168ec..8666e8f52 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -141,7 +141,7 @@ func main() { RoomserverAPI: rsAPI, UserAPI: userAPI, StateAPI: stateAPI, - KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation), + KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation, userAPI), //ServerKeyAPI: serverKeyAPI, ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ygg, fsAPI, federation, diff --git a/cmd/dendrite-federation-api-server/main.go b/cmd/dendrite-federation-api-server/main.go index 1bde56368..70d8394f5 100644 --- a/cmd/dendrite-federation-api-server/main.go +++ b/cmd/dendrite-federation-api-server/main.go @@ -30,10 +30,11 @@ func main() { keyRing := serverKeyAPI.KeyRing() fsAPI := base.FederationSenderHTTPClient() rsAPI := base.RoomserverHTTPClient() + keyAPI := base.KeyServerHTTPClient() federationapi.AddPublicRoutes( base.PublicAPIMux, base.Cfg, userAPI, federation, keyRing, - rsAPI, fsAPI, base.EDUServerClient(), base.CurrentStateAPIClient(), + rsAPI, fsAPI, base.EDUServerClient(), base.CurrentStateAPIClient(), keyAPI, ) base.SetupAndServeHTTP(string(base.Cfg.Bind.FederationAPI), string(base.Cfg.Listen.FederationAPI)) diff --git a/cmd/dendrite-key-server/main.go b/cmd/dendrite-key-server/main.go index 7dabc258a..1aafa1447 100644 --- a/cmd/dendrite-key-server/main.go +++ b/cmd/dendrite-key-server/main.go @@ -24,7 +24,7 @@ func main() { base := setup.NewBaseDendrite(cfg, "KeyServer", true) defer base.Close() // nolint: errcheck - intAPI := keyserver.NewInternalAPI(base.Cfg, base.CreateFederationClient()) + intAPI := keyserver.NewInternalAPI(base.Cfg, base.CreateFederationClient(), base.UserAPIClient()) keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 93d62343d..80a45c991 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -119,7 +119,7 @@ func main() { rsImpl.SetFederationSenderAPI(fsAPI) stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer) - keyAPI := keyserver.NewInternalAPI(base.Cfg, federation) + keyAPI := keyserver.NewInternalAPI(base.Cfg, federation, userAPI) monolith := setup.Monolith{ Config: base.Cfg, diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index 3d58d957f..0bb2dbe9f 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -233,7 +233,7 @@ func main() { RoomserverAPI: rsAPI, StateAPI: stateAPI, UserAPI: userAPI, - KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation), + KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation, userAPI), //ServerKeyAPI: serverKeyAPI, ExtPublicRoomsProvider: p2pPublicRoomProvider, } diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 7d1994b25..079f333a4 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -20,6 +20,7 @@ import ( eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/config" + keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -38,11 +39,12 @@ func AddPublicRoutes( federationSenderAPI federationSenderAPI.FederationSenderInternalAPI, eduAPI eduserverAPI.EDUServerInputAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, + keyAPI keyserverAPI.KeyInternalAPI, ) { routing.Setup( router, cfg, rsAPI, eduAPI, federationSenderAPI, keyRing, - federation, userAPI, stateAPI, + federation, userAPI, stateAPI, keyAPI, ) } diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 6bbe9d80e..8bc4277eb 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -31,7 +31,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { fsAPI := base.FederationSenderHTTPClient() // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. - federationapi.AddPublicRoutes(base.PublicAPIMux, cfg, nil, nil, keyRing, nil, fsAPI, nil, nil) + federationapi.AddPublicRoutes(base.PublicAPIMux, cfg, nil, nil, keyRing, nil, fsAPI, nil, nil, nil) httputil.SetupHTTPAPI( base.BaseMux, base.PublicAPIMux, diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index a1dd0fd09..90eec9e0e 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -19,12 +19,106 @@ import ( "net/http" "time" + "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "golang.org/x/crypto/ed25519" ) +type queryKeysRequest struct { + DeviceKeys map[string][]string `json:"device_keys"` +} + +// QueryDeviceKeys returns device keys for users on this server. +// https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-query +func QueryDeviceKeys( + httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.KeyInternalAPI, thisServer gomatrixserverlib.ServerName, +) util.JSONResponse { + var qkr queryKeysRequest + err := json.Unmarshal(request.Content(), &qkr) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + } + } + // make sure we only query users on our domain + for userID := range qkr.DeviceKeys { + _, serverName, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + delete(qkr.DeviceKeys, userID) + continue // ignore invalid users + } + if serverName != thisServer { + delete(qkr.DeviceKeys, userID) + continue + } + } + + var queryRes api.QueryKeysResponse + keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{ + UserToDevices: qkr.DeviceKeys, + }, &queryRes) + if queryRes.Error != nil { + util.GetLogger(httpReq.Context()).WithError(queryRes.Error).Error("Failed to QueryKeys") + return jsonerror.InternalServerError() + } + return util.JSONResponse{ + Code: 200, + JSON: struct { + DeviceKeys interface{} `json:"device_keys"` + }{queryRes.DeviceKeys}, + } +} + +type claimOTKsRequest struct { + OneTimeKeys map[string]map[string]string `json:"one_time_keys"` +} + +// ClaimOneTimeKeys claims OTKs for users on this server. +// https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-claim +func ClaimOneTimeKeys( + httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.KeyInternalAPI, thisServer gomatrixserverlib.ServerName, +) util.JSONResponse { + var cor claimOTKsRequest + err := json.Unmarshal(request.Content(), &cor) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + } + } + // make sure we only claim users on our domain + for userID := range cor.OneTimeKeys { + _, serverName, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + delete(cor.OneTimeKeys, userID) + continue // ignore invalid users + } + if serverName != thisServer { + delete(cor.OneTimeKeys, userID) + continue + } + } + + var claimRes api.PerformClaimKeysResponse + keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{ + OneTimeKeys: cor.OneTimeKeys, + }, &claimRes) + if claimRes.Error != nil { + util.GetLogger(httpReq.Context()).WithError(claimRes.Error).Error("Failed to PerformClaimKeys") + return jsonerror.InternalServerError() + } + return util.JSONResponse{ + Code: 200, + JSON: struct { + OneTimeKeys interface{} `json:"one_time_keys"` + }{claimRes.OneTimeKeys}, + } +} + // LocalKeys returns the local keys for the server. // See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys func LocalKeys(cfg *config.Dendrite) util.JSONResponse { diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index cd97f2978..50b7bdd28 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -24,6 +24,7 @@ import ( federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/httputil" + keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -54,6 +55,7 @@ func Setup( federation *gomatrixserverlib.FederationClient, userAPI userapi.UserInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, + keyAPI keyserverAPI.KeyInternalAPI, ) { v2keysmux := publicAPIMux.PathPrefix(pathPrefixV2Keys).Subrouter() v1fedmux := publicAPIMux.PathPrefix(pathPrefixV1Federation).Subrouter() @@ -299,4 +301,18 @@ func Setup( return GetPostPublicRooms(req, rsAPI, stateAPI) }), ).Methods(http.MethodGet) + + v1fedmux.Handle("/user/keys/claim", httputil.MakeFedAPI( + "federation_keys_claim", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + return ClaimOneTimeKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) + }, + )).Methods(http.MethodPost) + + v1fedmux.Handle("/user/keys/query", httputil.MakeFedAPI( + "federation_keys_query", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) + }, + )).Methods(http.MethodPost) } diff --git a/go.mod b/go.mod index dfdc66444..f087b087f 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 - github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b + github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 github.com/mattn/go-sqlite3 v2.0.2+incompatible diff --git a/go.sum b/go.sum index a7c8a05b4..de7527d92 100644 --- a/go.sum +++ b/go.sum @@ -423,6 +423,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b h1:ul/Jc5q5+QBHNvhd9idfglOwyGf/Tc3ittINEbKJPsQ= github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d h1:WZXyd8YI+PQIDYjN8HxtqNRJ1DCckt9wPTi2P8cdnKM= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= diff --git a/internal/setup/monolith.go b/internal/setup/monolith.go index 39013a2cd..1f6d9a761 100644 --- a/internal/setup/monolith.go +++ b/internal/setup/monolith.go @@ -73,7 +73,7 @@ func (m *Monolith) AddAllPublicRoutes(publicMux *mux.Router) { federationapi.AddPublicRoutes( publicMux, m.Config, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationSenderAPI, - m.EDUInternalAPI, m.StateAPI, + m.EDUInternalAPI, m.StateAPI, m.KeyAPI, ) mediaapi.AddPublicRoutes(publicMux, m.Config, m.UserAPI, m.Client) syncapi.AddPublicRoutes( diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index e406dab4f..174a72dcd 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -24,7 +24,9 @@ import ( "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -33,6 +35,7 @@ type KeyInternalAPI struct { DB storage.Database ThisServer gomatrixserverlib.ServerName FedClient *gomatrixserverlib.FederationClient + UserAPI userapi.UserInternalAPI } func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { @@ -66,11 +69,25 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err), } } - mergeInto(res.OneTimeKeys, keys) + util.GetLogger(ctx).WithField("keys_claimed", len(keys)).WithField("num_users", len(local)).Info("Claimed local keys") + for _, key := range keys { + _, ok := res.OneTimeKeys[key.UserID] + if !ok { + res.OneTimeKeys[key.UserID] = make(map[string]map[string]json.RawMessage) + } + _, ok = res.OneTimeKeys[key.UserID][key.DeviceID] + if !ok { + res.OneTimeKeys[key.UserID][key.DeviceID] = make(map[string]json.RawMessage) + } + for keyID, keyJSON := range key.KeyJSON { + res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON + } + } delete(domainToDeviceKeys, string(a.ThisServer)) } - // claim remote keys - a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) + if len(domainToDeviceKeys) > 0 { + a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) + } } func (a *KeyInternalAPI) claimRemoteKeys( @@ -82,6 +99,7 @@ func (a *KeyInternalAPI) claimRemoteKeys( wg.Add(len(domainToDeviceKeys)) // mutex for failures var failMu sync.Mutex + util.GetLogger(ctx).WithField("num_servers", len(domainToDeviceKeys)).Info("Claiming remote keys from servers") // fan out for d, k := range domainToDeviceKeys { @@ -91,6 +109,7 @@ func (a *KeyInternalAPI) claimRemoteKeys( defer cancel() claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) if err != nil { + util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed") failMu.Lock() res.Failures[domain] = map[string]interface{}{ "message": err.Error(), @@ -108,6 +127,7 @@ func (a *KeyInternalAPI) claimRemoteKeys( close(resultCh) }() + keysClaimed := 0 for result := range resultCh { for userID, nest := range result.OneTimeKeys { res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage) @@ -119,10 +139,12 @@ func (a *KeyInternalAPI) claimRemoteKeys( continue } res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON + keysClaimed++ } } } } + util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys") } func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { @@ -145,13 +167,28 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } return } + + // pull out display names after we have the keys so we handle wildcards correctly + var dids []string + for _, dk := range deviceKeys { + dids = append(dids, dk.DeviceID) + } + var queryRes userapi.QueryDeviceInfosResponse + err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{ + DeviceIDs: dids, + }, &queryRes) + if err != nil { + util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing") + } + if res.DeviceKeys[userID] == nil { res.DeviceKeys[userID] = make(map[string]json.RawMessage) } for _, dk := range deviceKeys { - // inject an empty 'unsigned' key which should be used for display names - // (but not via this API? unsure when they should be added) - dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct{}{}) + // inject display name if known + dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { + DisplayName string `json:"device_display_name,omitempty"` + }{queryRes.DeviceInfo[dk.DeviceID].DisplayName}) res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON } } else { @@ -298,19 +335,3 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) { // TODO } - -func mergeInto(dst map[string]map[string]map[string]json.RawMessage, src []api.OneTimeKeys) { - for _, key := range src { - _, ok := dst[key.UserID] - if !ok { - dst[key.UserID] = make(map[string]map[string]json.RawMessage) - } - _, ok = dst[key.UserID][key.DeviceID] - if !ok { - dst[key.UserID][key.DeviceID] = make(map[string]json.RawMessage) - } - for keyID, keyJSON := range key.KeyJSON { - dst[key.UserID][key.DeviceID][keyID] = keyJSON - } - } -} diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 714b59f0b..2e1ddb6cc 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/keyserver/internal" "github.com/matrix-org/dendrite/keyserver/inthttp" "github.com/matrix-org/dendrite/keyserver/storage" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) @@ -33,7 +34,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. -func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.FederationClient) api.KeyInternalAPI { +func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.FederationClient, userAPI userapi.UserInternalAPI) api.KeyInternalAPI { db, err := storage.NewDatabase( string(cfg.Database.E2EKey), cfg.DbProperties(), @@ -45,5 +46,6 @@ func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.Federatio DB: db, ThisServer: cfg.Matrix.ServerName, FedClient: fedClient, + UserAPI: userAPI, } } diff --git a/sytest-whitelist b/sytest-whitelist index f21432fbd..5bf6d68bf 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -122,9 +122,11 @@ User can invite local user to room with version 1 Can upload device keys Should reject keys claiming to belong to a different user Can query device keys using POST +Can query remote device keys using POST Can query specific device keys using POST query for user with no keys returns empty key dict Can claim one time key using POST +Can claim remote one time key using POST Can add account data Can add account data to room Can get account data without syncing diff --git a/userapi/api/api.go b/userapi/api/api.go index cf0f05633..bd0773f87 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -30,6 +30,7 @@ type UserInternalAPI interface { QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error + QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error } // InputAccountDataRequest is the request for InputAccountData @@ -44,6 +45,19 @@ type InputAccountDataRequest struct { type InputAccountDataResponse struct { } +// QueryDeviceInfosRequest is the request to QueryDeviceInfos +type QueryDeviceInfosRequest struct { + DeviceIDs []string +} + +// QueryDeviceInfosResponse is the response to QueryDeviceInfos +type QueryDeviceInfosResponse struct { + DeviceInfo map[string]struct { + DisplayName string + UserID string + } +} + // QueryAccessTokenRequest is the request for QueryAccessToken type QueryAccessTokenRequest struct { AccessToken string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 1d10d1d8b..2de8f9607 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -125,6 +125,27 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil return nil } +func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error { + devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs) + if err != nil { + return err + } + res.DeviceInfo = make(map[string]struct { + DisplayName string + UserID string + }) + for _, d := range devices { + res.DeviceInfo[d.ID] = struct { + DisplayName string + UserID string + }{ + DisplayName: d.DisplayName, + UserID: d.UserID, + } + } + return nil +} + func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 4ab0d690e..b2b42823f 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -35,6 +35,7 @@ const ( QueryAccessTokenPath = "/userapi/queryAccessToken" QueryDevicesPath = "/userapi/queryDevices" QueryAccountDataPath = "/userapi/queryAccountData" + QueryDeviceInfosPath = "/userapi/queryDeviceInfos" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -101,6 +102,18 @@ func (h *httpUserInternalAPI) QueryProfile( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +func (h *httpUserInternalAPI) QueryDeviceInfos( + ctx context.Context, + request *api.QueryDeviceInfosRequest, + response *api.QueryDeviceInfosResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceInfos") + defer span.Finish() + + apiURL := h.apiURL + QueryDeviceInfosPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + func (h *httpUserInternalAPI) QueryAccessToken( ctx context.Context, request *api.QueryAccessTokenRequest, diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 8f3be7738..d8e151ad4 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/util" ) +// nolint: gocyclo func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { internalAPIMux.Handle(PerformAccountCreationPath, httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse { @@ -103,4 +104,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryDeviceInfosPath, + httputil.MakeInternalAPI("queryDeviceInfos", func(req *http.Request) util.JSONResponse { + request := api.QueryDeviceInfosRequest{} + response := api.QueryDeviceInfosResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryDeviceInfos(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index 4bdb57850..3c9ec934a 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -24,6 +24,7 @@ type Database interface { GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) + GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) // CreateDevice makes a new device associated with the given user ID localpart. // If there is already a device with the same device ID for this user, that access token will be revoked // and replaced with the given accessToken. If the given accessToken is already in use for another device, diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go index 1d036d1b3..03bf7c722 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/devices/postgres/devices_table.go @@ -84,11 +84,15 @@ const deleteDevicesByLocalpartSQL = "" + const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" +const selectDevicesByIDSQL = "" + + "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)" + type devicesStatements struct { insertDeviceStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt selectDeviceByIDStmt *sql.Stmt selectDevicesByLocalpartStmt *sql.Stmt + selectDevicesByIDStmt *sql.Stmt updateDeviceNameStmt *sql.Stmt deleteDeviceStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt @@ -125,6 +129,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil { return } + if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { + return + } s.serverName = server return } @@ -207,15 +214,42 @@ func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device + var displayName sql.NullString stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) if err == nil { dev.ID = deviceID dev.UserID = userutil.MakeUserID(localpart, s.serverName) + if displayName.Valid { + dev.DisplayName = displayName.String + } } return &dev, err } +func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed") + var devices []api.Device + for rows.Next() { + var dev api.Device + var localpart string + var displayName sql.NullString + if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil { + return nil, err + } + if displayName.Valid { + dev.DisplayName = displayName.String + } + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + devices = append(devices, dev) + } + return devices, rows.Err() +} + func (s *devicesStatements) selectDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index 801657bd5..6ac802bb1 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -71,6 +71,10 @@ func (d *Database) GetDevicesByLocalpart( return d.devices.selectDevicesByLocalpart(ctx, localpart) } +func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + return d.devices.selectDevicesByID(ctx, deviceIDs) +} + // CreateDevice makes a new device associated with the given user ID localpart. // If there is already a device with the same device ID for this user, that access token will be revoked // and replaced with the given accessToken. If the given accessToken is already in use for another device, diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index ec52c64bc..efe6f927c 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -20,6 +20,7 @@ import ( "strings" "time" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" @@ -72,6 +73,9 @@ const deleteDevicesByLocalpartSQL = "" + const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" +const selectDevicesByIDSQL = "" + + "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)" + type devicesStatements struct { db *sql.DB writer *sqlutil.TransactionWriter @@ -79,6 +83,7 @@ type devicesStatements struct { selectDevicesCountStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt selectDeviceByIDStmt *sql.Stmt + selectDevicesByIDStmt *sql.Stmt selectDevicesByLocalpartStmt *sql.Stmt updateDeviceNameStmt *sql.Stmt deleteDeviceStmt *sql.Stmt @@ -117,6 +122,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { return } + if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { + return + } s.serverName = server return } @@ -224,11 +232,15 @@ func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device + var displayName sql.NullString stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) if err == nil { dev.ID = deviceID dev.UserID = userutil.MakeUserID(localpart, s.serverName) + if displayName.Valid { + dev.DisplayName = displayName.String + } } return &dev, err } @@ -263,3 +275,32 @@ func (s *devicesStatements) selectDevicesByLocalpart( return devices, nil } + +func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1) + iDeviceIDs := make([]interface{}, len(deviceIDs)) + for i := range deviceIDs { + iDeviceIDs[i] = deviceIDs[i] + } + + rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed") + var devices []api.Device + for rows.Next() { + var dev api.Device + var localpart string + var displayName sql.NullString + if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil { + return nil, err + } + if displayName.Valid { + dev.DisplayName = displayName.String + } + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + devices = append(devices, dev) + } + return devices, rows.Err() +} diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index f248abda4..b9f08ca11 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -77,6 +77,10 @@ func (d *Database) GetDevicesByLocalpart( return d.devices.selectDevicesByLocalpart(ctx, localpart) } +func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + return d.devices.selectDevicesByID(ctx, deviceIDs) +} + // CreateDevice makes a new device associated with the given user ID localpart. // If there is already a device with the same device ID for this user, that access token will be revoked // and replaced with the given accessToken. If the given accessToken is already in use for another device,