Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] initial llvm codegen for amdgpu #402

Merged
merged 44 commits into from
Sep 13, 2017
Merged

[BACKEND] initial llvm codegen for amdgpu #402

merged 44 commits into from
Sep 13, 2017

Conversation

aditya4d
Copy link
Contributor

The test results:

$ python test_gemm.py
skip because nvptx -mcpu=sm_20 is not enabled..
skip because rocm is not enabled..
skip because metal is not enabled..
skip because opencl is not enabled..
skip because cuda is not enabled..
$
$ python test_codegen_device.py
$

@tqchen
Copy link
Member

tqchen commented Aug 31, 2017

check your runtime as it reports rocm not enabled. Need to change src/runtime/module.cc to add rocm enable checj

// add function as void return value
CodeGenLLVM::AddFunctionInternal(f, true);
// annotate as kernel function
module_->getOrInsertNamedMetadata("nvvm.annotations")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likely you don't need this, need to change to Amd kernel annotations

CHECK_EQ(info.scope.rank, 1)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if address space is consistent with Amd gpu backend

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to make another pass on the codegen part as there are obvious differences between nvptx and amdgcn codegen. Is there a way I can see IR directly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to check the LLVM code, do module_->dump(); you have to insert it manually in the code though. Otherwise, implement GetSource in hip module which should give you the assembly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll use LOG(WARNING) << module_->dump(); to see it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tqchen I am getting following error:

$ python tests/python/unittest/test_codegen_device.py                          
Traceback (most recent call last):
  File "tests/python/unittest/test_codegen_device.py", line 1, in <module>
    import tvm
  File "/home/aditya/tvm/python/tvm/__init__.py", line 5, in <module>
    from . import tensor
  File "/home/aditya/tvm/python/tvm/tensor.py", line 4, in <module>
    from ._ffi.node import NodeBase, NodeGeneric, register_node, convert_to_node
  File "/home/aditya/tvm/python/tvm/_ffi/node.py", line 8, in <module>
    from .node_generic import NodeGeneric, convert_to_node, const
  File "/home/aditya/tvm/python/tvm/_ffi/node_generic.py", line 7, in <module>
    from .base import string_types
  File "/home/aditya/tvm/python/tvm/_ffi/base.py", line 43, in <module>
    _LIB, _LIB_NAME = _load_lib()
  File "/home/aditya/tvm/python/tvm/_ffi/base.py", line 35, in _load_lib
    lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)
  File "/usr/lib/python2.7/ctypes/__init__.py", line 362, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: /home/aditya/tvm/lib/libtvm.so: undefined symbol: _ZNK4llvm6Module4dumpEv

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually I use module_->dump() without piping it to stream and it should work

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen tqchen changed the title [v1] added initial llvm codegen for amdgpu [BACKEND][WIP] added initial llvm codegen for amdgpu Aug 31, 2017
@tqchen
Copy link
Member

tqchen commented Aug 31, 2017

@aditya4d
Copy link
Contributor Author

@tqchen I am getting the following error:

$ python test_codegen_device.py
[22:30:06] /home/aditya/tvm/dmlc-core/include/dmlc/./logging.h:308: [22:30:06] src/runtime/module.cc:74: Module[hip] does not support GetSource

Stack trace returned 10 entries:
[bt] (0) /home/aditya/tvm/lib/libtvm.so(_ZN3tvm7runtime10ModuleNode9GetSourceERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE+0x30f) [0x7f3ecfc6d16f]
[bt] (1) /home/aditya/tvm/lib/libtvm.so(+0x955fd2) [0x7f3ecfc6ffd2]
[bt] (2) /home/aditya/tvm/lib/libtvm.so(TVMFuncCall+0x5e) [0x7f3ecfc7961e]
[bt] (3) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call_unix64+0x4c) [0x7f3ed4d10e40]
[bt] (4) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call+0x2eb) [0x7f3ed4d108ab]
[bt] (5) /usr/lib/python2.7/lib-dynload/_ctypes.x86_64-linux-gnu.so(_ctypes_callproc+0x48f) [0x7f3ed4f203df]
[bt] (6) /usr/lib/python2.7/lib-dynload/_ctypes.x86_64-linux-gnu.so(+0x11d82) [0x7f3ed4f24d82]
[bt] (7) python(PyObject_Call+0x43) [0x4b0cb3]
[bt] (8) python(PyEval_EvalFrameEx+0x5faf) [0x4c9faf]
[bt] (9) python(PyEval_EvalCodeEx+0x255) [0x4c2765]

