Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PyTorch batch_matmul conversion when given (3-dim, 2-dim) input pair #7843

Closed
wants to merge 1 commit into from

Conversation

haojin2
Copy link
Contributor

@haojin2 haojin2 commented Apr 13, 2021

This PR wants to fix a small bug in PT converter.
Bug reproduction script:

import torch
from torch import nn
from torch.nn import Linear

class SimpleModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(SimpleModel, self).__init__()
        self.fc = Linear(input_size, output_size)

    def forward(self, x):
        return self.fc(x)

batch_size = 128
dim = 64
T = 50

x = torch.randn((batch_size, T, dim))

model = SimpleModel(dim, 1)

model.eval()

scripted_model = torch.jit.trace(model, x).eval()

import tvm
from tvm import relay

mod, params = relay.frontend.from_pytorch(scripted_model, [("data", [batch_size, T, dim])])

target = tvm.target.Target('cuda -libs=cublas')
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target, params=params)
tvm_ctx = tvm.gpu(0)
rt = tvm.contrib.graph_executor.GraphModule(lib['default'](tvm_ctx))

ndarray_inputs = {
    "data": x.numpy()
}

rt.set_input(**ndarray_inputs)
rt.run()
print(rt.get_output(0).asnumpy())

Without this fix (current main):

Cannot find config for target=cuda -keys=cuda,gpu -libs=cublas -max_num_threads=1024 -thread_warp_size=32, workload=('batch_matmul_cublas.cuda', ('TENSOR', (128, 50, 64), 'float32'), ('TENSOR', (1, 1, 64), 'float32'), (128, 50, 1)). A fallback configuration is used, which may bring great performance regression.
Traceback (most recent call last):
  File "repro.py", line 41, in <module>
    rt.run()
  File "/home/ubuntu/.local/lib/python3.6/site-packages/tvm-0.8.dev846+g81afb14c4-py3.6-linux-x86_64.egg/tvm/contrib/graph_executor.py", line 206, in run
    self._run()
  File "/home/ubuntu/.local/lib/python3.6/site-packages/tvm-0.8.dev846+g81afb14c4-py3.6-linux-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  3: TVMFuncCall
  2: tvm::runtime::GraphExecutor::Run()
  1: std::_Function_handler<void (), tvm::runtime::GraphExecutor::CreateTVMOp(tvm::runtime::TVMOpParam const&, std::vector<DLTensor, std::allocator<DLTensor> > const&, unsigned long)::{lambda()#3}>::_M_invoke(std::_Any_data const&)
  0: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::WrapPackedFunc(int (*)(TVMValue*, int*, int, TVMValue*, int*, void*), tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  2: TVMFuncCall
  1: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::contrib::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  0: void tvm::contrib::CallBatchGemm<tvm::contrib::CublasSgemmBatchOp>(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*, tvm::contrib::CublasSgemmBatchOp)
  File "/home/ubuntu/tvm/src/runtime/contrib/cublas/../cblas/gemm_common.h", line 189
  File "/home/ubuntu/tvm/src/runtime/library_module.cc", line 78
TVMError: 
---------------------------------------------------------------
An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
---------------------------------------------------------------

  Check failed: ret == 0 (-1 vs. 0) : TVMError: 
---------------------------------------------------------------
An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
---------------------------------------------------------------

  Check failed: BatchCount3D(B) == batch_size (1 vs. 128) : 
terminate called after throwing an instance of 'tvm::runtime::InternalError'
  what():  [23:18:59] /home/ubuntu/tvm/src/runtime/workspace_pool.cc:118: 
---------------------------------------------------------------
An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
---------------------------------------------------------------

  Check failed: allocated_.size() == 1 (2 vs. 1) : 
Stack trace:
  0: tvm::runtime::WorkspacePool::~WorkspacePool()
  1: __call_tls_dtors
  2: 0x00007fdd60f44236
  3: exit
  4: __libc_start_main
  5: _start
  6: 0xffffffffffffffff


Aborted (core dumped)

With this fix:

Cannot find config for target=cuda -keys=cuda,gpu -libs=cublas -max_num_threads=1024 -thread_warp_size=32, workload=('batch_matmul_cublas.cuda', ('TENSOR', (128, 50, 64), 'float32'), ('TENSOR', (128, 1, 64), 'float32'), (128, 50, 1)). A fallback configuration is used, which may bring great performance regression.
[[[-0.21468109]
  [ 0.3858583 ]
  [ 0.16572809]
  ...
  [-0.03322682]
  [ 0.33868816]
  [ 0.3021463 ]]

 [[ 1.052577  ]
  [ 0.26492748]
  [ 0.37078723]
  ...
  [-0.0752994 ]
  [-0.66205776]
  [-0.19348428]]

 [[ 0.6743065 ]
  [ 0.02969196]
  [-0.03708391]
  ...
  [ 0.16056934]
  [ 0.41362724]
  [ 0.629748  ]]

 ...

 [[-0.05230951]
  [-0.3116043 ]
  [-0.07618818]
  ...
  [-0.7429178 ]
  [ 0.34146884]
  [-0.46452078]]

 [[ 0.6838716 ]
  [-0.0820943 ]
  [ 0.01337433]
  ...
  [ 0.6866671 ]
  [-0.4317361 ]
  [ 0.16978306]]

 [[ 0.7288995 ]
  [ 0.57882047]
  [ 0.40440276]
  ...
  [ 0.36602104]
  [ 0.6143365 ]
  [ 0.5057366 ]]]

Copy link
Contributor

@jcf94 jcf94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@haojin2 Thanks! Nice catch!

Your solution on the [2 dim matrix * 3 dim matrix] seems to repeat the [2 dim matrix] for batch_size times then apply the [batch_matmul], I'm thinking will it be better to merge the first 2 dim of the [3 dim matrix] and process a simple [matmul].

@haojin2
Copy link
Contributor Author

haojin2 commented Apr 14, 2021

@jcf94 I think what you says makes sense, I'll make that change.

@comaniac
Copy link
Contributor

Is this related to #7730? The current CuBLAS support for batch_matmul doesn't support implicit broadcasting, but the TE compute does. It would be better to support it on the CuBLAS side without introducing a new op.

@jcf94
Copy link
Contributor

jcf94 commented Apr 14, 2021

Is this related to #7730? The current CuBLAS support for batch_matmul doesn't support implicit broadcasting, but the TE compute does. It would be better to support it on the CuBLAS side without introducing a new op.

I guess this is another problem, just a bug of Pytorch frontend.

@jcf94
Copy link
Contributor

jcf94 commented Apr 14, 2021

... By the way, seems there gets another PR #7845 related this problem.

@wweic
Copy link
Contributor

wweic commented Apr 14, 2021

Thanks @jcf94 @comaniac for the prompt review.

@haojin2 Looks like #7845 is the idea @jcf94 suggested, should we help review #7845 and merge that instead?

@masahi
Copy link
Member

masahi commented Apr 15, 2021

Given that #7845 merged, I'll close this. Thanks @haojin2

@masahi masahi closed this Apr 15, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants