Handle inbound federation E2E key queries/claims (#1215)
* Handle inbound /keys/claim and /keys/query requests * Add display names to device key responses * Linting
This commit is contained in:
parent
1e71fd645e
commit
541a23f712
|
@ -186,7 +186,7 @@ func main() {
|
||||||
ServerKeyAPI: serverKeyAPI,
|
ServerKeyAPI: serverKeyAPI,
|
||||||
StateAPI: stateAPI,
|
StateAPI: stateAPI,
|
||||||
UserAPI: userAPI,
|
UserAPI: userAPI,
|
||||||
KeyAPI: keyserver.NewInternalAPI(base.Base.Cfg, federation),
|
KeyAPI: keyserver.NewInternalAPI(base.Base.Cfg, federation, userAPI),
|
||||||
ExtPublicRoomsProvider: provider,
|
ExtPublicRoomsProvider: provider,
|
||||||
}
|
}
|
||||||
monolith.AddAllPublicRoutes(base.Base.PublicAPIMux)
|
monolith.AddAllPublicRoutes(base.Base.PublicAPIMux)
|
||||||
|
|
|
@ -141,7 +141,7 @@ func main() {
|
||||||
RoomserverAPI: rsAPI,
|
RoomserverAPI: rsAPI,
|
||||||
UserAPI: userAPI,
|
UserAPI: userAPI,
|
||||||
StateAPI: stateAPI,
|
StateAPI: stateAPI,
|
||||||
KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation),
|
KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation, userAPI),
|
||||||
//ServerKeyAPI: serverKeyAPI,
|
//ServerKeyAPI: serverKeyAPI,
|
||||||
ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider(
|
ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider(
|
||||||
ygg, fsAPI, federation,
|
ygg, fsAPI, federation,
|
||||||
|
|
|
@ -30,10 +30,11 @@ func main() {
|
||||||
keyRing := serverKeyAPI.KeyRing()
|
keyRing := serverKeyAPI.KeyRing()
|
||||||
fsAPI := base.FederationSenderHTTPClient()
|
fsAPI := base.FederationSenderHTTPClient()
|
||||||
rsAPI := base.RoomserverHTTPClient()
|
rsAPI := base.RoomserverHTTPClient()
|
||||||
|
keyAPI := base.KeyServerHTTPClient()
|
||||||
|
|
||||||
federationapi.AddPublicRoutes(
|
federationapi.AddPublicRoutes(
|
||||||
base.PublicAPIMux, base.Cfg, userAPI, federation, keyRing,
|
base.PublicAPIMux, base.Cfg, userAPI, federation, keyRing,
|
||||||
rsAPI, fsAPI, base.EDUServerClient(), base.CurrentStateAPIClient(),
|
rsAPI, fsAPI, base.EDUServerClient(), base.CurrentStateAPIClient(), keyAPI,
|
||||||
)
|
)
|
||||||
|
|
||||||
base.SetupAndServeHTTP(string(base.Cfg.Bind.FederationAPI), string(base.Cfg.Listen.FederationAPI))
|
base.SetupAndServeHTTP(string(base.Cfg.Bind.FederationAPI), string(base.Cfg.Listen.FederationAPI))
|
||||||
|
|
|
@ -24,7 +24,7 @@ func main() {
|
||||||
base := setup.NewBaseDendrite(cfg, "KeyServer", true)
|
base := setup.NewBaseDendrite(cfg, "KeyServer", true)
|
||||||
defer base.Close() // nolint: errcheck
|
defer base.Close() // nolint: errcheck
|
||||||
|
|
||||||
intAPI := keyserver.NewInternalAPI(base.Cfg, base.CreateFederationClient())
|
intAPI := keyserver.NewInternalAPI(base.Cfg, base.CreateFederationClient(), base.UserAPIClient())
|
||||||
|
|
||||||
keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI)
|
keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI)
|
||||||
|
|
||||||
|
|
|
@ -119,7 +119,7 @@ func main() {
|
||||||
rsImpl.SetFederationSenderAPI(fsAPI)
|
rsImpl.SetFederationSenderAPI(fsAPI)
|
||||||
|
|
||||||
stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer)
|
stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer)
|
||||||
keyAPI := keyserver.NewInternalAPI(base.Cfg, federation)
|
keyAPI := keyserver.NewInternalAPI(base.Cfg, federation, userAPI)
|
||||||
|
|
||||||
monolith := setup.Monolith{
|
monolith := setup.Monolith{
|
||||||
Config: base.Cfg,
|
Config: base.Cfg,
|
||||||
|
|
|
@ -233,7 +233,7 @@ func main() {
|
||||||
RoomserverAPI: rsAPI,
|
RoomserverAPI: rsAPI,
|
||||||
StateAPI: stateAPI,
|
StateAPI: stateAPI,
|
||||||
UserAPI: userAPI,
|
UserAPI: userAPI,
|
||||||
KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation),
|
KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation, userAPI),
|
||||||
//ServerKeyAPI: serverKeyAPI,
|
//ServerKeyAPI: serverKeyAPI,
|
||||||
ExtPublicRoomsProvider: p2pPublicRoomProvider,
|
ExtPublicRoomsProvider: p2pPublicRoomProvider,
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
|
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
|
||||||
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
|
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
|
||||||
"github.com/matrix-org/dendrite/internal/config"
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
|
||||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
||||||
|
@ -38,11 +39,12 @@ func AddPublicRoutes(
|
||||||
federationSenderAPI federationSenderAPI.FederationSenderInternalAPI,
|
federationSenderAPI federationSenderAPI.FederationSenderInternalAPI,
|
||||||
eduAPI eduserverAPI.EDUServerInputAPI,
|
eduAPI eduserverAPI.EDUServerInputAPI,
|
||||||
stateAPI currentstateAPI.CurrentStateInternalAPI,
|
stateAPI currentstateAPI.CurrentStateInternalAPI,
|
||||||
|
keyAPI keyserverAPI.KeyInternalAPI,
|
||||||
) {
|
) {
|
||||||
|
|
||||||
routing.Setup(
|
routing.Setup(
|
||||||
router, cfg, rsAPI,
|
router, cfg, rsAPI,
|
||||||
eduAPI, federationSenderAPI, keyRing,
|
eduAPI, federationSenderAPI, keyRing,
|
||||||
federation, userAPI, stateAPI,
|
federation, userAPI, stateAPI, keyAPI,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
|
||||||
fsAPI := base.FederationSenderHTTPClient()
|
fsAPI := base.FederationSenderHTTPClient()
|
||||||
// TODO: This is pretty fragile, as if anything calls anything on these nils this test will break.
|
// TODO: This is pretty fragile, as if anything calls anything on these nils this test will break.
|
||||||
// Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing.
|
// Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing.
|
||||||
federationapi.AddPublicRoutes(base.PublicAPIMux, cfg, nil, nil, keyRing, nil, fsAPI, nil, nil)
|
federationapi.AddPublicRoutes(base.PublicAPIMux, cfg, nil, nil, keyRing, nil, fsAPI, nil, nil, nil)
|
||||||
httputil.SetupHTTPAPI(
|
httputil.SetupHTTPAPI(
|
||||||
base.BaseMux,
|
base.BaseMux,
|
||||||
base.PublicAPIMux,
|
base.PublicAPIMux,
|
||||||
|
|
|
@ -19,12 +19,106 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
"github.com/matrix-org/dendrite/internal/config"
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type queryKeysRequest struct {
|
||||||
|
DeviceKeys map[string][]string `json:"device_keys"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryDeviceKeys returns device keys for users on this server.
|
||||||
|
// https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-query
|
||||||
|
func QueryDeviceKeys(
|
||||||
|
httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.KeyInternalAPI, thisServer gomatrixserverlib.ServerName,
|
||||||
|
) util.JSONResponse {
|
||||||
|
var qkr queryKeysRequest
|
||||||
|
err := json.Unmarshal(request.Content(), &qkr)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// make sure we only query users on our domain
|
||||||
|
for userID := range qkr.DeviceKeys {
|
||||||
|
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
|
if err != nil {
|
||||||
|
delete(qkr.DeviceKeys, userID)
|
||||||
|
continue // ignore invalid users
|
||||||
|
}
|
||||||
|
if serverName != thisServer {
|
||||||
|
delete(qkr.DeviceKeys, userID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var queryRes api.QueryKeysResponse
|
||||||
|
keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{
|
||||||
|
UserToDevices: qkr.DeviceKeys,
|
||||||
|
}, &queryRes)
|
||||||
|
if queryRes.Error != nil {
|
||||||
|
util.GetLogger(httpReq.Context()).WithError(queryRes.Error).Error("Failed to QueryKeys")
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: struct {
|
||||||
|
DeviceKeys interface{} `json:"device_keys"`
|
||||||
|
}{queryRes.DeviceKeys},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type claimOTKsRequest struct {
|
||||||
|
OneTimeKeys map[string]map[string]string `json:"one_time_keys"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClaimOneTimeKeys claims OTKs for users on this server.
|
||||||
|
// https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-claim
|
||||||
|
func ClaimOneTimeKeys(
|
||||||
|
httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.KeyInternalAPI, thisServer gomatrixserverlib.ServerName,
|
||||||
|
) util.JSONResponse {
|
||||||
|
var cor claimOTKsRequest
|
||||||
|
err := json.Unmarshal(request.Content(), &cor)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// make sure we only claim users on our domain
|
||||||
|
for userID := range cor.OneTimeKeys {
|
||||||
|
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
|
if err != nil {
|
||||||
|
delete(cor.OneTimeKeys, userID)
|
||||||
|
continue // ignore invalid users
|
||||||
|
}
|
||||||
|
if serverName != thisServer {
|
||||||
|
delete(cor.OneTimeKeys, userID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var claimRes api.PerformClaimKeysResponse
|
||||||
|
keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{
|
||||||
|
OneTimeKeys: cor.OneTimeKeys,
|
||||||
|
}, &claimRes)
|
||||||
|
if claimRes.Error != nil {
|
||||||
|
util.GetLogger(httpReq.Context()).WithError(claimRes.Error).Error("Failed to PerformClaimKeys")
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: struct {
|
||||||
|
OneTimeKeys interface{} `json:"one_time_keys"`
|
||||||
|
}{claimRes.OneTimeKeys},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// LocalKeys returns the local keys for the server.
|
// LocalKeys returns the local keys for the server.
|
||||||
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
|
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
|
||||||
func LocalKeys(cfg *config.Dendrite) util.JSONResponse {
|
func LocalKeys(cfg *config.Dendrite) util.JSONResponse {
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
|
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
|
||||||
"github.com/matrix-org/dendrite/internal/config"
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
"github.com/matrix-org/dendrite/internal/httputil"
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
|
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
|
||||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -54,6 +55,7 @@ func Setup(
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
userAPI userapi.UserInternalAPI,
|
userAPI userapi.UserInternalAPI,
|
||||||
stateAPI currentstateAPI.CurrentStateInternalAPI,
|
stateAPI currentstateAPI.CurrentStateInternalAPI,
|
||||||
|
keyAPI keyserverAPI.KeyInternalAPI,
|
||||||
) {
|
) {
|
||||||
v2keysmux := publicAPIMux.PathPrefix(pathPrefixV2Keys).Subrouter()
|
v2keysmux := publicAPIMux.PathPrefix(pathPrefixV2Keys).Subrouter()
|
||||||
v1fedmux := publicAPIMux.PathPrefix(pathPrefixV1Federation).Subrouter()
|
v1fedmux := publicAPIMux.PathPrefix(pathPrefixV1Federation).Subrouter()
|
||||||
|
@ -299,4 +301,18 @@ func Setup(
|
||||||
return GetPostPublicRooms(req, rsAPI, stateAPI)
|
return GetPostPublicRooms(req, rsAPI, stateAPI)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet)
|
).Methods(http.MethodGet)
|
||||||
|
|
||||||
|
v1fedmux.Handle("/user/keys/claim", httputil.MakeFedAPI(
|
||||||
|
"federation_keys_claim", cfg.Matrix.ServerName, keys, wakeup,
|
||||||
|
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
|
||||||
|
return ClaimOneTimeKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName)
|
||||||
|
},
|
||||||
|
)).Methods(http.MethodPost)
|
||||||
|
|
||||||
|
v1fedmux.Handle("/user/keys/query", httputil.MakeFedAPI(
|
||||||
|
"federation_keys_query", cfg.Matrix.ServerName, keys, wakeup,
|
||||||
|
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
|
||||||
|
return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName)
|
||||||
|
},
|
||||||
|
)).Methods(http.MethodPost)
|
||||||
}
|
}
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -21,7 +21,7 @@ require (
|
||||||
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
|
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d
|
||||||
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f
|
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f
|
||||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7
|
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7
|
||||||
github.com/mattn/go-sqlite3 v2.0.2+incompatible
|
github.com/mattn/go-sqlite3 v2.0.2+incompatible
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -423,6 +423,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b h1:ul/Jc5q5+QBHNvhd9idfglOwyGf/Tc3ittINEbKJPsQ=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b h1:ul/Jc5q5+QBHNvhd9idfglOwyGf/Tc3ittINEbKJPsQ=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
||||||
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d h1:WZXyd8YI+PQIDYjN8HxtqNRJ1DCckt9wPTi2P8cdnKM=
|
||||||
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
||||||
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y=
|
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y=
|
||||||
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go=
|
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go=
|
||||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=
|
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=
|
||||||
|
|
|
@ -73,7 +73,7 @@ func (m *Monolith) AddAllPublicRoutes(publicMux *mux.Router) {
|
||||||
federationapi.AddPublicRoutes(
|
federationapi.AddPublicRoutes(
|
||||||
publicMux, m.Config, m.UserAPI, m.FedClient,
|
publicMux, m.Config, m.UserAPI, m.FedClient,
|
||||||
m.KeyRing, m.RoomserverAPI, m.FederationSenderAPI,
|
m.KeyRing, m.RoomserverAPI, m.FederationSenderAPI,
|
||||||
m.EDUInternalAPI, m.StateAPI,
|
m.EDUInternalAPI, m.StateAPI, m.KeyAPI,
|
||||||
)
|
)
|
||||||
mediaapi.AddPublicRoutes(publicMux, m.Config, m.UserAPI, m.Client)
|
mediaapi.AddPublicRoutes(publicMux, m.Config, m.UserAPI, m.Client)
|
||||||
syncapi.AddPublicRoutes(
|
syncapi.AddPublicRoutes(
|
||||||
|
|
|
@ -24,7 +24,9 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage"
|
"github.com/matrix-org/dendrite/keyserver/storage"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
@ -33,6 +35,7 @@ type KeyInternalAPI struct {
|
||||||
DB storage.Database
|
DB storage.Database
|
||||||
ThisServer gomatrixserverlib.ServerName
|
ThisServer gomatrixserverlib.ServerName
|
||||||
FedClient *gomatrixserverlib.FederationClient
|
FedClient *gomatrixserverlib.FederationClient
|
||||||
|
UserAPI userapi.UserInternalAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||||
|
@ -66,12 +69,26 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
|
||||||
Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
|
Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
mergeInto(res.OneTimeKeys, keys)
|
util.GetLogger(ctx).WithField("keys_claimed", len(keys)).WithField("num_users", len(local)).Info("Claimed local keys")
|
||||||
|
for _, key := range keys {
|
||||||
|
_, ok := res.OneTimeKeys[key.UserID]
|
||||||
|
if !ok {
|
||||||
|
res.OneTimeKeys[key.UserID] = make(map[string]map[string]json.RawMessage)
|
||||||
|
}
|
||||||
|
_, ok = res.OneTimeKeys[key.UserID][key.DeviceID]
|
||||||
|
if !ok {
|
||||||
|
res.OneTimeKeys[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
|
||||||
|
}
|
||||||
|
for keyID, keyJSON := range key.KeyJSON {
|
||||||
|
res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
|
||||||
|
}
|
||||||
|
}
|
||||||
delete(domainToDeviceKeys, string(a.ThisServer))
|
delete(domainToDeviceKeys, string(a.ThisServer))
|
||||||
}
|
}
|
||||||
// claim remote keys
|
if len(domainToDeviceKeys) > 0 {
|
||||||
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
|
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) claimRemoteKeys(
|
func (a *KeyInternalAPI) claimRemoteKeys(
|
||||||
ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
|
ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
|
||||||
|
@ -82,6 +99,7 @@ func (a *KeyInternalAPI) claimRemoteKeys(
|
||||||
wg.Add(len(domainToDeviceKeys))
|
wg.Add(len(domainToDeviceKeys))
|
||||||
// mutex for failures
|
// mutex for failures
|
||||||
var failMu sync.Mutex
|
var failMu sync.Mutex
|
||||||
|
util.GetLogger(ctx).WithField("num_servers", len(domainToDeviceKeys)).Info("Claiming remote keys from servers")
|
||||||
|
|
||||||
// fan out
|
// fan out
|
||||||
for d, k := range domainToDeviceKeys {
|
for d, k := range domainToDeviceKeys {
|
||||||
|
@ -91,6 +109,7 @@ func (a *KeyInternalAPI) claimRemoteKeys(
|
||||||
defer cancel()
|
defer cancel()
|
||||||
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim)
|
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed")
|
||||||
failMu.Lock()
|
failMu.Lock()
|
||||||
res.Failures[domain] = map[string]interface{}{
|
res.Failures[domain] = map[string]interface{}{
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
|
@ -108,6 +127,7 @@ func (a *KeyInternalAPI) claimRemoteKeys(
|
||||||
close(resultCh)
|
close(resultCh)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
keysClaimed := 0
|
||||||
for result := range resultCh {
|
for result := range resultCh {
|
||||||
for userID, nest := range result.OneTimeKeys {
|
for userID, nest := range result.OneTimeKeys {
|
||||||
res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
|
res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
|
||||||
|
@ -119,10 +139,12 @@ func (a *KeyInternalAPI) claimRemoteKeys(
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON
|
res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON
|
||||||
|
keysClaimed++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
|
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
|
||||||
|
@ -145,13 +167,28 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// pull out display names after we have the keys so we handle wildcards correctly
|
||||||
|
var dids []string
|
||||||
|
for _, dk := range deviceKeys {
|
||||||
|
dids = append(dids, dk.DeviceID)
|
||||||
|
}
|
||||||
|
var queryRes userapi.QueryDeviceInfosResponse
|
||||||
|
err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{
|
||||||
|
DeviceIDs: dids,
|
||||||
|
}, &queryRes)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing")
|
||||||
|
}
|
||||||
|
|
||||||
if res.DeviceKeys[userID] == nil {
|
if res.DeviceKeys[userID] == nil {
|
||||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||||
}
|
}
|
||||||
for _, dk := range deviceKeys {
|
for _, dk := range deviceKeys {
|
||||||
// inject an empty 'unsigned' key which should be used for display names
|
// inject display name if known
|
||||||
// (but not via this API? unsure when they should be added)
|
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
|
||||||
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct{}{})
|
DisplayName string `json:"device_display_name,omitempty"`
|
||||||
|
}{queryRes.DeviceInfo[dk.DeviceID].DisplayName})
|
||||||
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
|
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -298,19 +335,3 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
|
||||||
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) {
|
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) {
|
||||||
// TODO
|
// TODO
|
||||||
}
|
}
|
||||||
|
|
||||||
func mergeInto(dst map[string]map[string]map[string]json.RawMessage, src []api.OneTimeKeys) {
|
|
||||||
for _, key := range src {
|
|
||||||
_, ok := dst[key.UserID]
|
|
||||||
if !ok {
|
|
||||||
dst[key.UserID] = make(map[string]map[string]json.RawMessage)
|
|
||||||
}
|
|
||||||
_, ok = dst[key.UserID][key.DeviceID]
|
|
||||||
if !ok {
|
|
||||||
dst[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
|
|
||||||
}
|
|
||||||
for keyID, keyJSON := range key.KeyJSON {
|
|
||||||
dst[key.UserID][key.DeviceID][keyID] = keyJSON
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/keyserver/internal"
|
"github.com/matrix-org/dendrite/keyserver/internal"
|
||||||
"github.com/matrix-org/dendrite/keyserver/inthttp"
|
"github.com/matrix-org/dendrite/keyserver/inthttp"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage"
|
"github.com/matrix-org/dendrite/keyserver/storage"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
@ -33,7 +34,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
|
||||||
|
|
||||||
// NewInternalAPI returns a concerete implementation of the internal API. Callers
|
// NewInternalAPI returns a concerete implementation of the internal API. Callers
|
||||||
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
||||||
func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.FederationClient) api.KeyInternalAPI {
|
func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.FederationClient, userAPI userapi.UserInternalAPI) api.KeyInternalAPI {
|
||||||
db, err := storage.NewDatabase(
|
db, err := storage.NewDatabase(
|
||||||
string(cfg.Database.E2EKey),
|
string(cfg.Database.E2EKey),
|
||||||
cfg.DbProperties(),
|
cfg.DbProperties(),
|
||||||
|
@ -45,5 +46,6 @@ func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.Federatio
|
||||||
DB: db,
|
DB: db,
|
||||||
ThisServer: cfg.Matrix.ServerName,
|
ThisServer: cfg.Matrix.ServerName,
|
||||||
FedClient: fedClient,
|
FedClient: fedClient,
|
||||||
|
UserAPI: userAPI,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -122,9 +122,11 @@ User can invite local user to room with version 1
|
||||||
Can upload device keys
|
Can upload device keys
|
||||||
Should reject keys claiming to belong to a different user
|
Should reject keys claiming to belong to a different user
|
||||||
Can query device keys using POST
|
Can query device keys using POST
|
||||||
|
Can query remote device keys using POST
|
||||||
Can query specific device keys using POST
|
Can query specific device keys using POST
|
||||||
query for user with no keys returns empty key dict
|
query for user with no keys returns empty key dict
|
||||||
Can claim one time key using POST
|
Can claim one time key using POST
|
||||||
|
Can claim remote one time key using POST
|
||||||
Can add account data
|
Can add account data
|
||||||
Can add account data to room
|
Can add account data to room
|
||||||
Can get account data without syncing
|
Can get account data without syncing
|
||||||
|
|
|
@ -30,6 +30,7 @@ type UserInternalAPI interface {
|
||||||
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
|
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
|
||||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||||
|
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// InputAccountDataRequest is the request for InputAccountData
|
// InputAccountDataRequest is the request for InputAccountData
|
||||||
|
@ -44,6 +45,19 @@ type InputAccountDataRequest struct {
|
||||||
type InputAccountDataResponse struct {
|
type InputAccountDataResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryDeviceInfosRequest is the request to QueryDeviceInfos
|
||||||
|
type QueryDeviceInfosRequest struct {
|
||||||
|
DeviceIDs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryDeviceInfosResponse is the response to QueryDeviceInfos
|
||||||
|
type QueryDeviceInfosResponse struct {
|
||||||
|
DeviceInfo map[string]struct {
|
||||||
|
DisplayName string
|
||||||
|
UserID string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// QueryAccessTokenRequest is the request for QueryAccessToken
|
// QueryAccessTokenRequest is the request for QueryAccessToken
|
||||||
type QueryAccessTokenRequest struct {
|
type QueryAccessTokenRequest struct {
|
||||||
AccessToken string
|
AccessToken string
|
||||||
|
|
|
@ -125,6 +125,27 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
|
||||||
|
devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.DeviceInfo = make(map[string]struct {
|
||||||
|
DisplayName string
|
||||||
|
UserID string
|
||||||
|
})
|
||||||
|
for _, d := range devices {
|
||||||
|
res.DeviceInfo[d.ID] = struct {
|
||||||
|
DisplayName string
|
||||||
|
UserID string
|
||||||
|
}{
|
||||||
|
DisplayName: d.DisplayName,
|
||||||
|
UserID: d.UserID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error {
|
func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error {
|
||||||
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -35,6 +35,7 @@ const (
|
||||||
QueryAccessTokenPath = "/userapi/queryAccessToken"
|
QueryAccessTokenPath = "/userapi/queryAccessToken"
|
||||||
QueryDevicesPath = "/userapi/queryDevices"
|
QueryDevicesPath = "/userapi/queryDevices"
|
||||||
QueryAccountDataPath = "/userapi/queryAccountData"
|
QueryAccountDataPath = "/userapi/queryAccountData"
|
||||||
|
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
||||||
|
@ -101,6 +102,18 @@ func (h *httpUserInternalAPI) QueryProfile(
|
||||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) QueryDeviceInfos(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryDeviceInfosRequest,
|
||||||
|
response *api.QueryDeviceInfosResponse,
|
||||||
|
) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceInfos")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + QueryDeviceInfosPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *httpUserInternalAPI) QueryAccessToken(
|
func (h *httpUserInternalAPI) QueryAccessToken(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *api.QueryAccessTokenRequest,
|
request *api.QueryAccessTokenRequest,
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// nolint: gocyclo
|
||||||
func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||||
internalAPIMux.Handle(PerformAccountCreationPath,
|
internalAPIMux.Handle(PerformAccountCreationPath,
|
||||||
httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse {
|
httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse {
|
||||||
|
@ -103,4 +104,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(QueryDeviceInfosPath,
|
||||||
|
httputil.MakeInternalAPI("queryDeviceInfos", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryDeviceInfosRequest{}
|
||||||
|
response := api.QueryDeviceInfosResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := s.QueryDeviceInfos(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ type Database interface {
|
||||||
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
||||||
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
||||||
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
|
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
|
||||||
|
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
||||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||||
|
|
|
@ -84,11 +84,15 @@ const deleteDevicesByLocalpartSQL = "" +
|
||||||
const deleteDevicesSQL = "" +
|
const deleteDevicesSQL = "" +
|
||||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
||||||
|
|
||||||
|
const selectDevicesByIDSQL = "" +
|
||||||
|
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)"
|
||||||
|
|
||||||
type devicesStatements struct {
|
type devicesStatements struct {
|
||||||
insertDeviceStmt *sql.Stmt
|
insertDeviceStmt *sql.Stmt
|
||||||
selectDeviceByTokenStmt *sql.Stmt
|
selectDeviceByTokenStmt *sql.Stmt
|
||||||
selectDeviceByIDStmt *sql.Stmt
|
selectDeviceByIDStmt *sql.Stmt
|
||||||
selectDevicesByLocalpartStmt *sql.Stmt
|
selectDevicesByLocalpartStmt *sql.Stmt
|
||||||
|
selectDevicesByIDStmt *sql.Stmt
|
||||||
updateDeviceNameStmt *sql.Stmt
|
updateDeviceNameStmt *sql.Stmt
|
||||||
deleteDeviceStmt *sql.Stmt
|
deleteDeviceStmt *sql.Stmt
|
||||||
deleteDevicesByLocalpartStmt *sql.Stmt
|
deleteDevicesByLocalpartStmt *sql.Stmt
|
||||||
|
@ -125,6 +129,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
|
||||||
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
|
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
s.serverName = server
|
s.serverName = server
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -207,15 +214,42 @@ func (s *devicesStatements) selectDeviceByID(
|
||||||
ctx context.Context, localpart, deviceID string,
|
ctx context.Context, localpart, deviceID string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
|
var displayName sql.NullString
|
||||||
stmt := s.selectDeviceByIDStmt
|
stmt := s.selectDeviceByIDStmt
|
||||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName)
|
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dev.ID = deviceID
|
dev.ID = deviceID
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
|
if displayName.Valid {
|
||||||
|
dev.DisplayName = displayName.String
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return &dev, err
|
return &dev, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||||
|
rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
|
||||||
|
var devices []api.Device
|
||||||
|
for rows.Next() {
|
||||||
|
var dev api.Device
|
||||||
|
var localpart string
|
||||||
|
var displayName sql.NullString
|
||||||
|
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if displayName.Valid {
|
||||||
|
dev.DisplayName = displayName.String
|
||||||
|
}
|
||||||
|
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
|
devices = append(devices, dev)
|
||||||
|
}
|
||||||
|
return devices, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) selectDevicesByLocalpart(
|
func (s *devicesStatements) selectDevicesByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) ([]api.Device, error) {
|
) ([]api.Device, error) {
|
||||||
|
|
|
@ -71,6 +71,10 @@ func (d *Database) GetDevicesByLocalpart(
|
||||||
return d.devices.selectDevicesByLocalpart(ctx, localpart)
|
return d.devices.selectDevicesByLocalpart(ctx, localpart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||||
|
return d.devices.selectDevicesByID(ctx, deviceIDs)
|
||||||
|
}
|
||||||
|
|
||||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
||||||
|
@ -72,6 +73,9 @@ const deleteDevicesByLocalpartSQL = "" +
|
||||||
const deleteDevicesSQL = "" +
|
const deleteDevicesSQL = "" +
|
||||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||||
|
|
||||||
|
const selectDevicesByIDSQL = "" +
|
||||||
|
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
|
||||||
|
|
||||||
type devicesStatements struct {
|
type devicesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer *sqlutil.TransactionWriter
|
||||||
|
@ -79,6 +83,7 @@ type devicesStatements struct {
|
||||||
selectDevicesCountStmt *sql.Stmt
|
selectDevicesCountStmt *sql.Stmt
|
||||||
selectDeviceByTokenStmt *sql.Stmt
|
selectDeviceByTokenStmt *sql.Stmt
|
||||||
selectDeviceByIDStmt *sql.Stmt
|
selectDeviceByIDStmt *sql.Stmt
|
||||||
|
selectDevicesByIDStmt *sql.Stmt
|
||||||
selectDevicesByLocalpartStmt *sql.Stmt
|
selectDevicesByLocalpartStmt *sql.Stmt
|
||||||
updateDeviceNameStmt *sql.Stmt
|
updateDeviceNameStmt *sql.Stmt
|
||||||
deleteDeviceStmt *sql.Stmt
|
deleteDeviceStmt *sql.Stmt
|
||||||
|
@ -117,6 +122,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
|
||||||
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
s.serverName = server
|
s.serverName = server
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -224,11 +232,15 @@ func (s *devicesStatements) selectDeviceByID(
|
||||||
ctx context.Context, localpart, deviceID string,
|
ctx context.Context, localpart, deviceID string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
|
var displayName sql.NullString
|
||||||
stmt := s.selectDeviceByIDStmt
|
stmt := s.selectDeviceByIDStmt
|
||||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName)
|
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dev.ID = deviceID
|
dev.ID = deviceID
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
|
if displayName.Valid {
|
||||||
|
dev.DisplayName = displayName.String
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return &dev, err
|
return &dev, err
|
||||||
}
|
}
|
||||||
|
@ -263,3 +275,32 @@ func (s *devicesStatements) selectDevicesByLocalpart(
|
||||||
|
|
||||||
return devices, nil
|
return devices, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||||
|
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
|
||||||
|
iDeviceIDs := make([]interface{}, len(deviceIDs))
|
||||||
|
for i := range deviceIDs {
|
||||||
|
iDeviceIDs[i] = deviceIDs[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
|
||||||
|
var devices []api.Device
|
||||||
|
for rows.Next() {
|
||||||
|
var dev api.Device
|
||||||
|
var localpart string
|
||||||
|
var displayName sql.NullString
|
||||||
|
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if displayName.Valid {
|
||||||
|
dev.DisplayName = displayName.String
|
||||||
|
}
|
||||||
|
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
|
devices = append(devices, dev)
|
||||||
|
}
|
||||||
|
return devices, rows.Err()
|
||||||
|
}
|
||||||
|
|
|
@ -77,6 +77,10 @@ func (d *Database) GetDevicesByLocalpart(
|
||||||
return d.devices.selectDevicesByLocalpart(ctx, localpart)
|
return d.devices.selectDevicesByLocalpart(ctx, localpart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||||
|
return d.devices.selectDevicesByID(ctx, deviceIDs)
|
||||||
|
}
|
||||||
|
|
||||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||||
|
|
Loading…
Reference in a new issue