Traceback (most recent call last):
  File "test_codegen_device.py", line 88, in <module>
    test_add_pipeline()
  File "test_codegen_device.py", line 85, in test_add_pipeline
    check_target("rocm", host="llvm")
  File "test_codegen_device.py", line 46, in check_target
    code = mdev.get_source()
  File "/home/aditya/tvm/python/tvm/module.py", line 34, in get_source
    return _GetSource(self, fmt)
  File "/home/aditya/tvm/python/tvm/_ffi/function.py", line 255, in my_api_func
    return flocal(*args)
  File "/home/aditya/tvm/python/tvm/_ffi/_ctypes/function.py", line 183, in __call__
    ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
  File "/home/aditya/tvm/python/tvm/_ffi/base.py", line 62, in check_call
    raise TVMError(py_str(_LIB.TVMGetLastError()))
tvm._ffi.base.TVMError: [22:30:06] src/runtime/module.cc:74: Module[hip] does not support GetSource

" -mcpu=gfx900" +
target.substr(5, target.length() - 5));
) >= 4 &&
target.substr(0, 4) == "rocm");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend use GetLLVMTargetMachine so futher options can be passed here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

@tqchen
Copy link
Member

tqchen commented Aug 31, 2017

@tqchen
Copy link
Member

tqchen commented Aug 31, 2017

@aditya4d
Copy link
Contributor Author

Summarizing,

  1. It would be great to know a way to dump IR to see whether all the intrinsic, device code and meta data are generated correctly (even data layout and target triple).
  2. hipModuleLaunchKernel doesn't support the primary method from cudaLaunchKernel. It uses extra_args.
  3. I am seeing a segfault when running
$ python test_codegen_device.py
[23:11:18] src/runtime/rocm/rocm_module.cc:64: HSACO
Bus error (core dumped)

I have to debug on it more.

