Skip to content

Commit

Permalink
Fix dispatch keys for eigh, lu_solve (#60945)
Browse files Browse the repository at this point in the history
Summary:
I added a test to `test_ops.py` that verifies that the op can run correctly from different cuda devices. This test revealed that `linalg_eigh`, `linalg_eigvalsh`, `linalg_matrix_rank`, `linalg_pinv` were failing. `matrix_rank` and `pinv` are calling `eigh` internally.

`linalg_eigh` and `lu_solve` internally use dispatch stubs, so they should be registered with `CPU, CUDA` dispatch keys. The generated code includes device guards in this case and the problem is not present.

Implemented a better out variant for `eigvalsh` and registered it with `CPU, CUDA` dispatch keys.

~I added a device guard to `linalg_eigh_kernel` as a fix for `eigvalsh` function. This function needs to be registered as CompositeImplicitAutograd, because it calls `at::linalg_eigh` if `at::GradMode::is_enabled()`.~

Fixes pytorch/pytorch#60892.

Pull Request resolved: pytorch/pytorch#60945

Reviewed By: mruberry

Differential Revision: D29589580

Pulled By: ngimel

fbshipit-source-id: 5851605958bdfc3a1a1768263934619449957168
  • Loading branch information
IvanYashchuk authored and facebook-github-bot committed Jul 7, 2021
1 parent fb00194 commit 9dd1824
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 59 deletions.
59 changes: 32 additions & 27 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2246,11 +2246,11 @@ DEFINE_DISPATCH(linalg_eigh_stub);
* 'uplo_str' - controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L"
"u", "U" - upper triangular portion of the input matrix is used in computations; "l", "L" - lower.
*/
std::tuple<Tensor&, Tensor&> linalg_eigh_out_info(
void linalg_eigh_out_info(
const Tensor& input,
Tensor& values,
Tensor& vectors,
Tensor& infos,
const Tensor& values,
const Tensor& vectors,
const Tensor& infos,
bool compute_eigenvectors,
const c10::string_view uplo_str) {
// These internal asserts make explicit the assumptions in the implementation
Expand Down Expand Up @@ -2306,8 +2306,6 @@ std::tuple<Tensor&, Tensor&> linalg_eigh_out_info(
bool upper = (uplo == 'U');

linalg_eigh_stub(input.device().type(), values, vectors, infos, upper, compute_eigenvectors);

return std::tuple<Tensor&, Tensor&>(values, vectors);
}

std::tuple<Tensor, Tensor> linalg_eigh(const Tensor& input, c10::string_view uplo) {
Expand All @@ -2318,7 +2316,7 @@ std::tuple<Tensor, Tensor> linalg_eigh(const Tensor& input, c10::string_view upl
Tensor vectors = at::empty({0}, input.options());
Tensor infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, input.options().dtype(kInt));

std::tie(values, vectors) = linalg_eigh_out_info(input, values, vectors, infos, true, uplo);
linalg_eigh_out_info(input, values, vectors, infos, true, uplo);

if (input.dim() > 2) {
batchCheckErrors(infos, "torch.linalg.eigh");
Expand All @@ -2332,8 +2330,6 @@ std::tuple<Tensor, Tensor> linalg_eigh(const Tensor& input, c10::string_view upl
// TODO: it's possible to make the _out variant to be a primal function and implement linalg_eigh on top of _out
// TODO: implement _out variant avoiding copy and using already allocated storage directly
std::tuple<Tensor&, Tensor&> linalg_eigh_out(const Tensor& input, c10::string_view uplo, Tensor& eigvals, Tensor& eigvecs) {
checkSameDevice("torch.linalg.eigh", eigvecs, input, "eigenvectors");
checkSameDevice("torch.linalg.eigh", eigvals, input, "eigenvalues");
checkLinalgCompatibleDtype("torch.linalg.eigh", eigvecs, input, "eigenvectors");

// eigenvalues are always real-valued here
Expand All @@ -2360,36 +2356,45 @@ Tensor linalg_eigvalsh(const Tensor& input, c10::string_view uplo) {
return values;
}

squareCheckInputs(input);
checkUplo(uplo);
ScalarType real_dtype = toValueType(input.scalar_type());
Tensor values = at::empty({0}, input.options().dtype(real_dtype));
values = at::linalg_eigvalsh_outf(input, uplo, values);
return values;
}

Tensor& linalg_eigvalsh_out(const Tensor& input, c10::string_view uplo, Tensor& result) {
ScalarType real_dtype = toValueType(input.scalar_type());
checkLinalgCompatibleDtype("torch.linalg.eigvalsh", result.scalar_type(), real_dtype);

squareCheckInputs(input);
checkUplo(uplo);

auto expected_result_shape = IntArrayRef(input.sizes().data(), input.dim()-1); // input.shape[:-1]
bool result_equal_expected_shape = result.sizes().equals(expected_result_shape);
bool expected_result_type = (result.scalar_type() == real_dtype);
bool copy_needed = !expected_result_type;
copy_needed |= (result.numel() != 0 && !result_equal_expected_shape);
copy_needed |= (result.numel() != 0 && !result.is_contiguous());

Tensor vectors = at::empty({0}, input.options());
Tensor infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, input.options().dtype(kInt));

std::tie(values, vectors) = linalg_eigh_out_info(input, values, vectors, infos, false, uplo);
if (copy_needed) { // we have to allocate a temporary tensor
Tensor result_tmp = at::empty({expected_result_shape}, input.options().dtype(real_dtype));
linalg_eigh_out_info(input, result_tmp, vectors, infos, /*compute_eigenvectors=*/false, uplo);
at::native::resize_output(result, result_tmp.sizes());
result.copy_(result_tmp);
} else {
// else use the provided output storage directly
linalg_eigh_out_info(input, result, vectors, infos, /*compute_eigenvectors=*/false, uplo);
}

if (input.dim() > 2) {
batchCheckErrors(infos, "torch.linalg.eigvalsh");
} else {
singleCheckErrors(infos.item().toInt(), "torch.linalg.eigvalsh");
}

return values;
}

// TODO: it's possible to make the _out variant to be a primal function and implement linalg_eigvalsh on top of _out
// TODO: implement _out variant avoiding copy and using already allocated storage directly
Tensor& linalg_eigvalsh_out(const Tensor& input, c10::string_view uplo, Tensor& result) {
checkSameDevice("torch.linalg.eigvalsh", result, input);
ScalarType real_dtype = toValueType(input.scalar_type());
checkLinalgCompatibleDtype("torch.linalg.eigvalsh", result.scalar_type(), real_dtype);

Tensor result_tmp = at::linalg_eigvalsh(input, uplo);

at::native::resize_output(result, result_tmp.sizes());
result.copy_(result_tmp);

return result;
}

Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/BatchLinearAlgebra.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ using ormqr_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, const
DECLARE_DISPATCH(ormqr_fn, ormqr_stub);

using linalg_eigh_fn = void (*)(
Tensor& /*eigenvalues*/,
Tensor& /*eigenvectors*/,
Tensor& /*infos*/,
const Tensor& /*eigenvalues*/,
const Tensor& /*eigenvectors*/,
const Tensor& /*infos*/,
bool /*upper*/,
bool /*compute_eigenvectors*/);
DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub);
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ void linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos,
This function doesn't do any error checks and it's assumed that every argument is valid.
*/
template <typename scalar_t>
void apply_lapack_eigh(Tensor& values, Tensor& vectors, Tensor& infos, bool upper, bool compute_eigenvectors) {
void apply_lapack_eigh(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
#if !AT_BUILD_WITH_LAPACK()
TORCH_CHECK(
false,
Expand Down Expand Up @@ -365,7 +365,7 @@ void apply_lapack_eigh(Tensor& values, Tensor& vectors, Tensor& infos, bool uppe
}

// This is a type dispatching helper function for 'apply_lapack_eigh'
void linalg_eigh_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, bool upper, bool compute_eigenvectors) {
void linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
// This function calculates the symmetric/hermitian eigendecomposition
// in-place tensors should be in batched column major memory format the
// content of eigenvalues, eigenvectors and infos is overwritten by
Expand Down
13 changes: 7 additions & 6 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2268,7 +2268,7 @@ std::tuple<Tensor, Tensor> _linalg_qr_helper_cuda(const Tensor& input, c10::stri
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ symeig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename scalar_t>
static void apply_magma_eigh(Tensor& values, Tensor& vectors, Tensor& infos, bool upper, bool compute_eigenvectors) {
static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
#ifndef USE_MAGMA
TORCH_CHECK(
false,
Expand Down Expand Up @@ -2381,23 +2381,24 @@ std::tuple<Tensor, Tensor> _symeig_helper_cuda(const Tensor& self, bool eigenvec

// This is a type dispatch function for 'apply_magma_eigh'
// For small inputs result is computed on CPU
void linalg_eigh_magma(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, bool upper, bool compute_eigenvectors) {
void linalg_eigh_magma(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
// MAGMA just calls LAPACK for eigenvectors.size(-1) <= 128
// See https://bitbucket.org/icl/magma/src/e6fdca447bd402693e8b0b950a898b6879bbcc41/src/zheevd_gpu.cpp?at=master#lines-258
// in addition lda is ignored breaking 0x0 inputs
if (eigenvectors.size(-1) > 128) {
// MAGMA requires eigenvalues and infos tensors to reside on CPU
Tensor eigenvalues_cpu = eigenvalues.to(kCPU);
infos = infos.to(kCPU);
Tensor infos_cpu = infos.to(kCPU);

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
eigenvectors.scalar_type(), "linalg_eigh_cpu", [&] {
eigenvectors.scalar_type(), "linalg_eigh_magma", [&] {
apply_magma_eigh<scalar_t>(
eigenvalues_cpu, eigenvectors, infos, upper, compute_eigenvectors);
eigenvalues_cpu, eigenvectors, infos_cpu, upper, compute_eigenvectors);
});

// Transfer computed by MAGMA results from CPU to GPU
eigenvalues.copy_(eigenvalues_cpu);
infos.copy_(infos_cpu);
} else { // eigenvectors.size(-1) <= 128
// transfer to CPU, compute the result and copy back to GPU
// this is faster than going through MAGMA that does the same
Expand All @@ -2413,7 +2414,7 @@ void linalg_eigh_magma(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos,
}
}

void linalg_eigh_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, bool upper, bool compute_eigenvectors) {
void linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
#if defined(USE_CUSOLVER)
linalg_eigh_cusolver(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
#else
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,7 @@ Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau) {
}

template <typename scalar_t>
static void apply_syevd(Tensor& values, Tensor& vectors, Tensor& infos, bool upper, bool compute_eigenvectors) {
static void apply_syevd(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
using value_t = typename c10::scalar_value_type<scalar_t>::type;

cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
Expand Down Expand Up @@ -1114,7 +1114,7 @@ static void apply_syevd(Tensor& values, Tensor& vectors, Tensor& infos, bool upp
}

template <typename scalar_t>
static void apply_syevj(Tensor& values, Tensor& vectors, Tensor& infos, bool upper, bool compute_eigenvectors) {
static void apply_syevj(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
using value_t = typename c10::scalar_value_type<scalar_t>::type;

cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
Expand Down Expand Up @@ -1171,7 +1171,7 @@ static void apply_syevj(Tensor& values, Tensor& vectors, Tensor& infos, bool upp
}

template <typename scalar_t>
static void apply_syevj_batched(Tensor& values, Tensor& vectors, Tensor& infos, bool upper, bool compute_eigenvectors) {
static void apply_syevj_batched(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
using value_t = typename c10::scalar_value_type<scalar_t>::type;

cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
Expand Down Expand Up @@ -1230,19 +1230,19 @@ static void apply_syevj_batched(Tensor& values, Tensor& vectors, Tensor& infos,
TORCH_CUSOLVER_CHECK(cusolverDnDestroySyevjInfo(syevj_params));
}

static void linalg_eigh_cusolver_syevd(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, bool upper, bool compute_eigenvectors) {
static void linalg_eigh_cusolver_syevd(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(eigenvectors.scalar_type(), "linalg_eigh_cuda", [&] {
apply_syevd<scalar_t>(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
});
}

static void linalg_eigh_cusolver_syevj(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, bool upper, bool compute_eigenvectors) {
static void linalg_eigh_cusolver_syevj(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(eigenvectors.scalar_type(), "linalg_eigh_cuda", [&] {
apply_syevj<scalar_t>(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
});
}

void linalg_eigh_cusolver(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, bool upper, bool compute_eigenvectors) {
void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
// TODO: syevj_batched should be added here, but at least for CUDA 11.2 it contains a bug leading to incorrect results
// See https://github.com/pytorch/pytorch/pull/53040#issuecomment-793626268 and https://github.com/cupy/cupy/issues/4847

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void geqrf_cusolver(const Tensor& input, const Tensor& tau);
void ormqr_cusolver(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose);
Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau);

void linalg_eigh_cusolver(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, bool upper, bool compute_eigenvectors);
void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors);
void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& pivots);

void lu_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots);
Expand Down
10 changes: 6 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6680,12 +6680,12 @@

- func: lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CompositeExplicitAutograd: lu_solve_out
CPU, CUDA: lu_solve_out

- func: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
variants: method, function
dispatch:
CompositeExplicitAutograd: lu_solve
CPU, CUDA: lu_solve

- func: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)
variants: function
Expand Down Expand Up @@ -10094,19 +10094,21 @@
python_module: linalg
variants: function
dispatch:
CompositeExplicitAutograd: linalg_eigh
CPU, CUDA: linalg_eigh

- func: linalg_eigh.eigvals(Tensor self, str UPLO="L", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
python_module: linalg
dispatch:
CompositeExplicitAutograd: linalg_eigh_out
CPU, CUDA: linalg_eigh_out

- func: linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor
python_module: linalg
variants: function

- func: linalg_eigvalsh.out(Tensor self, str UPLO='L', *, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
dispatch:
CPU, CUDA: linalg_eigvalsh_out

- func: linalg_householder_product(Tensor input, Tensor tau) -> Tensor
python_module: linalg
Expand Down
7 changes: 3 additions & 4 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6430,13 +6430,12 @@ def test_solve_methods_arg_device(self, device):
# b and A have to be modified to match accepted inputs sizes for lu_solve
b = b.unsqueeze(0)
A = A.unsqueeze(0)
with self.assertRaisesRegex(RuntimeError, generic_backend_dispatch_err_str):
with self.assertRaisesRegex(RuntimeError, specific_backend_dispatch_err_str):
torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=A_device).int())

# This checks if a suitable error message is thrown
# when LU output and pivots are on the same device
with self.assertRaisesRegex(RuntimeError,
"Expected LU_pivots and LU_data to be on the same device"):
# when LU output and pivots are not on the same device
with self.assertRaisesRegex(RuntimeError, specific_backend_dispatch_err_str):
torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=b_device).int())

@precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3})
Expand Down
22 changes: 21 additions & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.testing._internal.common_methods_invocations import \
(op_db, _NOTHING, UnaryUfuncInfo, SpectralFuncInfo)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, onlyOnCPUAndCUDA, skipCUDAIfRocm, OpDTypes)
(deviceCountAtLeast, instantiate_device_type_tests, ops, onlyCUDA, onlyOnCPUAndCUDA, skipCUDAIfRocm, OpDTypes)
from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, \
check_alias_annotation
Expand Down Expand Up @@ -171,6 +171,26 @@ def unsupported(dtype):

self.assertEqual(supported_backward_dtypes, claimed_backward_supported, msg=msg)

# Validates that each OpInfo works correctly on different CUDA devices
@skipCUDAIfRocm
@onlyCUDA
@deviceCountAtLeast(2)
@ops(op_db, allowed_dtypes=(torch.float32, torch.long))
def test_multiple_devices(self, devices, dtype, op):
for cuda_device_str in devices:
cuda_device = torch.device(cuda_device_str)
# NOTE: only tests on first sample
samples = op.sample_inputs(cuda_device, dtype)
sample = samples[0]
result = op(sample.input, *sample.args, **sample.kwargs)

if isinstance(result, torch.Tensor):
self.assertTrue(result.device == cuda_device)
elif is_iterable_of_tensors(result):
self.assertTrue(all(map(lambda t: t.device == cuda_device, result)))
else:
self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.")

# Tests that the function and its (ndarray-accepting) reference produce the same
# values on the tensors from sample_inputs func for the corresponding op.
@onlyOnCPUAndCUDA
Expand Down
10 changes: 5 additions & 5 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2086,11 +2086,11 @@ def large_1d_unique(dtype, device):
samples.append(SampleInput(scalar))
samples.append(SampleInput(scalar, args=(0,)))
samples.append(SampleInput(scalar, args=(0, True)))
# no CUDA support for stable sort yet
if not device.startswith('cuda'):
samples.append(SampleInput(scalar, kwargs=dict(stable=True)))
samples.append(SampleInput(scalar, kwargs=dict(dim=0, stable=True)))
samples.append(SampleInput(scalar, kwargs=dict(dim=0, descending=True, stable=True)))

# Test cases for stable sort
samples.append(SampleInput(scalar, kwargs=dict(stable=True)))
samples.append(SampleInput(scalar, kwargs=dict(dim=0, stable=True)))
samples.append(SampleInput(scalar, kwargs=dict(dim=0, descending=True, stable=True)))
return samples

def sample_inputs_index_fill(op_info, device, dtype, requires_grad, **kwargs):
Expand Down

0 comments on commit 9dd1824

Please sign in to comment.