diff --git a/userapi/producers/keychange.go b/userapi/producers/keychange.go index f7664b10f..da6cea31a 100644 --- a/userapi/producers/keychange.go +++ b/userapi/producers/keychange.go @@ -28,7 +28,7 @@ import ( // KeyChange produces key change events for the sync API and federation sender to consume type KeyChange struct { Topic string - JetStream nats.JetStreamContext + JetStream JetStreamPublisher DB storage.KeyChangeDatabase } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 17127ba00..588e40f18 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" "github.com/nats-io/nats.go" "golang.org/x/crypto/bcrypt" @@ -75,11 +76,12 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap } syncProducer := producers.NewSyncAPI(accountDB, &dummyProducer{}, "", "") - + keyChangeProducer := &producers.KeyChange{DB: accountDB, JetStream: &dummyProducer{}} return &internal.UserInternalAPI{ - DB: accountDB, - Config: cfg, - SyncProducer: syncProducer, + DB: accountDB, + Config: cfg, + SyncProducer: syncProducer, + KeyChangeProducer: keyChangeProducer, }, accountDB, func() { close() baseclose() @@ -432,3 +434,176 @@ func TestAccountData(t *testing.T) { } }) } + +func TestDevices(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + _ = alice + + dupeAccessToken := util.RandomString(8) + + creationTests := []struct { + name string + inputData *api.PerformDeviceCreationRequest + wantErr bool + wantNewDevID bool + }{ + { + name: "not a local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", ServerName: "notlocal"}, + wantErr: true, + }, + { + name: "implicit local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + { + name: "explicit local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test2", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + { + name: "dupe token - ok", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true}, + }, + { + name: "dupe token - not ok", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true}, + wantErr: true, + }, + { + name: "test3 second device", // used to test deletion later + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + { + name: "test3 third device", // used to test deletion later + wantNewDevID: true, + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + } + + deletionTests := []struct { + name string + inputData *api.PerformDeviceDeletionRequest + wantErr bool + wantDevices int + }{ + { + name: "deletion - not a local user", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test:notlocalhost"}, + wantErr: true, + }, + { + name: "deleting not existing devices should not error", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test", DeviceIDs: []string{"iDontExist"}}, + wantDevices: 1, + }, + { + name: "delete all devices", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test"}, + wantDevices: 0, + }, + { + name: "delete all devices", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test3:test"}, + wantDevices: 0, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType) + defer close() + + for _, tc := range creationTests { + t.Run(tc.name, func(t *testing.T) { + res := api.PerformDeviceCreationResponse{} + deviceID := util.RandomString(8) + tc.inputData.DeviceID = &deviceID + if tc.wantNewDevID { + tc.inputData.DeviceID = nil + } + err := intAPI.PerformDeviceCreation(ctx, tc.inputData, &res) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + if !res.DeviceCreated { + return + } + + queryDevicesRes := api.QueryDevicesResponse{} + queryDevicesReq := api.QueryDevicesRequest{UserID: res.Device.UserID} + if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil { + t.Fatal(err) + } + // We only want to verify one device + if len(queryDevicesRes.Devices) > 1 { + return + } + res.Device.AccessToken = "" + + // At this point, there should only be one device + if !reflect.DeepEqual(*res.Device, queryDevicesRes.Devices[0]) { + t.Fatalf("expected device to be\n%#v, got \n%#v", *res.Device, queryDevicesRes.Devices[0]) + } + }) + } + + for _, tc := range deletionTests { + t.Run(tc.name, func(t *testing.T) { + delRes := api.PerformDeviceDeletionResponse{} + err := intAPI.PerformDeviceDeletion(ctx, tc.inputData, &delRes) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + if tc.wantErr { + return + } + + queryDevicesRes := api.QueryDevicesResponse{} + queryDevicesReq := api.QueryDevicesRequest{UserID: tc.inputData.UserID} + if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil { + t.Fatal(err) + } + + if len(queryDevicesRes.Devices) != tc.wantDevices { + t.Fatalf("expected %d devices, got %d", tc.wantDevices, len(queryDevicesRes.Devices)) + } + + }) + } + }) +} + +// Tests that the session ID of a device is not reused when reusing the same device ID. +func TestDeviceIDReuse(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType) + defer close() + + res := api.PerformDeviceCreationResponse{} + // create a first device + deviceID := util.RandomString(8) + req := api.PerformDeviceCreationRequest{Localpart: "alice", ServerName: "test", DeviceID: &deviceID, NoDeviceListUpdate: true} + err := intAPI.PerformDeviceCreation(ctx, &req, &res) + if err != nil { + t.Fatal(err) + } + + // Do the same request again, we expect a different sessionID + res2 := api.PerformDeviceCreationResponse{} + err = intAPI.PerformDeviceCreation(ctx, &req, &res2) + if err != nil { + t.Fatalf("expected no error, but got: %v", err) + } + + if res2.Device.SessionID == res.Device.SessionID { + t.Fatalf("expected a different session ID, but they are the same") + } + }) +}