else {
CHECK_EQ(ts.rank, 0);
switch (ts.dim_index) {
case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x; break;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need this, this corresponds to get_group_id in OpenCL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it!

CHECK_EQ(info.scope.rank, 1)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I am aware of it. Once I get LLVM IR dump, I can get better understanding of what to change or even add more functionality.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, seems this is something we will need frequently. Let us simply also print out llvm ir and save it to the code field(optional) in the ROCMModule, so we can access it with module.get_source()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated PR with IR dump code. I see some places that need to be changed (which is why I am getting Bus Error).

@aditya4d
Copy link
Contributor Author

@tqchen I don't expect code to run as we have a new feature coming in from HIP (https://github.com/adityaatluri/HIP/commit/8a7328fd9de7f1d174e5f4b75de734fb4032f5b6). I'll write a CPP test to check whether the IR generated is valid or not.
But, the purpose of this PR is to get good IR generated.

@tqchen
Copy link
Member

tqchen commented Aug 31, 2017

Can you elaborate a bit on what is expected? For example, is the problem lies in the additional argument packing, or other parts? We might be able to change the runtime accordingly to solve this issue

It would be nice to get a runnable code.

@tqchen
Copy link
Member

tqchen commented Aug 31, 2017

Specifically, we can pre-pack the arguments into a single buffer, if necessary, without going through the HIP CUDA compatible API.

For example, in Metal runtime everything is packed into an array of ArgUnion, and the device code will r eceive a packed arguments instead https://github.com/dmlc/tvm/blob/master/src/runtime/metal/metal_module.mm#L202

@@ -113,7 +113,7 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
ret_void ? t_void_ : t_int_, arg_type, false);
// setup the function.
function_ = llvm::cast<llvm::Function>(module_->getOrInsertFunction(f->name, ftype));
function_->setCallingConv(llvm::CallingConv::C);
function_->setCallingConv(dev_type == AMDGPU ? llvm::CallingConv::AMDGPU_KERNEL : llvm::CallingConv::C);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likely you don't have to do this, we can overwrite AddFunction, and do setCallingConv after calling AddFunctionInternal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to create multiple functions which do the same and it is easier to read the code this way. Do you want me to overload the function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don;t have to overload AddFunctionInternal. Simply overload AddFunction, which calls AddFunctionInternal, and then do function_->setCallingConv again

void AddFunction(const LoweredFunc& f) final {
// add function as void return value
CodeGenLLVM::AddFunctionInternal(f, true, AMDGPU);
// annotate as kernel function
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e. simply add function_->setCallingConv here to override the old flag

@aditya4d
Copy link
Contributor Author

aditya4d commented Sep 6, 2017

@tqchen I am able to see good IR now and be able to generate ISA from it (through llc). https://gist.github.com/adityaatluri/1ac1ff72b927e42fdd8a61f98176039a

@tqchen
Copy link
Member

tqchen commented Sep 7, 2017

Can you confirm the gap between this and actual kernel test that runs? Thanks

@aditya4d
Copy link
Contributor Author

aditya4d commented Sep 7, 2017

Can you explain a bit more about what you meant?

@tqchen
Copy link
Member

tqchen commented Sep 7, 2017

I mean directly run the test via the Rocm module and verify the correctness of the kernel

@aditya4d
Copy link
Contributor Author

aditya4d commented Sep 7, 2017

Gotcha. Turns out the IR is not valid. These lines are causing bad output results. https://gist.github.com/adityaatluri/1ac1ff72b927e42fdd8a61f98176039a#file-tvm-amdgcn-ll-L10
Do you know where they are coming from ?
Correct IR: https://gist.github.com/adityaatluri/1ac1ff72b927e42fdd8a61f98176039a#file-tvm-amdgcn-correct-ll

@tqchen
Copy link
Member

tqchen commented Sep 7, 2017

This is shift left, used for address calculation, in condition

if (blockIdx.x * 256 + threadIdx.x < n) { 
    ...
}

blockIdx.x * 256 becomes blockIdx << 8 which corresponds to that line

@aditya4d
Copy link
Contributor Author

aditya4d commented Sep 7, 2017

Do you know which code block generate this?

@tqchen
Copy link
Member

tqchen commented Sep 7, 2017

Should due to LLVM's constant folder in IRBuilder, which automatically folds Mul(blockIdx, 256) into left shift, is the shift not supported by AMD ISA?

@aditya4d
Copy link
Contributor Author

aditya4d commented Sep 7, 2017

It does support shl/shr but I don't think we need to mul workitem id with 8. Also, I didn't see the last arg i32. Let me retest. Also, 1024 is good for AMDGPUs.

@aditya4d
Copy link
Contributor Author

aditya4d commented Sep 7, 2017

After retest, the data output got validated.

@tqchen
Copy link
Member

tqchen commented Sep 7, 2017

nice, can we directly use RocmModule to run the test instead of the current test that is de-coupled from the compiler?

@aditya4d
Copy link
Contributor Author

aditya4d commented Sep 7, 2017

We need new HIP which the team is working on. Once it lands, it'll make it easier to launch kernel.

@@ -100,7 +101,7 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
Type t = arg.type();
if (t.is_handle() && f->handle_data_type.count(arg)) {
arg_type.push_back(
LLVMType(f->handle_data_type[arg].type())->getPointerTo());
LLVMType(f->handle_data_type[arg].type())->getPointerTo(isTargetAMD ? 1 : 0));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a virtual function GetGlobalAddressSpace to CodeGenLLVM and override that in CodeGenAMDGPU

@@ -113,7 +114,8 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
ret_void ? t_void_ : t_int_, arg_type, false);
// setup the function.
function_ = llvm::cast<llvm::Function>(module_->getOrInsertFunction(f->name, ftype));
function_->setCallingConv(llvm::CallingConv::C);
function_->setCallingConv(isTargetAMD ?
llvm::CallingConv::AMDGPU_KERNEL : llvm::CallingConv::C);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do this set in AddFunction override in CodeGenAMDGPU

CodeGenLLVM::AddFunctionInternal(f, true);
// annotate as kernel function
/*
module_->getOrInsertNamedMetadata("nvvm.annotations")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove these comments

@@ -0,0 +1,176 @@
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_nvptx.cc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update the comments

llvm::Value* CreateStorageSync(const Call* op) final {
const std::string& sync = op->args[0].as<StringImm>()->value;
if (sync == "warp") {
// TODO(tqchen) warp sync in CUDA9
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the comment here, is there any need of warp(wavefront) synchronizer in AMD GPU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are sync commands for AMD GPU, but if it CUDA9 specific, current generation AMD GPUs don't support it.

};

runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
CHECK(1) << target;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove prints

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do it for the last commit.

@tqchen
Copy link
Member

tqchen commented Sep 7, 2017

I see what you mean by looking at the test code you provided. Actually the metadata is already available in TVMRuntime, and we are using this to pack the data.

So one possible way is simply implement parameter packing in TVM.

For example, https://github.com/dmlc/tvm/blob/master/src/runtime/pack_args.h#L150 packs non pointer argument into a continuous memory region of one buffer(ArgUnion). This is used by Metal runtime, which requires non pointer arguments to be packed as one buffer. If we know the parameter packing requirement(e.g. alignment of each value)

CHECK(tm->addPassesToEmitFile(
pass, destAsm, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_AssemblyFile";

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comments here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need this in future. Especially helpful for kernel debugging.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, but usually it is not good to keep debug code into production. What we should do is to add it back later when there is need for debug.

};

runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
CHECK(1) << target;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this

@tqchen
Copy link
Member

tqchen commented Sep 13, 2017

Last two comments, and we can merge this in. Thanks for the work to make this happen

@tqchen tqchen changed the title [BACKEND][WIP] added initial llvm codegen for amdgpu [BACKEND] initial llvm codegen for amdgpu Sep 13, 2017
arr.data = &obj[0];
arr.size = obj.length();

std::string hsaco = (*f)(arr), ll;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are using a comma expression here, is it intended? This will results in hsaco take value from ll

destAsm.SetUnbuffered();
module->print(dest_ll, nullptr);
std::unique_ptr<llvm::Module> mAsm = llvm::CloneModule(module.get());
std::unique_ptr<llvm::Module> mObj = llvm::CloneModule(module.get());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove mObjFile and mAsmFile. We can consider hold two optional source code in RocmModule, both ll and asm, and return them when different source suffix is requested, that might help you in debugging.

std::unique_ptr<llvm::Module> mObjFile = llvm::CloneModule(module.get());
llvm::legacy::PassManager pass;


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove extra line here

@tqchen tqchen merged commit 891e226 into apache:master Sep 13, 2017
@masahi
Copy link
Member

masahi commented Oct 10, 2017

Hi @adityaatluri , do I need a custom llvm + clang from AMD to use the rocm backend?
I tried to run test_gemm.py with TVM + llvm 4.0, but it hangs forever.

My card is R9 Nano, so I replaced gfx900 in this line with gfx803.

@aditya4d
Copy link
Contributor Author

@masahi can you try LLVM 5.0? There are few issues with rocm runtime which will be fixed soon.

@masahi
Copy link
Member

masahi commented Oct 10, 2017

@adityaatluri Thanks for the quick response. I'll try llvm 5.0 after I am back from work.

By the way, the opencl backend with rocm's opencl stack works fine on my Nano. I can pass all tests in https://github.com/dmlc/tvm/tree/master/topi/tests/python .

@aditya4d
Copy link
Contributor Author

Thank you for trying it out.

@masahi
Copy link
Member

masahi commented Oct 10, 2017

@adityaatluri I built tvm with llvm 5.0 from the official ubunutu package, but test_gemm.py still hangs my entire system.
It gives a familiar error message, saying 'Memory access fault by GPU node -1 on address ...' .

I think something is wrong with codegen. With opencl backend, when I do this:

f_opencl = tvm.build(s, [A, B, C], "opencl")
dev_module_opencl = f_opencl.imported_modules[0]
print(dev_module_opencl.get_source())

I get a valid opencl kernel string. But for rocm backend,

f_rocm = tvm.build(s, [A, B, C], "rocm")
dev_module_rocm = f_rocm.imported_modules[0]
print(repr(dev_module_rocm.get_source()))

just prints out '\x7fELF\x02\x01\x01@'

test_codegen_device.py fails for the same reason. But test_runtime_ndarray.py passes.

Any ideas?

@aditya4d
Copy link
Contributor Author

@masahi
Copy link
Member

masahi commented Oct 10, 2017

Ok I get this
https://gist.github.com/masahi/2c81b07aaf2f2e58cd0a053fe3d1fb02#file-myadd_kernel0-ll
Does it look good?

I'm correctly linking against llvm 5.0.
$ldd libtvm.so
linux-vdso.so.1 => (0x00007ffd97336000)
libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007fb7d954c000)
libLLVM-5.0.so.1 => /usr/lib/x86_64-linux-gnu/libLLVM-5.0.so.1 (0x00007fb7d5f92000)
libhip_hcc.so => /opt/rocm/lib/libhip_hcc.so (0x00007fb7d5ced000)
libOpenCL.so.1 => /usr/lib/x86_64-linux-gnu/libOpenCL.so.1 (0x00007fb7d5ae2000)
libstdc++.so.6 => /usr/lib/x86_64-linux-gnu/libstdc++.so.6 (0x00007fb7d575f000)
libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007fb7d5456000)
libgcc_s.so.1 => /lib/x86_64-linux-gnu/libgcc_s.so.1 (0x00007fb7d5240000)
libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007fb7d5022000)
libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007fb7d4c58000)
/lib64/ld-linux-x86-64.so.2 (0x000055f5d1d3a000)
libffi.so.6 => /usr/lib/x86_64-linux-gnu/libffi.so.6 (0x00007fb7d4a50000)
libedit.so.2 => /usr/lib/x86_64-linux-gnu/libedit.so.2 (0x00007fb7d4817000)
libtinfo.so.5 => /lib/x86_64-linux-gnu/libtinfo.so.5 (0x00007fb7d45ee000)
libz.so.1 => /lib/x86_64-linux-gnu/libz.so.1 (0x00007fb7d43d4000)
libunwind.so.8 => /usr/lib/x86_64-linux-gnu/libunwind.so.8 (0x00007fb7d41b8000)
libhc_am.so => /opt/rocm/lib/libhc_am.so (0x00007fb7d3f96000)
libbsd.so.0 => /lib/x86_64-linux-gnu/libbsd.so.0 (0x00007fb7d3d80000)
liblzma.so.5 => /lib/x86_64-linux-gnu/liblzma.so.5 (0x00007fb7d3b5e000)
libhsa-runtime64.so.1 => /opt/rocm/hsa/lib/libhsa-runtime64.so.1 (0x00007fb7d38c7000)
libhsakmt.so.1 => /opt/rocm/libhsakmt/lib/libhsakmt.so.1 (0x00007fb7d36a8000)
libelf.so.1 => /usr/lib/x86_64-linux-gnu/libelf.so.1 (0x00007fb7d3490000)
librt.so.1 => /lib/x86_64-linux-gnu/librt.so.1 (0x00007fb7d3288000)
libpci.so.3 => /lib/x86_64-linux-gnu/libpci.so.3 (0x00007fb7d307a000)
libresolv.so.2 => /lib/x86_64-linux-gnu/libresolv.so.2 (0x00007fb7d2e5f000)
libudev.so.1 => /lib/x86_64-linux-gnu/libudev.so.1 (0x00007fb7d2e3e000)

@aditya4d
Copy link
Contributor Author

Can you compile the ir to asm using llc -march=amdgcn -mcpu=gfx900 <file.ll>?

@masahi
Copy link
Member

masahi commented Oct 10, 2017

Ok, the output of llc-5.0 -march=amdgcn -mcpu=gfx803 myadd_kernel.ll (not gfx900, my card is R9 Nano)
https://gist.github.com/masahi/6c7f270240891bc7e1e82dd221e05903

I can also disassemble 'rocm_kernel.co', generated in here
The output of $/opt/rocm/hcc/compiler/bin/llvm-objdump -disassemble -mcpu=gfx803 rocm_kernel.co
https://gist.github.com/masahi/ba5b376d2ecd15c72f6ee599a84287a7

@aditya4d
Copy link
Contributor Author

The asm looks good to me. Are you still getting runtime error?

@masahi
Copy link
Member

masahi commented Oct 10, 2017

Yes, I either get 'Memory access fault error', or no error but the output array is [0., 0., 0., ....] .

When I print the value of packed_nbytes here, it sayes 28 or 20 (the operator() is called twice, don't know why)
Is this expected?

@aditya4d
Copy link
Contributor Author

@masahi Great! That is the bug we are seeing that I mentioned.

[12:04:46] src/runtime/rocm/rocm_device_api.cc:126: Doing GPUCopy 
[12:04:46] src/runtime/rocm/rocm_device_api.cc:128: HtoD: 0.400863
[12:04:46] src/runtime/rocm/rocm_device_api.cc:126: Doing GPUCopy 
[12:04:46] src/runtime/rocm/rocm_device_api.cc:128: HtoD: 0.0277189
[12:04:46] src/runtime/rocm/rocm_device_api.cc:126: Doing GPUCopy 
[12:04:46] src/runtime/rocm/rocm_device_api.cc:128: HtoD: 0
[12:04:46] src/runtime/rocm/rocm_device_api.cc:126: Doing GPUCopy 
[12:04:46] src/runtime/rocm/rocm_device_api.cc:137: DtoH: 0
[12:04:46] src/runtime/rocm/rocm_device_api.cc:126: Doing GPUCopy 
[12:04:46] src/runtime/rocm/rocm_device_api.cc:137: DtoH: 0.400863
Traceback (most recent call last):
  File "test_codegen_device.py", line 88, in <module>
    test_add_pipeline()
  File "test_codegen_device.py", line 85, in test_add_pipeline
    check_target("rocm", host="llvm")
  File "test_codegen_device.py", line 55, in check_target
    c.asnumpy(), a.asnumpy())
  File "/usr/lib/python2.7/dist-packages/numpy/testing/utils.py", line 1391, in assert_allclose
    verbose=verbose, header=header)
  File "/usr/lib/python2.7/dist-packages/numpy/testing/utils.py", line 733, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0

(mismatch 100.0%)
 x: array([ 0.,  0.,  0., ...,  0.,  0.,  0.], dtype=float32)
 y: array([ 0.400863,  0.975084,  0.134123, ...,  0.170033,  0.325066,
        0.891966], dtype=float32)

Do you have any interesting observations?

@masahi
Copy link
Member

masahi commented Oct 10, 2017

Nothing so far, working on it.
So the value of packed_nbytes to be 28 or 20 is definitely wrong?

@aditya4d
Copy link
Contributor Author

Can you join the dlpack slack channel? We can discuss more there.

@masahi
Copy link
Member

masahi commented Oct 10, 2017

Sure, but how can I join? Haven't used slack before.

@aditya4d
Copy link
Contributor Author

It is dlpack.slack.com

@masahi
Copy link
Member

masahi commented Oct 10, 2017

Ok, I'll ping @tqchen to send me an invite.

@tqchen
Copy link
Member

tqchen commented Oct 10, 2017

@masahi You can send an email to my uw email address

vinx13 added a commit to vinx13/tvm that referenced this pull request Jul 1, 2021
* [Arith] Inverse affine map

* Update iter_affine_map.h

* Update iter_affine_map.h

* Update iter_affine_map.py

* Topology order visit

* doc

* fix

* address comments
junrushao pushed a commit to junrushao/tvm that referenced this pull request Jan 27, 2022
* [Arith] Inverse affine map

* Update iter_affine_map.h

* Update iter_affine_map.h

* Update iter_affine_map.py

* Topology order visit

* doc

* fix

* address comments
vinx13 added a commit to vinx13/tvm that referenced this pull request Mar 9, 2022
* [Arith] Inverse affine map

* Update iter_affine_map.h

* Update iter_affine_map.h

* Update iter_affine_map.py

* Topology order visit

* doc

* fix

* address comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants