Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] hipBLAS integration #17290

Merged
merged 1 commit into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ tvm_option(USE_THRUST "Build with Thrust" OFF)
tvm_option(USE_CURAND "Build with cuRAND" OFF)
tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF)
tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF)
tvm_option(USE_HIPBLAS "Build with ROCM:HIPBLAS" OFF)
tvm_option(USE_SORT "Build with sort support" ON)
tvm_option(USE_NNPACK "Build with nnpack support" OFF)
tvm_option(USE_LIBTORCH "Build with libtorch support" OFF)
Expand Down
1 change: 1 addition & 0 deletions cmake/modules/LibInfo.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ function(add_lib_info src_file)
TVM_INFO_TVM_DEBUG_WITH_ABI_CHANGE="${TVM_DEBUG_WITH_ABI_CHANGE}"
TVM_INFO_TVM_LOG_BEFORE_THROW="${TVM_LOG_BEFORE_THROW}"
TVM_INFO_USE_ROCBLAS="${USE_ROCBLAS}"
TVM_INFO_USE_HIPBLAS="${USE_HIPBLAS}"
TVM_INFO_USE_ROCM="${USE_ROCM}"
TVM_INFO_USE_RCCL="${USE_RCCL}"
TVM_INFO_USE_RPC="${USE_RPC}"
Expand Down
12 changes: 12 additions & 0 deletions cmake/modules/ROCM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ if(USE_ROCM)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_ROCBLAS_LIBRARY})
endif(USE_ROCBLAS)

