Add display names to device key responses

This commit is contained in:
Kegan Dougal 2020-07-22 16:33:31 +01:00
parent 9d0b9cd360
commit acc0e7b63d
19 changed files with 180 additions and 12 deletions

View file

@ -186,7 +186,7 @@ func main() {
ServerKeyAPI: serverKeyAPI, ServerKeyAPI: serverKeyAPI,
StateAPI: stateAPI, StateAPI: stateAPI,
UserAPI: userAPI, UserAPI: userAPI,
KeyAPI: keyserver.NewInternalAPI(base.Base.Cfg, federation), KeyAPI: keyserver.NewInternalAPI(base.Base.Cfg, federation, userAPI),
ExtPublicRoomsProvider: provider, ExtPublicRoomsProvider: provider,
} }
monolith.AddAllPublicRoutes(base.Base.PublicAPIMux) monolith.AddAllPublicRoutes(base.Base.PublicAPIMux)

View file

@ -140,7 +140,7 @@ func main() {
RoomserverAPI: rsAPI, RoomserverAPI: rsAPI,
UserAPI: userAPI, UserAPI: userAPI,
StateAPI: stateAPI, StateAPI: stateAPI,
KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation), KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation, userAPI),
//ServerKeyAPI: serverKeyAPI, //ServerKeyAPI: serverKeyAPI,
ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider(
ygg, fsAPI, federation, ygg, fsAPI, federation,

View file

@ -24,7 +24,7 @@ func main() {
base := setup.NewBaseDendrite(cfg, "KeyServer", true) base := setup.NewBaseDendrite(cfg, "KeyServer", true)
defer base.Close() // nolint: errcheck defer base.Close() // nolint: errcheck
intAPI := keyserver.NewInternalAPI(base.Cfg, base.CreateFederationClient()) intAPI := keyserver.NewInternalAPI(base.Cfg, base.CreateFederationClient(), base.UserAPIClient())
keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI) keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI)

View file

@ -119,7 +119,7 @@ func main() {
rsImpl.SetFederationSenderAPI(fsAPI) rsImpl.SetFederationSenderAPI(fsAPI)
stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer) stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer)
keyAPI := keyserver.NewInternalAPI(base.Cfg, federation) keyAPI := keyserver.NewInternalAPI(base.Cfg, federation, userAPI)
monolith := setup.Monolith{ monolith := setup.Monolith{
Config: base.Cfg, Config: base.Cfg,

View file

@ -233,7 +233,7 @@ func main() {
RoomserverAPI: rsAPI, RoomserverAPI: rsAPI,
StateAPI: stateAPI, StateAPI: stateAPI,
UserAPI: userAPI, UserAPI: userAPI,
KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation), KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation, userAPI),
//ServerKeyAPI: serverKeyAPI, //ServerKeyAPI: serverKeyAPI,
ExtPublicRoomsProvider: p2pPublicRoomProvider, ExtPublicRoomsProvider: p2pPublicRoomProvider,
} }

2
go.mod
View file

@ -21,7 +21,7 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7
github.com/mattn/go-sqlite3 v2.0.2+incompatible github.com/mattn/go-sqlite3 v2.0.2+incompatible

2
go.sum
View file

