From f6b5bad6e1a424c62ef8ec12b5c0b15a4d095bb3 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 25 Oct 2022 16:50:24 -0700 Subject: [PATCH] fix: `aten::einsum` schema switch for Torch 1.14 nightly - `aten::einsum` schema in 1.14 is `aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> Tensor`, whereas that of 1.13 has only the first two arguments - Updated test cases to use three arguments instead of two - Updated converter schema to allow for additional arguments - Fails 1.13 tests, as 1.14 schema is incompatible --- core/conversion/converters/impl/einsum.cpp | 2 +- tests/core/conversion/converters/test_einsum.cpp | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/core/conversion/converters/impl/einsum.cpp b/core/conversion/converters/impl/einsum.cpp index fb031f6c38..3503aa35ff 100644 --- a/core/conversion/converters/impl/einsum.cpp +++ b/core/conversion/converters/impl/einsum.cpp @@ -12,7 +12,7 @@ namespace impl { namespace { auto einsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( - {"aten::einsum(str equation, Tensor[] tensors) -> (Tensor)", + {"aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { // Extract equation and list of tensors auto equation = args[0].unwrapToString(); diff --git a/tests/core/conversion/converters/test_einsum.cpp b/tests/core/conversion/converters/test_einsum.cpp index ff7ba201ff..d45dea54fb 100644 --- a/tests/core/conversion/converters/test_einsum.cpp +++ b/tests/core/conversion/converters/test_einsum.cpp @@ -9,7 +9,8 @@ TEST(Converters, ATenEinsumConvertsMatMulCorrectly) { graph(%x.1 : Tensor, %x.2 : Tensor): %0 : str = prim::Constant[value="ij,jk->ik"]() %3 : Tensor[] = prim::ListConstruct(%x.1, %x.2) - %4 : Tensor = aten::einsum(%0, %3) + %none : NoneType = prim::Constant() + %4 : Tensor = aten::einsum(%0, %3, %none) return (%4))IR"; auto g = std::make_shared(); @@ -34,7 +35,8 @@ TEST(Converters, ATenEinsumConvertsElementwiseProdCorrectly) { graph(%x.1 : Tensor, %x.2 : Tensor): %0 : str = prim::Constant[value="abcd,abcd->abcd"]() %3 : Tensor[] = prim::ListConstruct(%x.1, %x.2) - %4 : Tensor = aten::einsum(%0, %3) + %none : NoneType = prim::Constant() + %4 : Tensor = aten::einsum(%0, %3, %none) return (%4))IR"; auto g = std::make_shared(); @@ -59,7 +61,8 @@ TEST(Converters, ATenEinsumConvertsTransposeCorrectly) { graph(%x.1 : Tensor): %0 : str = prim::Constant[value="jk->kj"]() %3 : Tensor[] = prim::ListConstruct(%x.1) - %4 : Tensor = aten::einsum(%0, %3) + %none : NoneType = prim::Constant() + %4 : Tensor = aten::einsum(%0, %3, %none) return (%4))IR"; auto g = std::make_shared(); @@ -83,7 +86,8 @@ TEST(Converters, ATenEinsumConvertsVectorsCorrectly) { graph(%x.1 : Tensor, %x.2 : Tensor): %0 : str = prim::Constant[value="a,b->ab"]() %3 : Tensor[] = prim::ListConstruct(%x.1, %x.2) - %4 : Tensor = aten::einsum(%0, %3) + %none : NoneType = prim::Constant() + %4 : Tensor = aten::einsum(%0, %3, %none) return (%4))IR"; auto g = std::make_shared();