if(USE_HIPBLAS)
message(STATUS "Build with HIPBLAS support")
tvm_file_glob(GLOB HIPBLAS_CONTRIB_SRC src/relax/backend/contrib/hipblas/*.cc)
list(APPEND COMPILER_SRCS ${HIPBLAS_CONTRIB_SRC})
tvm_file_glob(GLOB HIPBLAS_CONTRIB_SRCS src/runtime/contrib/hipblas/*.cc)
list(APPEND RUNTIME_SRCS ${HIPBLAS_CONTRIB_SRCS})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_HIPBLAS_LIBRARY})
if(NOT ROCM_HIPBLASLT_LIBRARY STREQUAL "ROCM_HIPBLASLT_LIBRARY-NOTFOUND")
list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_HIPBLASLT_LIBRARY})
endif()
endif(USE_HIPBLAS)

if(USE_THRUST)
message(STATUS "Build with rocThrust support")
# We need to override CXX to hipcc. This is required by rocthrust
Expand Down
4 changes: 4 additions & 0 deletions cmake/utils/FindROCM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ macro(find_rocm use_rocm)
endif()
find_library(ROCM_MIOPEN_LIBRARY MIOpen ${__rocm_sdk}/lib)
find_library(ROCM_ROCBLAS_LIBRARY rocblas ${__rocm_sdk}/lib)
find_library(ROCM_HIPBLAS_LIBRARY hipblas ${__rocm_sdk}/lib)
find_library(ROCM_HIPBLASLT_LIBRARY hipblaslt ${__rocm_sdk}/lib)
find_library(ROCM_HSA_LIBRARY hsa-runtime64 ${__rocm_sdk}/lib)

if(ROCM_HIPHCC_LIBRARY)
Expand All @@ -66,5 +68,7 @@ macro(find_rocm use_rocm)
message(STATUS "Found ROCM_HIPHCC_LIBRARY=" ${ROCM_HIPHCC_LIBRARY})
message(STATUS "Found ROCM_MIOPEN_LIBRARY=" ${ROCM_MIOPEN_LIBRARY})
message(STATUS "Found ROCM_ROCBLAS_LIBRARY=" ${ROCM_ROCBLAS_LIBRARY})
message(STATUS "Found ROCM_HIPBLAS_LIBRARY=" ${ROCM_HIPBLAS_LIBRARY})
message(STATUS "Found ROCM_HIPBLASLT_LIBRARY=" ${ROCM_HIPBLASLT_LIBRARY})
endif(ROCM_FOUND)
endmacro(find_rocm)
86 changes: 86 additions & 0 deletions python/tvm/contrib/hipblas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""External function interface to hipBLAS libraries."""
import tvm
from tvm import te


def matmul(lhs, rhs, transa=False, transb=False, dtype=None):
"""Create an extern op that compute matrix mult of A and rhs with cuBLAS

Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs

Returns
-------
C : Tensor
The result tensor.
"""
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
dtype = dtype if dtype is not None else lhs.dtype
return te.extern(
(n, m),
[lhs, rhs],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.hipblas.matmul", ins[0], ins[1], outs[0], transa, transb
),
dtype=dtype,
name="matmul_hipblas",
)


def batch_matmul(lhs, rhs, transa=False, transb=False, dtype=None):
"""Create an extern op that compute batch matrix mult of A and rhs with cuBLAS

Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs

Returns
-------
C : Tensor
The result tensor.
"""
b = lhs.shape[0]
n = lhs.shape[2] if transa else lhs.shape[1]
m = rhs.shape[1] if transb else rhs.shape[2]
dtype = dtype if dtype is not None else lhs.dtype
return te.extern(
(b, n, m),
[lhs, rhs],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.hipblas.batch_matmul", ins[0], ins[1], outs[0], transa, transb
),
dtype=dtype,
name="batch_matmul_hipblas",
)
180 changes: 180 additions & 0 deletions python/tvm/relax/backend/contrib/hipblas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Pattern table for hipblas backend"""
import operator
from functools import reduce

import tvm
from tvm.relax import transform
from tvm.relax.transform import PatternCheckContext

from ..pattern_registry import get_patterns_with_prefix, register_patterns
from ..patterns import make_matmul_pattern
from ..utils import has_leaking_intermediate_variables


def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype): # pylint: disable=unused-argument
"""Check if dtypes in the given workload are supported by hipblas BYOC."""
if lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
# The output cannot be 'e5m2_float8' if inputs are 'e4m3_float8'
# return out_dtype != "e5m2_float8"
return False
return (lhs_dtype == "float16" and rhs_dtype == "float16") or (
lhs_dtype == "int8" and rhs_dtype == "int8"
)


def _check_matmul(context: PatternCheckContext) -> bool:
if has_leaking_intermediate_variables(context):
return False
lhs = context.annotated_expr["lhs"]
rhs = context.annotated_expr["rhs"]
matmul_call = context.annotated_expr["root"]

lhs_dtype = lhs.struct_info.dtype
rhs_dtype = rhs.struct_info.dtype
out_dtype = matmul_call.struct_info.dtype
if not _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
return False

lhs_shape = lhs.struct_info.shape.values
rhs_shape = rhs.struct_info.shape.values

if not isinstance(lhs_shape[-1], (tvm.tir.expr.IntImm, int)):
# Reduction axis must be constant
return False

if lhs_dtype == "int8" and rhs_dtype == "int8":
return False
elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
return False

lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)

if "bias" in context.annotated_expr:
if lhs_dtype == "int8" and rhs_dtype == "int8":
# Non-default epilogue not supported for IGEMM
return False
bias = context.annotated_expr["bias"]
bias_shape = bias.struct_info.shape.values
bias_batches = reduce(operator.mul, bias_shape[:-1], 1)
if not isinstance(bias_batches, (tvm.tir.expr.IntImm, int)) or int(bias_batches) > 1:
# hipblas only supports bias vector
return False

# hipblasLt does not seem to support batched GEMM with one of matrices having
# one batch (with batch_stride 0). So for batched GEMM, the two batch counts
# must be equal. If lhs is batched but rhs is not, we can use the regular GEMM by
# flattening all batch axes into the M axis.
return (
isinstance(lhs_batches, tvm.tir.Var)
or isinstance(rhs_batches, tvm.tir.Var)
or (int(lhs_batches) == int(rhs_batches))
or (lhs_batches >= 1 and rhs_batches == 1)
)


register_patterns(
[
(
"hipblas.matmul",
*make_matmul_pattern(
with_bias=False,
),
_check_matmul,
),
(
"hipblas.matmul_bias",
*make_matmul_pattern(
with_bias=True,
),
_check_matmul,
),
(
"hipblas.matmul_bias_relu",
*make_matmul_pattern(
with_bias=True,
activation="relax.nn.relu",
),
_check_matmul,
),
(
"hipblas.matmul_bias_gelu",
*make_matmul_pattern(
with_bias=True,
activation="relax.nn.gelu",
),
_check_matmul,
),
(
"hipblas.matmul_transposed",
*make_matmul_pattern(
with_bias=False,
transposed_rhs=True,
),
_check_matmul,
),
(
"hipblas.matmul_transposed_bias",
*make_matmul_pattern(
with_bias=True,
transposed_rhs=True,
),
_check_matmul,
),
(
"hipblas.matmul_transposed_bias_relu",
*make_matmul_pattern(
with_bias=True,
activation="relax.nn.relu",
transposed_rhs=True,
),
_check_matmul,
),
(
"hipblas.matmul_transposed_bias_gelu",
*make_matmul_pattern(
with_bias=True,
activation="relax.nn.gelu",
transposed_rhs=True,
),
_check_matmul,
),
]
)


def partition_for_hipblas(mod):
"""
Partition the input module into hipblas-supported subgraphs.

Parameters
----------
mod: tvm.IRModule
The IRModule to be partitioned.

Returns
-------
mod: tvm.IRModule
The resulting IRModule, containing partitioned subgraphs to be
offloaded to the hipblas backend.
"""

patterns = get_patterns_with_prefix("hipblas")
return transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod)
3 changes: 3 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,9 @@ def _multi_gpu_exists():
parent_features="rocm",
)

# Mark a test as requiring the hipBLAS library.
requires_hipblas = Feature("hipblas", "hipBLAS", cmake_flag="USE_HIPBLAS", parent_features="rocm")

# Mark a test as requiring the metal runtime
requires_metal = Feature(
"metal",
Expand Down
Loading
Loading