Skip to content

Commit

Permalink
[MPS] Fixes SiLU on non-contiguous tensors (pytorch#139006)
Browse files Browse the repository at this point in the history
Similar to pytorch#123049, however, `SiLU` also produces random values, `0.0`, or `NaN` as results if input tensor is not contiguous on prior to macOS 15.0.
Orignally the problem was found at jy0205/Pyramid-Flow#113.
Pull Request resolved: pytorch#139006
Approved by: https://github.com/malfet
  • Loading branch information
niw authored and rahulsingh-intel committed Nov 5, 2024
1 parent e2495b3 commit 3664d0d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
13 changes: 11 additions & 2 deletions aten/src/ATen/native/mps/operations/Activation.mm
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,11 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) {

MPSStream* stream = getCurrentMPSStream();

bool executeGatherOp =
!(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) ||
self.is_contiguous(MemoryFormat::ChannelsLast3d));
Tensor result_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);

@autoreleasepool {
string key = "silu_out_mps:" + getTensorsStringKey({self});

Expand All @@ -1673,12 +1678,16 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) {
newCachedGraph->outputTensor_ = outputTensor;
});

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, executeGatherOp ? result_ : result, nil, false);

auto feeds = dictionaryFromPlaceholders(selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
if (executeGatherOp) {
result.copy_(result_);
}
}

TORCH_IMPL_FUNC(silu_backward_out_mps)
Expand Down
18 changes: 14 additions & 4 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6792,9 +6792,18 @@ def helper(shape, beta, threshold, dtype):
# Test silu

def test_silu(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
def helper(shape, contiguous=True):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
x = cpu_x.detach().clone().to('mps')

if not contiguous and (0 not in shape and len(shape) >= 2):
# Tranposing will make the tensor non-contiguous
cpu_x = cpu_x.transpose(0, 1)
x = x.transpose(0, 1)
assert not x.is_contiguous()

cpu_x.requires_grad_()
x.requires_grad_()

silu_result = torch.nn.SiLU()(x)
silu_result_cpu = torch.nn.SiLU()(cpu_x)
Expand All @@ -6810,7 +6819,8 @@ def helper(shape):

# Test empty shape too
for shape in [[], (2, 3), (2, 8, 4, 5)]:
helper(shape)
for contiguous in [True, False]:
helper(shape, contiguous)

def test_cast_mps_to_cpu(self):
def helper(src_dtype, dst_dtype):
Expand Down

0 comments on commit 3664d0d

Please sign in to comment.