Fix edge cases around device lists (#1234)
* Fix New users appear in /keys/changes * Create blank device keys when logging in on a new device * Add PerformDeviceUpdate and fix a few bugs - Correct device deletion query on sqlite - Return no keys on /keys/query rather than an empty key * Unbreak sqlite properly * Use a real DB for currentstateserver integration tests * Race fix
This commit is contained in:
parent
a7e67e65a8
commit
b5cb1d1534
|
@ -115,33 +115,9 @@ func GetDevicesByLocalpart(
|
||||||
|
|
||||||
// UpdateDeviceByID handles PUT on /devices/{deviceID}
|
// UpdateDeviceByID handles PUT on /devices/{deviceID}
|
||||||
func UpdateDeviceByID(
|
func UpdateDeviceByID(
|
||||||
req *http.Request, deviceDB devices.Database, device *api.Device,
|
req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
|
||||||
deviceID string,
|
deviceID string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
|
||||||
return jsonerror.InternalServerError()
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := req.Context()
|
|
||||||
dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: http.StatusNotFound,
|
|
||||||
JSON: jsonerror.NotFound("Unknown device"),
|
|
||||||
}
|
|
||||||
} else if err != nil {
|
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByID failed")
|
|
||||||
return jsonerror.InternalServerError()
|
|
||||||
}
|
|
||||||
|
|
||||||
if dev.UserID != device.UserID {
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: http.StatusForbidden,
|
|
||||||
JSON: jsonerror.Forbidden("device not owned by current user"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
defer req.Body.Close() // nolint: errcheck
|
defer req.Body.Close() // nolint: errcheck
|
||||||
|
|
||||||
|
@ -152,10 +128,28 @@ func UpdateDeviceByID(
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil {
|
var performRes api.PerformDeviceUpdateResponse
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("deviceDB.UpdateDevice failed")
|
err := userAPI.PerformDeviceUpdate(req.Context(), &api.PerformDeviceUpdateRequest{
|
||||||
|
RequestingUserID: device.UserID,
|
||||||
|
DeviceID: deviceID,
|
||||||
|
DisplayName: payload.DisplayName,
|
||||||
|
}, &performRes)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceUpdate failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
if !performRes.DeviceExists {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusNotFound,
|
||||||
|
JSON: jsonerror.Forbidden("device does not exist"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if performRes.Forbidden {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.Forbidden("device not owned by current user"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
|
|
|
@ -23,8 +23,8 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
"github.com/matrix-org/dendrite/internal/config"
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/devices"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
@ -57,7 +57,7 @@ func passwordLogin() flows {
|
||||||
|
|
||||||
// Login implements GET and POST /login
|
// Login implements GET and POST /login
|
||||||
func Login(
|
func Login(
|
||||||
req *http.Request, accountDB accounts.Database, deviceDB devices.Database,
|
req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI,
|
||||||
cfg *config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if req.Method == http.MethodGet {
|
if req.Method == http.MethodGet {
|
||||||
|
@ -81,7 +81,7 @@ func Login(
|
||||||
return *authErr
|
return *authErr
|
||||||
}
|
}
|
||||||
// make a device/access token
|
// make a device/access token
|
||||||
return completeAuth(req.Context(), cfg.Matrix.ServerName, deviceDB, login)
|
return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login)
|
||||||
}
|
}
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusMethodNotAllowed,
|
Code: http.StatusMethodNotAllowed,
|
||||||
|
@ -90,7 +90,7 @@ func Login(
|
||||||
}
|
}
|
||||||
|
|
||||||
func completeAuth(
|
func completeAuth(
|
||||||
ctx context.Context, serverName gomatrixserverlib.ServerName, deviceDB devices.Database, login *auth.Login,
|
ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.UserInternalAPI, login *auth.Login,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
token, err := auth.GenerateAccessToken()
|
token, err := auth.GenerateAccessToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -104,9 +104,13 @@ func completeAuth(
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
dev, err := deviceDB.CreateDevice(
|
var performRes userapi.PerformDeviceCreationResponse
|
||||||
ctx, localpart, login.DeviceID, token, login.InitialDisplayName,
|
err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{
|
||||||
)
|
DeviceDisplayName: login.InitialDisplayName,
|
||||||
|
DeviceID: login.DeviceID,
|
||||||
|
AccessToken: token,
|
||||||
|
Localpart: localpart,
|
||||||
|
}, &performRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusInternalServerError,
|
Code: http.StatusInternalServerError,
|
||||||
|
@ -117,10 +121,10 @@ func completeAuth(
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: loginResponse{
|
JSON: loginResponse{
|
||||||
UserID: dev.UserID,
|
UserID: performRes.Device.UserID,
|
||||||
AccessToken: dev.AccessToken,
|
AccessToken: performRes.Device.AccessToken,
|
||||||
HomeServer: serverName,
|
HomeServer: serverName,
|
||||||
DeviceID: dev.ID,
|
DeviceID: performRes.Device.ID,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -387,7 +387,7 @@ func Setup(
|
||||||
|
|
||||||
r0mux.Handle("/login",
|
r0mux.Handle("/login",
|
||||||
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
|
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
|
||||||
return Login(req, accountDB, deviceDB, cfg)
|
return Login(req, accountDB, userAPI, cfg)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
|
@ -644,7 +644,7 @@ func Setup(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return UpdateDeviceByID(req, deviceDB, device, vars["deviceID"])
|
return UpdateDeviceByID(req, userAPI, device, vars["deviceID"])
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -91,11 +92,13 @@ func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *rooms
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.SyncProducer) {
|
func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.SyncProducer, func()) {
|
||||||
cfg := &config.Dendrite{}
|
cfg := &config.Dendrite{}
|
||||||
|
stateDBName := "test_state.db"
|
||||||
|
naffkaDBName := "test_naffka.db"
|
||||||
cfg.Kafka.Topics.OutputRoomEvent = config.Topic(kafkaTopic)
|
cfg.Kafka.Topics.OutputRoomEvent = config.Topic(kafkaTopic)
|
||||||
cfg.Database.CurrentState = config.DataSource("file::memory:")
|
cfg.Database.CurrentState = config.DataSource("file:" + stateDBName)
|
||||||
db, err := sqlutil.Open(sqlutil.SQLiteDriverName(), "file::memory:", nil)
|
db, err := sqlutil.Open(sqlutil.SQLiteDriverName(), "file:"+naffkaDBName, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to open naffka database: %s", err)
|
t.Fatalf("Failed to open naffka database: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -107,11 +110,15 @@ func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.Sync
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create naffka consumer: %s", err)
|
t.Fatalf("Failed to create naffka consumer: %s", err)
|
||||||
}
|
}
|
||||||
return NewInternalAPI(cfg, naff), naff
|
return NewInternalAPI(cfg, naff), naff, func() {
|
||||||
|
os.Remove(naffkaDBName)
|
||||||
|
os.Remove(stateDBName)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryCurrentState(t *testing.T) {
|
func TestQueryCurrentState(t *testing.T) {
|
||||||
currStateAPI, producer := MustMakeInternalAPI(t)
|
currStateAPI, producer, cancel := MustMakeInternalAPI(t)
|
||||||
|
defer cancel()
|
||||||
plTuple := gomatrixserverlib.StateKeyTuple{
|
plTuple := gomatrixserverlib.StateKeyTuple{
|
||||||
EventType: "m.room.power_levels",
|
EventType: "m.room.power_levels",
|
||||||
StateKey: "",
|
StateKey: "",
|
||||||
|
@ -209,7 +216,8 @@ func mustMakeMembershipEvent(t *testing.T, roomID, userID, membership string) *r
|
||||||
|
|
||||||
// This test makes sure that QuerySharedUsers is returning the correct users for a range of sets.
|
// This test makes sure that QuerySharedUsers is returning the correct users for a range of sets.
|
||||||
func TestQuerySharedUsers(t *testing.T) {
|
func TestQuerySharedUsers(t *testing.T) {
|
||||||
currStateAPI, producer := MustMakeInternalAPI(t)
|
currStateAPI, producer, cancel := MustMakeInternalAPI(t)
|
||||||
|
defer cancel()
|
||||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join"))
|
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join"))
|
||||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join"))
|
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join"))
|
||||||
|
|
||||||
|
@ -222,6 +230,9 @@ func TestQuerySharedUsers(t *testing.T) {
|
||||||
|
|
||||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join"))
|
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join"))
|
||||||
|
|
||||||
|
// we don't know when the server has processed the events
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
req api.QuerySharedUsersRequest
|
req api.QuerySharedUsersRequest
|
||||||
wantRes api.QuerySharedUsersResponse
|
wantRes api.QuerySharedUsersResponse
|
||||||
|
|
|
@ -206,6 +206,9 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
||||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||||
}
|
}
|
||||||
for _, dk := range deviceKeys {
|
for _, dk := range deviceKeys {
|
||||||
|
if len(dk.KeyJSON) == 0 {
|
||||||
|
continue // don't include blank keys
|
||||||
|
}
|
||||||
// inject display name if known
|
// inject display name if known
|
||||||
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
|
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
|
||||||
DisplayName string `json:"device_display_name,omitempty"`
|
DisplayName string `json:"device_display_name,omitempty"`
|
||||||
|
|
|
@ -16,6 +16,7 @@ package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/Shopify/sarama"
|
"github.com/Shopify/sarama"
|
||||||
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
|
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
|
||||||
|
@ -88,6 +89,16 @@ func DeviceListCatchup(
|
||||||
if !userSet[userID] {
|
if !userSet[userID] {
|
||||||
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
|
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
|
||||||
hasNew = true
|
hasNew = true
|
||||||
|
userSet[userID] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if the response has any join/leave events, add them now.
|
||||||
|
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
|
||||||
|
for _, userID := range membershipEvents(res) {
|
||||||
|
if !userSet[userID] {
|
||||||
|
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
|
||||||
|
hasNew = true
|
||||||
|
userSet[userID] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return hasNew, nil
|
return hasNew, nil
|
||||||
|
@ -219,3 +230,25 @@ func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID strin
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// returns the user IDs of anyone joining or leaving a room in this response. These users will be added to
|
||||||
|
// the 'changed' property because of https://matrix.org/docs/spec/client_server/r0.6.1#id84
|
||||||
|
// "For optimal performance, Alice should be added to changed in Bob's sync only when she adds a new device,
|
||||||
|
// or when Alice and Bob now share a room but didn't share any room previously. However, for the sake of simpler
|
||||||
|
// logic, a server may add Alice to changed when Alice and Bob share a new room, even if they previously already shared a room."
|
||||||
|
func membershipEvents(res *types.Response) (userIDs []string) {
|
||||||
|
for _, room := range res.Rooms.Join {
|
||||||
|
for _, ev := range room.Timeline.Events {
|
||||||
|
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil {
|
||||||
|
if strings.Contains(string(ev.Content), `"join"`) {
|
||||||
|
userIDs = append(userIDs, *ev.StateKey)
|
||||||
|
} else if strings.Contains(string(ev.Content), `"leave"`) {
|
||||||
|
userIDs = append(userIDs, *ev.StateKey)
|
||||||
|
} else if strings.Contains(string(ev.Content), `"ban"`) {
|
||||||
|
userIDs = append(userIDs, *ev.StateKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -168,7 +168,7 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use
|
||||||
}
|
}
|
||||||
// work out room joins/leaves
|
// work out room joins/leaves
|
||||||
res, err := rp.db.IncrementalSync(
|
res, err := rp.db.IncrementalSync(
|
||||||
req.Context(), types.NewResponse(), *device, fromToken, toToken, 0, false,
|
req.Context(), types.NewResponse(), *device, fromToken, toToken, 10, false,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("Failed to IncrementalSync")
|
util.GetLogger(req.Context()).WithError(err).Error("Failed to IncrementalSync")
|
||||||
|
|
|
@ -129,7 +129,11 @@ Can claim one time key using POST
|
||||||
Can claim remote one time key using POST
|
Can claim remote one time key using POST
|
||||||
Local device key changes appear in v2 /sync
|
Local device key changes appear in v2 /sync
|
||||||
Local device key changes appear in /keys/changes
|
Local device key changes appear in /keys/changes
|
||||||
|
New users appear in /keys/changes
|
||||||
Local delete device changes appear in v2 /sync
|
Local delete device changes appear in v2 /sync
|
||||||
|
Local new device changes appear in v2 /sync
|
||||||
|
Local update device changes appear in v2 /sync
|
||||||
|
Users receive device_list updates for their own devices
|
||||||
Get left notifs for other users in sync and /keys/changes when user leaves
|
Get left notifs for other users in sync and /keys/changes when user leaves
|
||||||
Can add account data
|
Can add account data
|
||||||
Can add account data to room
|
Can add account data to room
|
||||||
|
|
|
@ -28,6 +28,7 @@ type UserInternalAPI interface {
|
||||||
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
|
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
|
||||||
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
|
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
|
||||||
PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error
|
PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error
|
||||||
|
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
|
||||||
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
||||||
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
|
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
|
||||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||||
|
@ -48,6 +49,16 @@ type InputAccountDataRequest struct {
|
||||||
type InputAccountDataResponse struct {
|
type InputAccountDataResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PerformDeviceUpdateRequest struct {
|
||||||
|
RequestingUserID string
|
||||||
|
DeviceID string
|
||||||
|
DisplayName *string
|
||||||
|
}
|
||||||
|
type PerformDeviceUpdateResponse struct {
|
||||||
|
DeviceExists bool
|
||||||
|
Forbidden bool
|
||||||
|
}
|
||||||
|
|
||||||
type PerformDeviceDeletionRequest struct {
|
type PerformDeviceDeletionRequest struct {
|
||||||
UserID string
|
UserID string
|
||||||
// The devices to delete
|
// The devices to delete
|
||||||
|
|
|
@ -104,7 +104,8 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
|
||||||
}
|
}
|
||||||
res.DeviceCreated = true
|
res.DeviceCreated = true
|
||||||
res.Device = dev
|
res.Device = dev
|
||||||
return nil
|
// create empty device keys and upload them to trigger device list changes
|
||||||
|
return a.deviceListUpdate(dev.UserID, []string{dev.ID})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.PerformDeviceDeletionRequest, res *api.PerformDeviceDeletionResponse) error {
|
func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.PerformDeviceDeletionRequest, res *api.PerformDeviceDeletionResponse) error {
|
||||||
|
@ -121,10 +122,14 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// create empty device keys and upload them to delete what was once there and trigger device list changes
|
// create empty device keys and upload them to delete what was once there and trigger device list changes
|
||||||
deviceKeys := make([]keyapi.DeviceKeys, len(req.DeviceIDs))
|
return a.deviceListUpdate(req.UserID, req.DeviceIDs)
|
||||||
for i, did := range req.DeviceIDs {
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error {
|
||||||
|
deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs))
|
||||||
|
for i, did := range deviceIDs {
|
||||||
deviceKeys[i] = keyapi.DeviceKeys{
|
deviceKeys[i] = keyapi.DeviceKeys{
|
||||||
UserID: req.UserID,
|
UserID: userID,
|
||||||
DeviceID: did,
|
DeviceID: did,
|
||||||
KeyJSON: nil,
|
KeyJSON: nil,
|
||||||
}
|
}
|
||||||
|
@ -143,6 +148,35 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
|
||||||
|
localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
res.DeviceExists = false
|
||||||
|
return nil
|
||||||
|
} else if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.DeviceExists = true
|
||||||
|
|
||||||
|
if dev.UserID != req.RequestingUserID {
|
||||||
|
res.Forbidden = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfileRequest, res *api.QueryProfileResponse) error {
|
func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfileRequest, res *api.QueryProfileResponse) error {
|
||||||
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -31,6 +31,7 @@ const (
|
||||||
PerformDeviceCreationPath = "/userapi/performDeviceCreation"
|
PerformDeviceCreationPath = "/userapi/performDeviceCreation"
|
||||||
PerformAccountCreationPath = "/userapi/performAccountCreation"
|
PerformAccountCreationPath = "/userapi/performAccountCreation"
|
||||||
PerformDeviceDeletionPath = "/userapi/performDeviceDeletion"
|
PerformDeviceDeletionPath = "/userapi/performDeviceDeletion"
|
||||||
|
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
|
||||||
|
|
||||||
QueryProfilePath = "/userapi/queryProfile"
|
QueryProfilePath = "/userapi/queryProfile"
|
||||||
QueryAccessTokenPath = "/userapi/queryAccessToken"
|
QueryAccessTokenPath = "/userapi/queryAccessToken"
|
||||||
|
@ -104,6 +105,14 @@ func (h *httpUserInternalAPI) PerformDeviceDeletion(
|
||||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceUpdate")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + PerformDeviceUpdatePath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *httpUserInternalAPI) QueryProfile(
|
func (h *httpUserInternalAPI) QueryProfile(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *api.QueryProfileRequest,
|
request *api.QueryProfileRequest,
|
||||||
|
|
|
@ -52,6 +52,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(PerformDeviceUpdatePath,
|
||||||
|
httputil.MakeInternalAPI("performDeviceUpdate", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.PerformDeviceUpdateRequest{}
|
||||||
|
response := api.PerformDeviceUpdateResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := s.PerformDeviceUpdate(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
internalAPIMux.Handle(PerformDeviceDeletionPath,
|
internalAPIMux.Handle(PerformDeviceDeletionPath,
|
||||||
httputil.MakeInternalAPI("performDeviceDeletion", func(req *http.Request) util.JSONResponse {
|
httputil.MakeInternalAPI("performDeviceDeletion", func(req *http.Request) util.JSONResponse {
|
||||||
request := api.PerformDeviceDeletionRequest{}
|
request := api.PerformDeviceDeletionRequest{}
|
||||||
|
|
|
@ -174,7 +174,7 @@ func (s *devicesStatements) deleteDevice(
|
||||||
func (s *devicesStatements) deleteDevices(
|
func (s *devicesStatements) deleteDevices(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||||
) error {
|
) error {
|
||||||
orig := strings.Replace(deleteDevicesSQL, "($1)", sqlutil.QueryVariadic(len(devices)), 1)
|
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
|
||||||
prep, err := s.db.Prepare(orig)
|
prep, err := s.db.Prepare(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -186,7 +186,6 @@ func (s *devicesStatements) deleteDevices(
|
||||||
for i, v := range devices {
|
for i, v := range devices {
|
||||||
params[i+1] = v
|
params[i+1] = v
|
||||||
}
|
}
|
||||||
params = append(params, params...)
|
|
||||||
_, err = stmt.ExecContext(ctx, params...)
|
_, err = stmt.ExecContext(ctx, params...)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
|
Loading…
Reference in a new issue