Skip to content

Commit

Permalink
Merge pull request #1420 from gs-olive/einsum_schema_fix
Browse files Browse the repository at this point in the history
fix: `aten::einsum` schema switch for Torch 1.14 nightly
  • Loading branch information
peri044 authored Oct 28, 2022
2 parents 3053ecc + f6b5bad commit e75123c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace impl {
namespace {

auto einsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
{"aten::einsum(str equation, Tensor[] tensors) -> (Tensor)",
{"aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Extract equation and list of tensors
auto equation = args[0].unwrapToString();
Expand Down
12 changes: 8 additions & 4 deletions tests/core/conversion/converters/test_einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ TEST(Converters, ATenEinsumConvertsMatMulCorrectly) {
graph(%x.1 : Tensor, %x.2 : Tensor):
%0 : str = prim::Constant[value="ij,jk->ik"]()
%3 : Tensor[] = prim::ListConstruct(%x.1, %x.2)
%4 : Tensor = aten::einsum(%0, %3)
%none : NoneType = prim::Constant()
%4 : Tensor = aten::einsum(%0, %3, %none)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
Expand All @@ -34,7 +35,8 @@ TEST(Converters, ATenEinsumConvertsElementwiseProdCorrectly) {
graph(%x.1 : Tensor, %x.2 : Tensor):
%0 : str = prim::Constant[value="abcd,abcd->abcd"]()
%3 : Tensor[] = prim::ListConstruct(%x.1, %x.2)
%4 : Tensor = aten::einsum(%0, %3)
%none : NoneType = prim::Constant()
%4 : Tensor = aten::einsum(%0, %3, %none)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
Expand All @@ -59,7 +61,8 @@ TEST(Converters, ATenEinsumConvertsTransposeCorrectly) {
graph(%x.1 : Tensor):
%0 : str = prim::Constant[value="jk->kj"]()
%3 : Tensor[] = prim::ListConstruct(%x.1)
%4 : Tensor = aten::einsum(%0, %3)
%none : NoneType = prim::Constant()
%4 : Tensor = aten::einsum(%0, %3, %none)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
Expand All @@ -83,7 +86,8 @@ TEST(Converters, ATenEinsumConvertsVectorsCorrectly) {
graph(%x.1 : Tensor, %x.2 : Tensor):
%0 : str = prim::Constant[value="a,b->ab"]()
%3 : Tensor[] = prim::ListConstruct(%x.1, %x.2)
%4 : Tensor = aten::einsum(%0, %3)
%none : NoneType = prim::Constant()
%4 : Tensor = aten::einsum(%0, %3, %none)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
Expand Down

0 comments on commit e75123c

Please sign in to comment.