Skip to content

Commit

Permalink
Add support for aten::squeeze without a dim (pytorch#56)
Browse files Browse the repository at this point in the history
# Description

Adds converter support for aten::squeeze(Tensor) which will remove any static dimension of size 1. An existing converter supports aten::squeeze(Tensor, int dim).

Fixes # (issue)

## Type of change

Please delete options that are not relevant and/or add your own.

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing functionality to not work as expected)
- This change requires a documentation update

# Checklist:

- [ ] My code follows the style guidelines of this project (You can use the linters)
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas and hacks
- [ ] I have made corresponding changes to the documentation
- [ ] I have added tests to verify my fix or my feature
- [ ] New and existing unit tests pass locally with my changes
- [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified
  • Loading branch information
mfeliz-cruise committed Oct 6, 2022
1 parent e608bc7 commit bc7ee08
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 21 deletions.
64 changes: 43 additions & 21 deletions core/conversion/converters/impl/squeeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,57 @@ namespace converters {
namespace impl {
namespace {

auto squeeze_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
{"aten::squeeze.dim(Tensor(a) self, int dim) -> (Tensor(a))",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto dim = args[1].unwrapToInt();
auto squeeze_registrations TORCHTRT_UNUSED =
RegisterNodeConversionPatterns()
.pattern(
{"aten::squeeze.dim(Tensor(a) self, int dim) -> (Tensor(a))",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto dim = args[1].unwrapToInt();

auto selfDim = util::toVec(self->getDimensions());
if (dim < 0) {
dim = selfDim.size() + dim;
}
auto selfDim = util::toVec(self->getDimensions());
if (dim < 0) {
dim = selfDim.size() + dim;
}

if (selfDim[dim] != 1) {
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], self);
if (selfDim[dim] != 1) {
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], self);

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
LOG_DEBUG("Output tensor shape: " << out->getDimensions());

return true;
}
return true;
}

auto shuffle_layer = ctx->net->addShuffle(*self);
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
shuffle_layer->setReshapeDimensions(util::squeezeDims(self->getDimensions(), dim));
auto shuffle_layer = ctx->net->addShuffle(*self);
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
shuffle_layer->setReshapeDimensions(util::squeezeDims(self->getDimensions(), dim));

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_layer->getOutput(0));
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_layer->getOutput(0));

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
LOG_DEBUG("Output tensor shape: " << out->getDimensions());

return true;
}});
return true;
}})
.pattern(
{"aten::squeeze(Tensor(a) self) -> (Tensor(a))",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto self_dims = self->getDimensions();
auto out = self;
auto squeeze_dims = util::squeezeAllDims(self_dims);
if (squeeze_dims != self_dims) {
auto shuffle_layer = ctx->net->addShuffle(*self);
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
shuffle_layer->setReshapeDimensions(squeeze_dims);
out = shuffle_layer->getOutput(0);
}

auto trt_out = ctx->AssociateValueAndTensor(n->outputs()[0], out);

LOG_DEBUG("Output tensor shape: " << trt_out->getDimensions());

return true;
}});

} // namespace
} // namespace impl
Expand Down
13 changes: 13 additions & 0 deletions core/util/trt_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,19 @@ nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros) {
return dims;
}

nvinfer1::Dims squeezeAllDims(const nvinfer1::Dims& d, bool use_zeros_for_unknown_dims) {
nvinfer1::Dims dims;
int j = 0;
for (int i = 0; i < d.nbDims; i++) {
if (d.d[i] != 1) {
dims.d[j++] = (use_zeros_for_unknown_dims && d.d[i] == -1) ? 0 : d.d[i];
}
}
dims.nbDims = j;

return dims;
}

std::vector<int64_t> toVec(nvinfer1::Dims d) {
std::vector<int64_t> dims;
for (int i = 0; i < d.nbDims; i++) {
Expand Down
1 change: 1 addition & 0 deletions core/util/trt_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ nvinfer1::Dims toDimsTailPad(c10::List<int64_t> l, uint64_t pad_to);
nvinfer1::Dims unpadDims(const nvinfer1::Dims& d);
nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos, int val = 1, bool use_zeros = true);
nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros = true);
nvinfer1::Dims squeezeAllDims(const nvinfer1::Dims& d, bool use_zeros_for_unknown_dims = true);
nvinfer1::Dims toDims(c10::IntArrayRef l);
nvinfer1::Dims toDims(c10::List<int64_t> l);
nvinfer1::DimsHW toDimsHW(c10::List<int64_t> l);
Expand Down
26 changes: 26 additions & 0 deletions tests/core/conversion/converters/test_squeeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,29 @@ TEST(Converters, ATenSqueezeDontNeedSqueezeConvertsCorrectly) {
ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenSqueezeNoDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : Tensor = aten::squeeze(%0)
return (%1))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto validate_squeeze_with_input = [&g](const at::Tensor& in) {
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
};

validate_squeeze_with_input(at::randint(1, 10, {2, 1, 3, 3}, {at::kCUDA}));
validate_squeeze_with_input(at::randint(1, 10, {1, 1, 1, 3}, {at::kCUDA}));
validate_squeeze_with_input(at::randint(1, 10, {1, 10, 1, 3}, {at::kCUDA}));
validate_squeeze_with_input(at::randint(1, 10, {2, 10, 3, 3}, {at::kCUDA}));
validate_squeeze_with_input(at::randint(1, 10, {1, 1}, {at::kCUDA}));
validate_squeeze_with_input(at::randint(1, 10, {1}, {at::kCUDA}));
}

0 comments on commit bc7ee08

Please sign in to comment.