@ -423,6 +423,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b h1:ul/Jc5q5+QBHNvhd9idfglOwyGf/Tc3ittINEbKJPsQ= github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b h1:ul/Jc5q5+QBHNvhd9idfglOwyGf/Tc3ittINEbKJPsQ=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d h1:WZXyd8YI+PQIDYjN8HxtqNRJ1DCckt9wPTi2P8cdnKM=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y=
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/keyserver/storage"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@ -34,6 +35,7 @@ type KeyInternalAPI struct {
DB storage.Database DB storage.Database
ThisServer gomatrixserverlib.ServerName ThisServer gomatrixserverlib.ServerName
FedClient *gomatrixserverlib.FederationClient FedClient *gomatrixserverlib.FederationClient
UserAPI userapi.UserInternalAPI
} }
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
@ -107,6 +109,7 @@ func (a *KeyInternalAPI) claimRemoteKeys(
defer cancel() defer cancel()
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed")
failMu.Lock() failMu.Lock()
res.Failures[domain] = map[string]interface{}{ res.Failures[domain] = map[string]interface{}{
"message": err.Error(), "message": err.Error(),
@ -164,13 +167,28 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
} }
return return
} }
// pull out display names after we have the keys so we handle wildcards correctly
var dids []string
for _, dk := range deviceKeys {
dids = append(dids, dk.DeviceID)
}
var queryRes userapi.QueryDeviceInfosResponse
err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{
DeviceIDs: dids,
}, &queryRes)
if err != nil {
util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing")
}
if res.DeviceKeys[userID] == nil { if res.DeviceKeys[userID] == nil {
res.DeviceKeys[userID] = make(map[string]json.RawMessage) res.DeviceKeys[userID] = make(map[string]json.RawMessage)
} }
for _, dk := range deviceKeys { for _, dk := range deviceKeys {
// inject an empty 'unsigned' key which should be used for display names // inject display name if known
// (but not via this API? unsure when they should be added) dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct{}{}) DisplayName string `json:"device_display_name,omitempty"`
}{queryRes.DeviceInfo[dk.DeviceID].DisplayName})
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
} }
} else { } else {

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/keyserver/internal" "github.com/matrix-org/dendrite/keyserver/internal"
"github.com/matrix-org/dendrite/keyserver/inthttp" "github.com/matrix-org/dendrite/keyserver/inthttp"
"github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/keyserver/storage"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -33,7 +34,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers // NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.FederationClient) api.KeyInternalAPI { func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.FederationClient, userAPI userapi.UserInternalAPI) api.KeyInternalAPI {
db, err := storage.NewDatabase( db, err := storage.NewDatabase(
string(cfg.Database.E2EKey), string(cfg.Database.E2EKey),
cfg.DbProperties(), cfg.DbProperties(),
@ -45,5 +46,6 @@ func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.Federatio
DB: db, DB: db,
ThisServer: cfg.Matrix.ServerName, ThisServer: cfg.Matrix.ServerName,
FedClient: fedClient, FedClient: fedClient,
UserAPI: userAPI,
} }
} }

View file

@ -122,6 +122,7 @@ User can invite local user to room with version 1
Can upload device keys Can upload device keys
Should reject keys claiming to belong to a different user Should reject keys claiming to belong to a different user
Can query device keys using POST Can query device keys using POST
Can query remote device keys using POST
Can query specific device keys using POST Can query specific device keys using POST
query for user with no keys returns empty key dict query for user with no keys returns empty key dict
Can claim one time key using POST Can claim one time key using POST

View file

