From 261ceadc79cba084e806a11d48b81c4e2dc30bc2 Mon Sep 17 00:00:00 2001 From: Gregory Chanan Date: Wed, 30 Jan 2019 16:01:51 -0800 Subject: [PATCH] =?UTF-8?q?Fix=20uninitialized=20data=20and=20broken=20bro?= =?UTF-8?q?adcasting=20with=20sparse.mm=20and=20spa=E2=80=A6=20(#16572)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: …rse.addmm. Fixes https://github.com/pytorch/pytorch/issues/16543. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16572 Differential Revision: D13884235 Pulled By: gchanan fbshipit-source-id: 308916051364d72f72ec56f0495c6c7c09845131 --- aten/src/ATen/native/sparse/SparseTensorMath.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index d8f4d6cb2..f67bb95df 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -593,14 +593,16 @@ Tensor _sparse_addmm( Scalar beta, Scalar alpha ) { - return at::s_native_addmm(t, sparse, dense, beta, alpha); + Tensor b_t; + std::tie(b_t) = expand_size(t, {sparse.size(0), dense.size(1)}, "addmm"); + return at::s_native_addmm(b_t, sparse, dense, beta, alpha); } Tensor _sparse_mm( const SparseTensor& sparse, const Tensor& dense ) { - Tensor t = at::empty({sparse.size(0), dense.size(1)}, dense.options()); + Tensor t = at::zeros({}, dense.options()); return at::_sparse_addmm(t, sparse, dense, 0, 1); }