forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CONTRIB] rocBLAS integration (apache#751)
* rocblas integration * fix include * fix lint
- Loading branch information
1 parent
a3a4b12
commit eba7c13
Showing
9 changed files
with
221 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,3 +80,6 @@ USE_MPS = 0 | |
|
||
# Whether use cuBLAS | ||
USE_CUBLAS = 0 | ||
|
||
# Whether use rocBlas | ||
USE_ROCBLAS = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
ROCBLAS_CONTRIB_SRC = $(wildcard src/contrib/rocblas/*.cc) | ||
ROCBLAS_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(ROCBLAS_CONTRIB_SRC)) | ||
|
||
ifeq ($(USE_ROCBLAS), 1) | ||
CFLAGS += -DTVM_USE_ROCBLAS=1 | ||
ADD_LDFLAGS += -lrocblas | ||
RUNTIME_DEP += $(ROCBLAS_CONTRIB_OBJ) | ||
endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
"""External function interface to rocBLAS libraries.""" | ||
from __future__ import absolute_import as _abs | ||
|
||
from .. import api as _api | ||
from .. import intrin as _intrin | ||
|
||
def matmul(lhs, rhs, transa=False, transb=False): | ||
"""Create an extern op that compute matrix mult of A and rhs with rocBLAS | ||
Parameters | ||
---------- | ||
lhs : Tensor | ||
The left matrix operand | ||
rhs : Tensor | ||
The right matrix operand | ||
transa : bool | ||
Whether transpose lhs | ||
transb : bool | ||
Whether transpose rhs | ||
Returns | ||
------- | ||
C : Tensor | ||
The result tensor. | ||
""" | ||
n = lhs.shape[1] if transa else lhs.shape[0] | ||
m = rhs.shape[0] if transb else rhs.shape[1] | ||
return _api.extern( | ||
(n, m), [lhs, rhs], | ||
lambda ins, outs: _intrin.call_packed( | ||
"tvm.contrib.rocblas.matmul", | ||
ins[0], ins[1], outs[0], transa, transb), name="C") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file Use external rocblas library call. | ||
*/ | ||
#include <tvm/runtime/registry.h> | ||
#include <tvm/runtime/util.h> | ||
#include <dmlc/logging.h> | ||
#include "rocblas.h" | ||
|
||
namespace tvm { | ||
namespace contrib { | ||
|
||
using namespace runtime; | ||
|
||
#ifndef CHECK_ROCBLAS_ERROR | ||
#define CHECK_ROCBLAS_ERROR(error) \ | ||
if (error != rocblas_status_success) { \ | ||
fprintf(stderr, "rocBLAS error: "); \ | ||
if (error == rocblas_status_invalid_handle) fprintf(stderr, "rocblas_status_invalid_handle"); \ | ||
if (error == rocblas_status_not_implemented) fprintf(stderr, " rocblas_status_not_implemented"); \ | ||
if (error == rocblas_status_invalid_pointer) fprintf(stderr, "rocblas_status_invalid_pointer"); \ | ||
if (error == rocblas_status_invalid_size) fprintf(stderr, "rocblas_status_invalid_size"); \ | ||
if (error == rocblas_status_memory_error) fprintf(stderr, "rocblas_status_memory_error"); \ | ||
if (error == rocblas_status_internal_error) fprintf(stderr, "rocblas_status_internal_error"); \ | ||
fprintf(stderr, "\n"); \ | ||
exit(EXIT_FAILURE); \ | ||
} | ||
#endif | ||
|
||
|
||
// matrix multiplication for row major | ||
TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul") | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
DLTensor* A = args[0]; | ||
DLTensor* B = args[1]; | ||
DLTensor* C = args[2]; | ||
bool transa = args[3]; | ||
bool transb = args[4]; | ||
// call gemm for simple compact code. | ||
CHECK_EQ(A->ndim, 2); | ||
CHECK_EQ(B->ndim, 2); | ||
CHECK_EQ(C->ndim, 2); | ||
CHECK(C->strides == nullptr); | ||
CHECK(B->strides == nullptr); | ||
CHECK(A->strides == nullptr); | ||
CHECK(TypeMatch(A->dtype, kDLFloat, 32)); | ||
CHECK(TypeMatch(B->dtype, kDLFloat, 32)); | ||
CHECK(TypeMatch(C->dtype, kDLFloat, 32)); | ||
|
||
rocblas_handle handle; | ||
CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle)); | ||
float alpha = 1.0; | ||
float beta = 0.0; | ||
float *A_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset); | ||
float *B_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset); | ||
float *C_ptr = reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset); | ||
|
||
CHECK_ROCBLAS_ERROR(rocblas_sgemm(handle, | ||
transb ? rocblas_operation_transpose : rocblas_operation_none, | ||
transa ? rocblas_operation_transpose : rocblas_operation_none, | ||
transb ? B->shape[0] : B->shape[1], | ||
transa ? A->shape[1] : A->shape[0], | ||
transb ? B->shape[1] : B->shape[0], | ||
&alpha, | ||
A_ptr, | ||
B->shape[1], | ||
B_ptr, | ||
A->shape[1], | ||
&beta, | ||
C_ptr, | ||
C->shape[1])); | ||
|
||
CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle)); | ||
}); | ||
} // namespace contrib | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import tvm | ||
import numpy as np | ||
from tvm.contrib import rocblas | ||
|
||
def test_matmul_add(): | ||
n = 1024 | ||
l = 128 | ||
m = 235 | ||
A = tvm.placeholder((n, l), name='A') | ||
B = tvm.placeholder((l, m), name='B') | ||
C = rocblas.matmul(A, B) | ||
s = tvm.create_schedule(C.op) | ||
|
||
def verify(target="rocm"): | ||
if not tvm.module.enabled(target): | ||
print("skip because %s is not enabled..." % target) | ||
return | ||
if not tvm.get_global_func("tvm.contrib.rocblas.matmul", True): | ||
print("skip because extern function is not avalable") | ||
return | ||
ctx = tvm.rocm(0) | ||
f = tvm.build(s, [A, B, C], target) | ||
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx) | ||
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx) | ||
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) | ||
f(a, b, c) | ||
np.testing.assert_allclose( | ||
c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5) | ||
verify() | ||
|
||
|
||
if __name__ == "__main__": | ||
test_matmul_add() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ | |
from __future__ import absolute_import as _abs | ||
|
||
from .conv2d import * | ||
from .dense import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# pylint: disable=invalid-name, unused-variable | ||
"""Schedule for dense operator""" | ||
from __future__ import absolute_import as _abs | ||
import tvm | ||
from tvm.contrib import rocblas | ||
import topi | ||
from ..nn.dense import dense, dense_default | ||
from .. import tag | ||
from .. import generic | ||
|
||
@dense.register("rocm") | ||
def dense_rocm(data, weight, bias=None): | ||
"""Dense operator for rocm backend. | ||
Parameters | ||
---------- | ||
data : tvm.Tensor | ||
2-D with shape [batch, in_dim] | ||
weight : tvm.Tensor | ||
2-D with shape [out_dim, in_dim] | ||
bias : tvm.Tensor, optional | ||
1-D with shape [out_dim] | ||
Returns | ||
------- | ||
output : tvm.Tensor | ||
2-D with shape [batch, out_dim] | ||
""" | ||
assert len(data.shape) == 2 and len(weight.shape) == 2, \ | ||
"only support 2-dim dense" | ||
if bias is not None: | ||
assert len(bias.shape) == 1 | ||
batch, in_dim = data.shape | ||
out_dim, _ = weight.shape | ||
target = tvm.target.current_target() | ||
if "rocblas" in target.libs: | ||
matmul = rocblas.matmul(data, weight, False, True) | ||
if bias is not None: | ||
matmul = tvm.compute((batch, out_dim), \ | ||
lambda i, j: matmul[i, j] + bias[j], \ | ||
tag=tag.BROADCAST) | ||
return matmul | ||
return dense_default(data, weight, bias) | ||
|
||
|
||
@generic.schedule_dense.register(["rocm"]) | ||
def schedule_dense(outs): | ||
"""Schedule for dense operator. | ||
Parameters | ||
---------- | ||
outs: Array of Tensor | ||
The computation graph description of dense | ||
in the format of an array of tensors. | ||
Returns | ||
------- | ||
s: Schedule | ||
The computation schedule for dense. | ||
""" | ||
target = tvm.target.current_target() | ||
if target.target_name == "rocm" and "rocblas" in target.libs: | ||
return generic.schedule_extern(outs) | ||
return topi.cuda.schedule_dense(outs) |