@ -30,6 +30,7 @@ type UserInternalAPI interface {
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
} }
// InputAccountDataRequest is the request for InputAccountData // InputAccountDataRequest is the request for InputAccountData
@ -44,6 +45,19 @@ type InputAccountDataRequest struct {
type InputAccountDataResponse struct { type InputAccountDataResponse struct {
} }
// QueryDeviceInfosRequest is the request to QueryDeviceInfos
type QueryDeviceInfosRequest struct {
DeviceIDs []string
}
// QueryDeviceInfosResponse is the response to QueryDeviceInfos
type QueryDeviceInfosResponse struct {
DeviceInfo map[string]struct {
DisplayName string
UserID string
}
}
// QueryAccessTokenRequest is the request for QueryAccessToken // QueryAccessTokenRequest is the request for QueryAccessToken
type QueryAccessTokenRequest struct { type QueryAccessTokenRequest struct {
AccessToken string AccessToken string

View file

@ -125,6 +125,27 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
return nil return nil
} }
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs)
if err != nil {
return err
}
res.DeviceInfo = make(map[string]struct {
DisplayName string
UserID string
})
for _, d := range devices {
res.DeviceInfo[d.ID] = struct {
DisplayName string
UserID string
}{
DisplayName: d.DisplayName,
UserID: d.UserID,
}
}
return nil
}
func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error { func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error {
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {

View file

@ -35,6 +35,7 @@ const (
QueryAccessTokenPath = "/userapi/queryAccessToken" QueryAccessTokenPath = "/userapi/queryAccessToken"
QueryDevicesPath = "/userapi/queryDevices" QueryDevicesPath = "/userapi/queryDevices"
QueryAccountDataPath = "/userapi/queryAccountData" QueryAccountDataPath = "/userapi/queryAccountData"
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
) )
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -101,6 +102,18 @@ func (h *httpUserInternalAPI) QueryProfile(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpUserInternalAPI) QueryDeviceInfos(
ctx context.Context,
request *api.QueryDeviceInfosRequest,
response *api.QueryDeviceInfosResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceInfos")
defer span.Finish()
apiURL := h.apiURL + QueryDeviceInfosPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) QueryAccessToken( func (h *httpUserInternalAPI) QueryAccessToken(
ctx context.Context, ctx context.Context,
request *api.QueryAccessTokenRequest, request *api.QueryAccessTokenRequest,

View file

@ -103,4 +103,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(QueryDeviceInfosPath,
httputil.MakeInternalAPI("queryDeviceInfos", func(req *http.Request) util.JSONResponse {
request := api.QueryDeviceInfosRequest{}
response := api.QueryDeviceInfosResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryDeviceInfos(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
} }

View file

@ -24,6 +24,7 @@ type Database interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart. // CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked // If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device, // and replaced with the given accessToken. If the given accessToken is already in use for another device,

View file

@ -84,11 +84,15 @@ const deleteDevicesByLocalpartSQL = "" +
const deleteDevicesSQL = "" + const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)"
type devicesStatements struct { type devicesStatements struct {
insertDeviceStmt *sql.Stmt insertDeviceStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt
selectDeviceByIDStmt *sql.Stmt selectDeviceByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt selectDevicesByLocalpartStmt *sql.Stmt
selectDevicesByIDStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt
@ -125,6 +129,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil { if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
return return
} }
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
s.serverName = server s.serverName = server
return return
} }
@ -207,15 +214,42 @@ func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context, localpart, deviceID string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
var displayName sql.NullString
stmt := s.selectDeviceByIDStmt stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
if err == nil { if err == nil {
dev.ID = deviceID dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, s.serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
} }
return &dev, err return &dev, err
} }
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
var devices []api.Device
for rows.Next() {
var dev api.Device
var localpart string
var displayName sql.NullString
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
return nil, err
}
if displayName.Valid {
dev.DisplayName = displayName.String
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}
func (s *devicesStatements) selectDevicesByLocalpart( func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ([]api.Device, error) { ) ([]api.Device, error) {

View file

@ -71,6 +71,10 @@ func (d *Database) GetDevicesByLocalpart(
return d.devices.selectDevicesByLocalpart(ctx, localpart) return d.devices.selectDevicesByLocalpart(ctx, localpart)
} }
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart. // CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked // If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device, // and replaced with the given accessToken. If the given accessToken is already in use for another device,

View file

@ -20,6 +20,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
@ -72,6 +73,9 @@ const deleteDevicesByLocalpartSQL = "" +
const deleteDevicesSQL = "" + const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
type devicesStatements struct { type devicesStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter writer *sqlutil.TransactionWriter
@ -79,6 +83,7 @@ type devicesStatements struct {
selectDevicesCountStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt
selectDeviceByIDStmt *sql.Stmt selectDeviceByIDStmt *sql.Stmt
selectDevicesByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt selectDevicesByLocalpartStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt deleteDeviceStmt *sql.Stmt
@ -117,6 +122,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return return
} }
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
s.serverName = server s.serverName = server
return return
} }
@ -224,11 +232,15 @@ func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context, localpart, deviceID string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
var displayName sql.NullString
stmt := s.selectDeviceByIDStmt stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
if err == nil { if err == nil {
dev.ID = deviceID dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, s.serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
} }
return &dev, err return &dev, err
} }
@ -263,3 +275,32 @@ func (s *devicesStatements) selectDevicesByLocalpart(
return devices, nil return devices, nil
} }
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
iDeviceIDs := make([]interface{}, len(deviceIDs))
for i := range deviceIDs {
iDeviceIDs[i] = deviceIDs[i]
}
rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
var devices []api.Device
for rows.Next() {
var dev api.Device
var localpart string
var displayName sql.NullString
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
return nil, err
}
if displayName.Valid {
dev.DisplayName = displayName.String
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}

View file

@ -77,6 +77,10 @@ func (d *Database) GetDevicesByLocalpart(
return d.devices.selectDevicesByLocalpart(ctx, localpart) return d.devices.selectDevicesByLocalpart(ctx, localpart)
} }
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart. // CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked // If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device, // and replaced with the given accessToken. If the given accessToken is already in use for another device,