diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index 11c6c7827..d0b3bdbe5 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -115,33 +115,9 @@ func GetDevicesByLocalpart( // UpdateDeviceByID handles PUT on /devices/{deviceID} func UpdateDeviceByID( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI api.UserInternalAPI, device *api.Device, deviceID string, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - - ctx := req.Context() - dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID) - if err == sql.ErrNoRows { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Unknown device"), - } - } else if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByID failed") - return jsonerror.InternalServerError() - } - - if dev.UserID != device.UserID { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("device not owned by current user"), - } - } defer req.Body.Close() // nolint: errcheck @@ -152,10 +128,28 @@ func UpdateDeviceByID( return jsonerror.InternalServerError() } - if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.UpdateDevice failed") + var performRes api.PerformDeviceUpdateResponse + err := userAPI.PerformDeviceUpdate(req.Context(), &api.PerformDeviceUpdateRequest{ + RequestingUserID: device.UserID, + DeviceID: deviceID, + DisplayName: payload.DisplayName, + }, &performRes) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceUpdate failed") return jsonerror.InternalServerError() } + if !performRes.DeviceExists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.Forbidden("device does not exist"), + } + } + if performRes.Forbidden { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("device not owned by current user"), + } + } return util.JSONResponse{ Code: http.StatusOK, diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 7f47aafff..42f828f64 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -23,8 +23,8 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/config" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -57,7 +57,7 @@ func passwordLogin() flows { // Login implements GET and POST /login func Login( - req *http.Request, accountDB accounts.Database, deviceDB devices.Database, + req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI, cfg *config.Dendrite, ) util.JSONResponse { if req.Method == http.MethodGet { @@ -81,7 +81,7 @@ func Login( return *authErr } // make a device/access token - return completeAuth(req.Context(), cfg.Matrix.ServerName, deviceDB, login) + return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login) } return util.JSONResponse{ Code: http.StatusMethodNotAllowed, @@ -90,7 +90,7 @@ func Login( } func completeAuth( - ctx context.Context, serverName gomatrixserverlib.ServerName, deviceDB devices.Database, login *auth.Login, + ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.UserInternalAPI, login *auth.Login, ) util.JSONResponse { token, err := auth.GenerateAccessToken() if err != nil { @@ -104,9 +104,13 @@ func completeAuth( return jsonerror.InternalServerError() } - dev, err := deviceDB.CreateDevice( - ctx, localpart, login.DeviceID, token, login.InitialDisplayName, - ) + var performRes userapi.PerformDeviceCreationResponse + err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ + DeviceDisplayName: login.InitialDisplayName, + DeviceID: login.DeviceID, + AccessToken: token, + Localpart: localpart, + }, &performRes) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -117,10 +121,10 @@ func completeAuth( return util.JSONResponse{ Code: http.StatusOK, JSON: loginResponse{ - UserID: dev.UserID, - AccessToken: dev.AccessToken, + UserID: performRes.Device.UserID, + AccessToken: performRes.Device.AccessToken, HomeServer: serverName, - DeviceID: dev.ID, + DeviceID: performRes.Device.ID, }, } } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 6c40db865..0e58129ef 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -387,7 +387,7 @@ func Setup( r0mux.Handle("/login", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { - return Login(req, accountDB, deviceDB, cfg) + return Login(req, accountDB, userAPI, cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) @@ -644,7 +644,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return UpdateDeviceByID(req, deviceDB, device, vars["deviceID"]) + return UpdateDeviceByID(req, userAPI, device, vars["deviceID"]) }), ).Methods(http.MethodPut, http.MethodOptions) diff --git a/currentstateserver/currentstateserver_test.go b/currentstateserver/currentstateserver_test.go index 1366a0be8..193173906 100644 --- a/currentstateserver/currentstateserver_test.go +++ b/currentstateserver/currentstateserver_test.go @@ -19,6 +19,7 @@ import ( "crypto/ed25519" "encoding/json" "net/http" + "os" "reflect" "testing" "time" @@ -91,11 +92,13 @@ func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *rooms return nil } -func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.SyncProducer) { +func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.SyncProducer, func()) { cfg := &config.Dendrite{} + stateDBName := "test_state.db" + naffkaDBName := "test_naffka.db" cfg.Kafka.Topics.OutputRoomEvent = config.Topic(kafkaTopic) - cfg.Database.CurrentState = config.DataSource("file::memory:") - db, err := sqlutil.Open(sqlutil.SQLiteDriverName(), "file::memory:", nil) + cfg.Database.CurrentState = config.DataSource("file:" + stateDBName) + db, err := sqlutil.Open(sqlutil.SQLiteDriverName(), "file:"+naffkaDBName, nil) if err != nil { t.Fatalf("Failed to open naffka database: %s", err) } @@ -107,11 +110,15 @@ func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.Sync if err != nil { t.Fatalf("Failed to create naffka consumer: %s", err) } - return NewInternalAPI(cfg, naff), naff + return NewInternalAPI(cfg, naff), naff, func() { + os.Remove(naffkaDBName) + os.Remove(stateDBName) + } } func TestQueryCurrentState(t *testing.T) { - currStateAPI, producer := MustMakeInternalAPI(t) + currStateAPI, producer, cancel := MustMakeInternalAPI(t) + defer cancel() plTuple := gomatrixserverlib.StateKeyTuple{ EventType: "m.room.power_levels", StateKey: "", @@ -209,7 +216,8 @@ func mustMakeMembershipEvent(t *testing.T, roomID, userID, membership string) *r // This test makes sure that QuerySharedUsers is returning the correct users for a range of sets. func TestQuerySharedUsers(t *testing.T) { - currStateAPI, producer := MustMakeInternalAPI(t) + currStateAPI, producer, cancel := MustMakeInternalAPI(t) + defer cancel() MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join")) MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join")) @@ -222,6 +230,9 @@ func TestQuerySharedUsers(t *testing.T) { MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join")) + // we don't know when the server has processed the events + time.Sleep(10 * time.Millisecond) + testCases := []struct { req api.QuerySharedUsersRequest wantRes api.QuerySharedUsersResponse diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 480d1084e..bb8286635 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -206,6 +206,9 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques res.DeviceKeys[userID] = make(map[string]json.RawMessage) } for _, dk := range deviceKeys { + if len(dk.KeyJSON) == 0 { + continue // don't include blank keys + } // inject display name if known dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { DisplayName string `json:"device_display_name,omitempty"` diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index cb4fca7d8..0272ffc06 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -16,6 +16,7 @@ package internal import ( "context" + "strings" "github.com/Shopify/sarama" currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" @@ -88,6 +89,16 @@ func DeviceListCatchup( if !userSet[userID] { res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) hasNew = true + userSet[userID] = true + } + } + // if the response has any join/leave events, add them now. + // TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them. + for _, userID := range membershipEvents(res) { + if !userSet[userID] { + res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) + hasNew = true + userSet[userID] = true } } return hasNew, nil @@ -219,3 +230,25 @@ func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID strin } return false } + +// returns the user IDs of anyone joining or leaving a room in this response. These users will be added to +// the 'changed' property because of https://matrix.org/docs/spec/client_server/r0.6.1#id84 +// "For optimal performance, Alice should be added to changed in Bob's sync only when she adds a new device, +// or when Alice and Bob now share a room but didn't share any room previously. However, for the sake of simpler +// logic, a server may add Alice to changed when Alice and Bob share a new room, even if they previously already shared a room." +func membershipEvents(res *types.Response) (userIDs []string) { + for _, room := range res.Rooms.Join { + for _, ev := range room.Timeline.Events { + if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil { + if strings.Contains(string(ev.Content), `"join"`) { + userIDs = append(userIDs, *ev.StateKey) + } else if strings.Contains(string(ev.Content), `"leave"`) { + userIDs = append(userIDs, *ev.StateKey) + } else if strings.Contains(string(ev.Content), `"ban"`) { + userIDs = append(userIDs, *ev.StateKey) + } + } + } + } + return +} diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index f817f0981..b530b34d1 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -168,7 +168,7 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use } // work out room joins/leaves res, err := rp.db.IncrementalSync( - req.Context(), types.NewResponse(), *device, fromToken, toToken, 0, false, + req.Context(), types.NewResponse(), *device, fromToken, toToken, 10, false, ) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("Failed to IncrementalSync") diff --git a/sytest-whitelist b/sytest-whitelist index 03baf4d44..16a71c648 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -129,7 +129,11 @@ Can claim one time key using POST Can claim remote one time key using POST Local device key changes appear in v2 /sync Local device key changes appear in /keys/changes +New users appear in /keys/changes Local delete device changes appear in v2 /sync +Local new device changes appear in v2 /sync +Local update device changes appear in v2 /sync +Users receive device_list updates for their own devices Get left notifs for other users in sync and /keys/changes when user leaves Can add account data Can add account data to room diff --git a/userapi/api/api.go b/userapi/api/api.go index 5c964c4fd..84338dbf2 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -28,6 +28,7 @@ type UserInternalAPI interface { PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error + PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error @@ -48,6 +49,16 @@ type InputAccountDataRequest struct { type InputAccountDataResponse struct { } +type PerformDeviceUpdateRequest struct { + RequestingUserID string + DeviceID string + DisplayName *string +} +type PerformDeviceUpdateResponse struct { + DeviceExists bool + Forbidden bool +} + type PerformDeviceDeletionRequest struct { UserID string // The devices to delete diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 738023dd6..b9d188229 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -104,7 +104,8 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe } res.DeviceCreated = true res.Device = dev - return nil + // create empty device keys and upload them to trigger device list changes + return a.deviceListUpdate(dev.UserID, []string{dev.ID}) } func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.PerformDeviceDeletionRequest, res *api.PerformDeviceDeletionResponse) error { @@ -121,10 +122,14 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe return err } // create empty device keys and upload them to delete what was once there and trigger device list changes - deviceKeys := make([]keyapi.DeviceKeys, len(req.DeviceIDs)) - for i, did := range req.DeviceIDs { + return a.deviceListUpdate(req.UserID, req.DeviceIDs) +} + +func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error { + deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs)) + for i, did := range deviceIDs { deviceKeys[i] = keyapi.DeviceKeys{ - UserID: req.UserID, + UserID: userID, DeviceID: did, KeyJSON: nil, } @@ -143,6 +148,35 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe return nil } +func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { + localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") + return err + } + dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID) + if err == sql.ErrNoRows { + res.DeviceExists = false + return nil + } else if err != nil { + util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed") + return err + } + res.DeviceExists = true + + if dev.UserID != req.RequestingUserID { + res.Forbidden = true + return nil + } + + err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") + return err + } + return nil +} + func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfileRequest, res *api.QueryProfileResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 47e2110f9..5f4df0eb1 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -31,6 +31,7 @@ const ( PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformAccountCreationPath = "/userapi/performAccountCreation" PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" + PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" QueryProfilePath = "/userapi/queryProfile" QueryAccessTokenPath = "/userapi/queryAccessToken" @@ -104,6 +105,14 @@ func (h *httpUserInternalAPI) PerformDeviceDeletion( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +func (h *httpUserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceUpdate") + defer span.Finish() + + apiURL := h.apiURL + PerformDeviceUpdatePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + func (h *httpUserInternalAPI) QueryProfile( ctx context.Context, request *api.QueryProfileRequest, diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index ebb9bf4e8..47d68ff21 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -52,6 +52,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(PerformDeviceUpdatePath, + httputil.MakeInternalAPI("performDeviceUpdate", func(req *http.Request) util.JSONResponse { + request := api.PerformDeviceUpdateRequest{} + response := api.PerformDeviceUpdateResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformDeviceUpdate(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(PerformDeviceDeletionPath, httputil.MakeInternalAPI("performDeviceDeletion", func(req *http.Request) util.JSONResponse { request := api.PerformDeviceDeletionRequest{} diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index efe6f927c..9b535aab9 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -174,7 +174,7 @@ func (s *devicesStatements) deleteDevice( func (s *devicesStatements) deleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { - orig := strings.Replace(deleteDevicesSQL, "($1)", sqlutil.QueryVariadic(len(devices)), 1) + orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) prep, err := s.db.Prepare(orig) if err != nil { return err @@ -186,7 +186,6 @@ func (s *devicesStatements) deleteDevices( for i, v := range devices { params[i+1] = v } - params = append(params, params...) _, err = stmt.ExecContext(ctx, params...) return err })