diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 7fe0032d80..f86e3c5cb5 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -190,6 +190,7 @@ def slice_scatter_decomposition( step: Optional[int] = None, ) -> torch.Tensor: dim_size = input_tensor.shape[dim] + device_input_tensor = input_tensor.device start = get_positive_dim(start, input_tensor.shape[dim]) if end is None: end = dim_size @@ -216,7 +217,8 @@ def slice_scatter_decomposition( index_tensor_shape.append(src_each_dim) for index in range(start, end, step): cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.int64)) - index_tensor = torch.stack(cat_tensors, dim).to(input_tensor.device) + index_tensor = torch.stack(cat_tensors, dim) + index_tensor = index_tensor.to(device_input_tensor) index_tensor_64 = index_tensor.to(torch.int64) output_tensor = torch.scatter(input_tensor, dim, index_tensor_64, src_tensor) return output_tensor