From 9dd1824741ab13a13ae84aadf5a0226831a118ee Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 7 Jul 2021 16:05:53 -0700 Subject: [PATCH] Fix dispatch keys for eigh, lu_solve (#60945) Summary: I added a test to `test_ops.py` that verifies that the op can run correctly from different cuda devices. This test revealed that `linalg_eigh`, `linalg_eigvalsh`, `linalg_matrix_rank`, `linalg_pinv` were failing. `matrix_rank` and `pinv` are calling `eigh` internally. `linalg_eigh` and `lu_solve` internally use dispatch stubs, so they should be registered with `CPU, CUDA` dispatch keys. The generated code includes device guards in this case and the problem is not present. Implemented a better out variant for `eigvalsh` and registered it with `CPU, CUDA` dispatch keys. ~I added a device guard to `linalg_eigh_kernel` as a fix for `eigvalsh` function. This function needs to be registered as CompositeImplicitAutograd, because it calls `at::linalg_eigh` if `at::GradMode::is_enabled()`.~ Fixes https://github.com/pytorch/pytorch/issues/60892. Pull Request resolved: https://github.com/pytorch/pytorch/pull/60945 Reviewed By: mruberry Differential Revision: D29589580 Pulled By: ngimel fbshipit-source-id: 5851605958bdfc3a1a1768263934619449957168 --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 59 ++++++++++--------- aten/src/ATen/native/BatchLinearAlgebra.h | 6 +- .../ATen/native/BatchLinearAlgebraKernel.cpp | 4 +- .../ATen/native/cuda/BatchLinearAlgebra.cu | 13 ++-- .../ATen/native/cuda/BatchLinearAlgebraLib.cu | 12 ++-- .../ATen/native/cuda/BatchLinearAlgebraLib.h | 2 +- aten/src/ATen/native/native_functions.yaml | 10 ++-- test/test_linalg.py | 7 +-- test/test_ops.py | 22 ++++++- .../_internal/common_methods_invocations.py | 10 ++-- 10 files changed, 86 insertions(+), 59 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 03aefb1dce..94b842ec41 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -2246,11 +2246,11 @@ DEFINE_DISPATCH(linalg_eigh_stub); * 'uplo_str' - controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L" "u", "U" - upper triangular portion of the input matrix is used in computations; "l", "L" - lower. */ -std::tuple linalg_eigh_out_info( +void linalg_eigh_out_info( const Tensor& input, - Tensor& values, - Tensor& vectors, - Tensor& infos, + const Tensor& values, + const Tensor& vectors, + const Tensor& infos, bool compute_eigenvectors, const c10::string_view uplo_str) { // These internal asserts make explicit the assumptions in the implementation @@ -2306,8 +2306,6 @@ std::tuple linalg_eigh_out_info( bool upper = (uplo == 'U'); linalg_eigh_stub(input.device().type(), values, vectors, infos, upper, compute_eigenvectors); - - return std::tuple(values, vectors); } std::tuple linalg_eigh(const Tensor& input, c10::string_view uplo) { @@ -2318,7 +2316,7 @@ std::tuple linalg_eigh(const Tensor& input, c10::string_view upl Tensor vectors = at::empty({0}, input.options()); Tensor infos = at::zeros({std::max(1, batchCount(input))}, input.options().dtype(kInt)); - std::tie(values, vectors) = linalg_eigh_out_info(input, values, vectors, infos, true, uplo); + linalg_eigh_out_info(input, values, vectors, infos, true, uplo); if (input.dim() > 2) { batchCheckErrors(infos, "torch.linalg.eigh"); @@ -2332,8 +2330,6 @@ std::tuple linalg_eigh(const Tensor& input, c10::string_view upl // TODO: it's possible to make the _out variant to be a primal function and implement linalg_eigh on top of _out // TODO: implement _out variant avoiding copy and using already allocated storage directly std::tuple linalg_eigh_out(const Tensor& input, c10::string_view uplo, Tensor& eigvals, Tensor& eigvecs) { - checkSameDevice("torch.linalg.eigh", eigvecs, input, "eigenvectors"); - checkSameDevice("torch.linalg.eigh", eigvals, input, "eigenvalues"); checkLinalgCompatibleDtype("torch.linalg.eigh", eigvecs, input, "eigenvectors"); // eigenvalues are always real-valued here @@ -2360,14 +2356,38 @@ Tensor linalg_eigvalsh(const Tensor& input, c10::string_view uplo) { return values; } - squareCheckInputs(input); - checkUplo(uplo); ScalarType real_dtype = toValueType(input.scalar_type()); Tensor values = at::empty({0}, input.options().dtype(real_dtype)); + values = at::linalg_eigvalsh_outf(input, uplo, values); + return values; +} + +Tensor& linalg_eigvalsh_out(const Tensor& input, c10::string_view uplo, Tensor& result) { + ScalarType real_dtype = toValueType(input.scalar_type()); + checkLinalgCompatibleDtype("torch.linalg.eigvalsh", result.scalar_type(), real_dtype); + + squareCheckInputs(input); + checkUplo(uplo); + + auto expected_result_shape = IntArrayRef(input.sizes().data(), input.dim()-1); // input.shape[:-1] + bool result_equal_expected_shape = result.sizes().equals(expected_result_shape); + bool expected_result_type = (result.scalar_type() == real_dtype); + bool copy_needed = !expected_result_type; + copy_needed |= (result.numel() != 0 && !result_equal_expected_shape); + copy_needed |= (result.numel() != 0 && !result.is_contiguous()); + Tensor vectors = at::empty({0}, input.options()); Tensor infos = at::zeros({std::max(1, batchCount(input))}, input.options().dtype(kInt)); - std::tie(values, vectors) = linalg_eigh_out_info(input, values, vectors, infos, false, uplo); + if (copy_needed) { // we have to allocate a temporary tensor + Tensor result_tmp = at::empty({expected_result_shape}, input.options().dtype(real_dtype)); + linalg_eigh_out_info(input, result_tmp, vectors, infos, /*compute_eigenvectors=*/false, uplo); + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + } else { + // else use the provided output storage directly + linalg_eigh_out_info(input, result, vectors, infos, /*compute_eigenvectors=*/false, uplo); + } if (input.dim() > 2) { batchCheckErrors(infos, "torch.linalg.eigvalsh"); @@ -2375,21 +2395,6 @@ Tensor linalg_eigvalsh(const Tensor& input, c10::string_view uplo) { singleCheckErrors(infos.item().toInt(), "torch.linalg.eigvalsh"); } - return values; -} - -// TODO: it's possible to make the _out variant to be a primal function and implement linalg_eigvalsh on top of _out -// TODO: implement _out variant avoiding copy and using already allocated storage directly -Tensor& linalg_eigvalsh_out(const Tensor& input, c10::string_view uplo, Tensor& result) { - checkSameDevice("torch.linalg.eigvalsh", result, input); - ScalarType real_dtype = toValueType(input.scalar_type()); - checkLinalgCompatibleDtype("torch.linalg.eigvalsh", result.scalar_type(), real_dtype); - - Tensor result_tmp = at::linalg_eigvalsh(input, uplo); - - at::native::resize_output(result, result_tmp.sizes()); - result.copy_(result_tmp); - return result; } diff --git a/aten/src/ATen/native/BatchLinearAlgebra.h b/aten/src/ATen/native/BatchLinearAlgebra.h index 37fded3acd..0c9046c76f 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.h +++ b/aten/src/ATen/native/BatchLinearAlgebra.h @@ -191,9 +191,9 @@ using ormqr_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, const DECLARE_DISPATCH(ormqr_fn, ormqr_stub); using linalg_eigh_fn = void (*)( - Tensor& /*eigenvalues*/, - Tensor& /*eigenvectors*/, - Tensor& /*infos*/, + const Tensor& /*eigenvalues*/, + const Tensor& /*eigenvectors*/, + const Tensor& /*infos*/, bool /*upper*/, bool /*compute_eigenvectors*/); DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub); diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index 86138b9b4f..b801f819f3 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -293,7 +293,7 @@ void linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, This function doesn't do any error checks and it's assumed that every argument is valid. */ template -void apply_lapack_eigh(Tensor& values, Tensor& vectors, Tensor& infos, bool upper, bool compute_eigenvectors) { +void apply_lapack_eigh(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { #if !AT_BUILD_WITH_LAPACK() TORCH_CHECK( false, @@ -365,7 +365,7 @@ void apply_lapack_eigh(Tensor& values, Tensor& vectors, Tensor& infos, bool uppe } // This is a type dispatching helper function for 'apply_lapack_eigh' -void linalg_eigh_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, bool upper, bool compute_eigenvectors) { +void linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { // This function calculates the symmetric/hermitian eigendecomposition // in-place tensors should be in batched column major memory format the // content of eigenvalues, eigenvectors and infos is overwritten by diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 2e540e34a1..0339d9304a 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -2268,7 +2268,7 @@ std::tuple _linalg_qr_helper_cuda(const Tensor& input, c10::stri // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ symeig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template -static void apply_magma_eigh(Tensor& values, Tensor& vectors, Tensor& infos, bool upper, bool compute_eigenvectors) { +static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { #ifndef USE_MAGMA TORCH_CHECK( false, @@ -2381,23 +2381,24 @@ std::tuple _symeig_helper_cuda(const Tensor& self, bool eigenvec // This is a type dispatch function for 'apply_magma_eigh' // For small inputs result is computed on CPU -void linalg_eigh_magma(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, bool upper, bool compute_eigenvectors) { +void linalg_eigh_magma(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { // MAGMA just calls LAPACK for eigenvectors.size(-1) <= 128 // See https://bitbucket.org/icl/magma/src/e6fdca447bd402693e8b0b950a898b6879bbcc41/src/zheevd_gpu.cpp?at=master#lines-258 // in addition lda is ignored breaking 0x0 inputs if (eigenvectors.size(-1) > 128) { // MAGMA requires eigenvalues and infos tensors to reside on CPU Tensor eigenvalues_cpu = eigenvalues.to(kCPU); - infos = infos.to(kCPU); + Tensor infos_cpu = infos.to(kCPU); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( - eigenvectors.scalar_type(), "linalg_eigh_cpu", [&] { + eigenvectors.scalar_type(), "linalg_eigh_magma", [&] { apply_magma_eigh( - eigenvalues_cpu, eigenvectors, infos, upper, compute_eigenvectors); + eigenvalues_cpu, eigenvectors, infos_cpu, upper, compute_eigenvectors); }); // Transfer computed by MAGMA results from CPU to GPU eigenvalues.copy_(eigenvalues_cpu); + infos.copy_(infos_cpu); } else { // eigenvectors.size(-1) <= 128 // transfer to CPU, compute the result and copy back to GPU // this is faster than going through MAGMA that does the same @@ -2413,7 +2414,7 @@ void linalg_eigh_magma(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, } } -void linalg_eigh_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, bool upper, bool compute_eigenvectors) { +void linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { #if defined(USE_CUSOLVER) linalg_eigh_cusolver(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors); #else diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu index 1ee09f38a7..9586217a41 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu @@ -1027,7 +1027,7 @@ Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau) { } template -static void apply_syevd(Tensor& values, Tensor& vectors, Tensor& infos, bool upper, bool compute_eigenvectors) { +static void apply_syevd(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { using value_t = typename c10::scalar_value_type::type; cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; @@ -1114,7 +1114,7 @@ static void apply_syevd(Tensor& values, Tensor& vectors, Tensor& infos, bool upp } template -static void apply_syevj(Tensor& values, Tensor& vectors, Tensor& infos, bool upper, bool compute_eigenvectors) { +static void apply_syevj(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { using value_t = typename c10::scalar_value_type::type; cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; @@ -1171,7 +1171,7 @@ static void apply_syevj(Tensor& values, Tensor& vectors, Tensor& infos, bool upp } template -static void apply_syevj_batched(Tensor& values, Tensor& vectors, Tensor& infos, bool upper, bool compute_eigenvectors) { +static void apply_syevj_batched(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { using value_t = typename c10::scalar_value_type::type; cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; @@ -1230,19 +1230,19 @@ static void apply_syevj_batched(Tensor& values, Tensor& vectors, Tensor& infos, TORCH_CUSOLVER_CHECK(cusolverDnDestroySyevjInfo(syevj_params)); } -static void linalg_eigh_cusolver_syevd(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, bool upper, bool compute_eigenvectors) { +static void linalg_eigh_cusolver_syevd(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(eigenvectors.scalar_type(), "linalg_eigh_cuda", [&] { apply_syevd(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors); }); } -static void linalg_eigh_cusolver_syevj(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, bool upper, bool compute_eigenvectors) { +static void linalg_eigh_cusolver_syevj(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(eigenvectors.scalar_type(), "linalg_eigh_cuda", [&] { apply_syevj(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors); }); } -void linalg_eigh_cusolver(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, bool upper, bool compute_eigenvectors) { +void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { // TODO: syevj_batched should be added here, but at least for CUDA 11.2 it contains a bug leading to incorrect results // See https://github.com/pytorch/pytorch/pull/53040#issuecomment-793626268 and https://github.com/cupy/cupy/issues/4847 diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h index 74f405bebb..11e546c37d 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h @@ -48,7 +48,7 @@ void geqrf_cusolver(const Tensor& input, const Tensor& tau); void ormqr_cusolver(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose); Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau); -void linalg_eigh_cusolver(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, bool upper, bool compute_eigenvectors); +void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors); void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& pivots); void lu_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 94f5e7d884..6bab7b5eda 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6680,12 +6680,12 @@ - func: lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CompositeExplicitAutograd: lu_solve_out + CPU, CUDA: lu_solve_out - func: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor variants: method, function dispatch: - CompositeExplicitAutograd: lu_solve + CPU, CUDA: lu_solve - func: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U) variants: function @@ -10094,12 +10094,12 @@ python_module: linalg variants: function dispatch: - CompositeExplicitAutograd: linalg_eigh + CPU, CUDA: linalg_eigh - func: linalg_eigh.eigvals(Tensor self, str UPLO="L", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) python_module: linalg dispatch: - CompositeExplicitAutograd: linalg_eigh_out + CPU, CUDA: linalg_eigh_out - func: linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor python_module: linalg @@ -10107,6 +10107,8 @@ - func: linalg_eigvalsh.out(Tensor self, str UPLO='L', *, Tensor(a!) out) -> Tensor(a!) python_module: linalg + dispatch: + CPU, CUDA: linalg_eigvalsh_out - func: linalg_householder_product(Tensor input, Tensor tau) -> Tensor python_module: linalg diff --git a/test/test_linalg.py b/test/test_linalg.py index e63c554dee..c005d2f202 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -6430,13 +6430,12 @@ def test_solve_methods_arg_device(self, device): # b and A have to be modified to match accepted inputs sizes for lu_solve b = b.unsqueeze(0) A = A.unsqueeze(0) - with self.assertRaisesRegex(RuntimeError, generic_backend_dispatch_err_str): + with self.assertRaisesRegex(RuntimeError, specific_backend_dispatch_err_str): torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=A_device).int()) # This checks if a suitable error message is thrown - # when LU output and pivots are on the same device - with self.assertRaisesRegex(RuntimeError, - "Expected LU_pivots and LU_data to be on the same device"): + # when LU output and pivots are not on the same device + with self.assertRaisesRegex(RuntimeError, specific_backend_dispatch_err_str): torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=b_device).int()) @precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3}) diff --git a/test/test_ops.py b/test/test_ops.py index e58dadea27..61f45490ec 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -12,7 +12,7 @@ from torch.testing._internal.common_methods_invocations import \ (op_db, _NOTHING, UnaryUfuncInfo, SpectralFuncInfo) from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, ops, onlyOnCPUAndCUDA, skipCUDAIfRocm, OpDTypes) + (deviceCountAtLeast, instantiate_device_type_tests, ops, onlyCUDA, onlyOnCPUAndCUDA, skipCUDAIfRocm, OpDTypes) from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, \ check_alias_annotation @@ -171,6 +171,26 @@ def unsupported(dtype): self.assertEqual(supported_backward_dtypes, claimed_backward_supported, msg=msg) + # Validates that each OpInfo works correctly on different CUDA devices + @skipCUDAIfRocm + @onlyCUDA + @deviceCountAtLeast(2) + @ops(op_db, allowed_dtypes=(torch.float32, torch.long)) + def test_multiple_devices(self, devices, dtype, op): + for cuda_device_str in devices: + cuda_device = torch.device(cuda_device_str) + # NOTE: only tests on first sample + samples = op.sample_inputs(cuda_device, dtype) + sample = samples[0] + result = op(sample.input, *sample.args, **sample.kwargs) + + if isinstance(result, torch.Tensor): + self.assertTrue(result.device == cuda_device) + elif is_iterable_of_tensors(result): + self.assertTrue(all(map(lambda t: t.device == cuda_device, result))) + else: + self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.") + # Tests that the function and its (ndarray-accepting) reference produce the same # values on the tensors from sample_inputs func for the corresponding op. @onlyOnCPUAndCUDA diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 430517b57c..94960be518 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2086,11 +2086,11 @@ def large_1d_unique(dtype, device): samples.append(SampleInput(scalar)) samples.append(SampleInput(scalar, args=(0,))) samples.append(SampleInput(scalar, args=(0, True))) - # no CUDA support for stable sort yet - if not device.startswith('cuda'): - samples.append(SampleInput(scalar, kwargs=dict(stable=True))) - samples.append(SampleInput(scalar, kwargs=dict(dim=0, stable=True))) - samples.append(SampleInput(scalar, kwargs=dict(dim=0, descending=True, stable=True))) + + # Test cases for stable sort + samples.append(SampleInput(scalar, kwargs=dict(stable=True))) + samples.append(SampleInput(scalar, kwargs=dict(dim=0, stable=True))) + samples.append(SampleInput(scalar, kwargs=dict(dim=0, descending=True, stable=True))) return samples def sample_inputs_index_fill(op_info, device, dtype, requires_grad, **kwargs):