diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index 354f17a734..518c6a2ded 100755 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -57,6 +57,7 @@ cc_library( "impl/batch_norm.cpp", "impl/bitwise.cpp", "impl/cast.cpp", + "impl/chunk.cpp", "impl/concat.cpp", "impl/constant.cpp", "impl/constant_pad.cpp", diff --git a/core/conversion/converters/CMakeLists.txt b/core/conversion/converters/CMakeLists.txt index c90a81a6cc..392a82b744 100644 --- a/core/conversion/converters/CMakeLists.txt +++ b/core/conversion/converters/CMakeLists.txt @@ -7,6 +7,7 @@ target_sources(${lib_name} "${CMAKE_CURRENT_SOURCE_DIR}/impl/activation.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/impl/batch_norm.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/impl/cast.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/impl/chunk.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/impl/concat.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/impl/constant.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/impl/constant_pad.cpp" diff --git a/core/conversion/converters/impl/chunk.cpp b/core/conversion/converters/impl/chunk.cpp new file mode 100644 index 0000000000..a7191133fb --- /dev/null +++ b/core/conversion/converters/impl/chunk.cpp @@ -0,0 +1,84 @@ +#include "core/conversion/converters/converters.h" +#include "core/conversion/tensorcontainer/TensorContainer.h" +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace { + +// clang-format off +auto cat_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() + .pattern({"aten::chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensorOrFreeze(ctx); + auto chunks = args[1].unwrapToInt(); + auto dim = args[2].unwrapToInt(); + bool dynamic_shape = ctx->input_is_dynamic; + int size = in->getDimensions().nbDims; + int maxDim = static_cast(in->getDimensions().d[dim]); + + c10::ListTypePtr lt = n->output()->type()->expect(); + c10::TypePtr elementType = lt->getElementType(); + + int offset = 0; + if(dim < 0) { + dim = in->getDimensions().nbDims + dim; + } + if (dynamic_shape) { + TORCHTRT_ASSERT(in->getDimensions().d[dim] != -1, "Can't chunk on dynamic shape dimension!"); + } + if (chunks > in->getDimensions().d[dim]) { + LOG_WARNING("The chunks size" << chunks << "along dimension" << dim << "is greater than tensor with size" << in->getDimensions().d[dim] + << "it will default to dimension" << in->getDimensions().d[dim]); + } + int step = (maxDim + chunks - 1) / chunks; + nvinfer1::Dims start_, size_, stride_; + int nbdims = in->getDimensions().nbDims; + start_.nbDims = nbdims; + size_.nbDims = nbdims; + stride_.nbDims = nbdims; + + int startIdx = 0; + int endIdx = maxDim; + + for (int i = 0; i < nbdims; i++) { + start_.d[i] = 0; + size_.d[i] = 0; + stride_.d[i] = 1; + } + // update slice layer + auto list = c10::impl::GenericList(elementType); + list.reserve(chunks); + if(!dynamic_shape) { + for (int chunk = 0; chunk < chunks; chunk++) { + for (int i = 0; i < nbdims; i++) { + if (i == dim) { + start_.d[i] = offset; + size_.d[i] = std::min(step, maxDim - offset); + } + } + LOG_DEBUG("start_:" << start_); + LOG_DEBUG("size_:" << size_); + LOG_DEBUG("stride_:" << stride_); + auto slice_layer = ctx->net->addSlice(*in, start_, size_, stride_); + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(slice_layer->getOutput(0)); + auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + list.emplace_back(ival); + offset = offset + step; + } + } + auto split_output_ivalue = std::move(torch::jit::IValue(list)); + ctx->AssociateValueAndIValue(n->outputs()[0], split_output_ivalue); + return true; + }}); +// clang-format on +} // namespace +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace torch_tensorrt diff --git a/tests/core/conversion/converters/BUILD b/tests/core/conversion/converters/BUILD index a8c57b1b41..1973c112fd 100644 --- a/tests/core/conversion/converters/BUILD +++ b/tests/core/conversion/converters/BUILD @@ -35,6 +35,10 @@ converter_test( name = "test_cast", ) +converter_test( + name = "test_chunk", +) + converter_test( name = "test_clone", ) @@ -208,6 +212,7 @@ test_suite( ":test_batch_norm", ":test_bitwise", ":test_cast", + ":test_chunk", ":test_clamp", ":test_clone", ":test_comparators", diff --git a/tests/core/conversion/converters/test_chunk.cpp b/tests/core/conversion/converters/test_chunk.cpp new file mode 100644 index 0000000000..eccefbea99 --- /dev/null +++ b/tests/core/conversion/converters/test_chunk.cpp @@ -0,0 +1,34 @@ +#include +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" + +TEST(Converters, ATenChunkConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=6]() + %3 : int = prim::Constant[value=0]() + %4 : Tensor[] = aten::chunk(%0, %2, %3) + %5 : Tensor, %6 : Tensor, %7 : Tensor, %8 : Tensor, %9 : Tensor, %10 : Tensor = prim::ListUnpack(%4) + return (%5, %6, %7, %8, %9, %10))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + auto in = at::randint(1, 10, {12}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} \ No newline at end of file