Skip to content

Commit

Permalink
[CUDNN] Add partitioning support for fused conv2d+bias+act (apache#10997
Browse files Browse the repository at this point in the history
)

cuDNN has kernel support for the pattern conv2d+bias+act,
although as of v8 only relu is supported as the activation.
  • Loading branch information
mbaret authored and altanh committed Apr 28, 2022
1 parent 4d392da commit e1a0c38
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 12 deletions.
79 changes: 70 additions & 9 deletions python/tvm/relay/op/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
# under the License.
# pylint: disable=unused-argument
"""cuDNN Relay integration."""
from typing import Callable, List, Tuple, Dict, Optional
from typing import Callable, List, Tuple

import tvm
import tvm.ir
from tvm import relay
from tvm import te
from tvm.relay import transform
from tvm.contrib import cudnn
from tvm.relay.build_module import bind_params_by_name

from ...dataflow_pattern import is_op, wildcard
from .te_target import lower_composite, relay_to_runtime
Expand All @@ -34,25 +33,19 @@
tvm._ffi.register_func("relay.ext.cudnn", relay_to_runtime(tvm.target.cuda()))


def partition_for_cudnn(
mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None
) -> tvm.IRModule:
def partition_for_cudnn(mod: tvm.IRModule) -> tvm.IRModule:
"""Partition the graph to offload for cuDNN.
Parameters
----------
mod : tvm.IRModule
The module to partition.
params : Optional[Dict[str, tvm.runtime.NDArray]]
Constant input parameters.
Returns
-------
tvm.IRModule
The partitioned module.
"""
if params:
mod["main"] = bind_params_by_name(mod["main"], params)

seq = tvm.transform.Sequential(
[
Expand Down Expand Up @@ -82,6 +75,12 @@ def conv2d_pattern() -> relay.Pattern:
"""Create pattern for conv2d."""
return is_op("nn.conv2d")(wildcard(), wildcard())

def conv2d_bias_act_pattern() -> relay.Pattern:
"""Create pattern for fused conv2d+bias+activation."""
conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
bias = is_op("nn.bias_add")(conv2d, wildcard())
return bias.optional(is_op("nn.relu"))

def check_softmax(matched: relay.Call) -> bool:
"""Check if softmax is supported by cuDNN."""
if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
Expand Down Expand Up @@ -115,9 +114,13 @@ def check_conv2d(matched: relay.Call) -> bool:

return True

def check_conv2d_bias_act(matched: relay.Call) -> bool:
return True

return [
("cudnn.softmax", softmax_pattern(), check_softmax),
("cudnn.log_softmax", log_softmax_pattern(), check_log_softmax),
("cudnn.conv2d_bias_act", conv2d_bias_act_pattern(), check_conv2d_bias_act),
("cudnn.conv2d", conv2d_pattern(), check_conv2d),
]

Expand All @@ -134,6 +137,64 @@ def _lower_log_softmax(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
return cudnn.log_softmax(inputs[0], axis=op.attrs["axis"])


@lower_composite("cudnn.conv2d_bias_act")
def _lower_conv2d_bias_act(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
"""Lower a fused conv2d+bias+activation using cuDNN."""
conv_dtype = op.checked_type.dtype
if op.op.name == "nn.relu":
activation_mode = 1 # Relu
conv2d = op.args[0].args[0]
else:
activation_mode = 5 # Identity
conv2d = op.args[0]

conv_mode = 1
tensor_format = 0
algo = 1
pad = conv2d.attrs["padding"]
strides = conv2d.attrs["strides"]
dilation = conv2d.attrs["dilation"]
groups = conv2d.attrs["groups"]

oshape = cudnn.conv_output_shape(
tensor_format,
pad,
strides,
dilation,
inputs[0].shape,
inputs[1].shape,
inputs[0].dtype,
conv_dtype,
groups,
)

return te.extern(
oshape,
inputs,
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.conv2d+bias+act.forward",
conv_mode,
tensor_format,
algo,
pad[0],
pad[1],
strides[0],
strides[1],
dilation[0],
dilation[1],
activation_mode,
0,
ins[0],
ins[1],
ins[2],
outs[0],
conv_dtype,
groups,
),
name="y",
)


@lower_composite("cudnn.conv2d")
def _lower_conv2d(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
"""Lower a conv2d using cuDNN."""
Expand Down
62 changes: 62 additions & 0 deletions src/runtime/contrib/cudnn/conv_forward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,44 @@ void ConvolutionForward(int mode, int format, int algo, int dims, int groups, co
entry_ptr->conv_entry.output_desc, y->data));
}

void ConvolutionBiasActivationForward(int mode, int format, int algo, int dims, int groups, int act,
double coef, const int pad[], const int stride[],
const int dilation[], DLTensor* x, DLTensor* w, DLTensor* y,
DLTensor* bias, const std::string& conv_dtype) {
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
// Set Mode
entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
CUDNN_CALL(cudnnSetActivationDescriptor(entry_ptr->conv_entry.activation_desc,
static_cast<cudnnActivationMode_t>(act),
cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN, coef));
CUDNN_CALL(cudnnSetTensor4dDescriptor(
entry_ptr->conv_entry.bias_desc, entry_ptr->conv_entry.tensor_format,
CuDNNDataType::DLTypeToCuDNNType(bias->dtype), 1, static_cast<int>(w->shape[0]), 1, 1));

SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, w->shape,
y->shape, x->dtype, conv_dtype);
// Set Device
entry_ptr->conv_entry.device = x->device;
// Set Algo
entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo);

// Set workspace
size_t workspace_size = 0;
CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(
entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc,
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc,
entry_ptr->conv_entry.fwd_algo, &workspace_size));
entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
CUDNN_CALL(cudnnConvolutionBiasActivationForward(
entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.filter_desc, w->data,
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.fwd_algo,
entry_ptr->conv_entry.workspace, workspace_size,
CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
entry_ptr->conv_entry.output_desc, y->data, entry_ptr->conv_entry.bias_desc, bias->data,
entry_ptr->conv_entry.activation_desc, entry_ptr->conv_entry.output_desc, y->data));
}

void FindAlgo(int format, int dims, int groups, const int pad[], const int stride[],
const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[],
const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) {
Expand Down Expand Up @@ -126,6 +164,30 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
conv_dtype);
});

TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d+bias+act.forward")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int mode = args[0];
int format = args[1];
int algo = args[2];
int pad_v[2], stride_v[2], dilation_v[2];
for (int i = 0; i < 2; i++) {
pad_v[i] = args[3 + i];
stride_v[i] = args[5 + i];
dilation_v[i] = args[7 + i];
}
int act = args[9];
double coef = args[10];
DLTensor* x = args[11];
DLTensor* w = args[12];
DLTensor* bias = args[13];
DLTensor* y = args[14];
std::string conv_dtype = args[15];
int groups = args[16];

ConvolutionBiasActivationForward(mode, format, algo, 2, groups, act, coef, pad_v, stride_v,
dilation_v, x, w, y, bias, conv_dtype);
});

TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int mode = args[0];
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/contrib/cudnn/cudnn_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,17 @@ ConvEntry::ConvEntry() {
CUDNN_CALL(cudnnCreateFilterDescriptor(&filter_desc));
CUDNN_CALL(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CALL(cudnnCreateTensorDescriptor(&output_desc));
CUDNN_CALL(cudnnCreateTensorDescriptor(&bias_desc));
CUDNN_CALL(cudnnCreateActivationDescriptor(&activation_desc));
}

ConvEntry::~ConvEntry() {
CUDNN_CALL(cudnnDestroyFilterDescriptor(filter_desc));
CUDNN_CALL(cudnnDestroyConvolutionDescriptor(conv_desc));
CUDNN_CALL(cudnnDestroyTensorDescriptor(input_desc));
CUDNN_CALL(cudnnDestroyTensorDescriptor(output_desc));
CUDNN_CALL(cudnnDestroyTensorDescriptor(bias_desc));
CUDNN_CALL(cudnnDestroyActivationDescriptor(activation_desc));
CleanWorkspace();
}

Expand Down
2 changes: 2 additions & 0 deletions src/runtime/contrib/cudnn/cudnn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ struct ConvEntry {
cudnnTensorFormat_t tensor_format;
cudnnTensorDescriptor_t input_desc;
cudnnFilterDescriptor_t filter_desc;
cudnnTensorDescriptor_t bias_desc;
cudnnActivationDescriptor_t activation_desc;
cudnnTensorDescriptor_t output_desc;
cudnnConvolutionFwdAlgo_t fwd_algo;
cudnnConvolutionBwdDataAlgo_t bwd_data_algo;
Expand Down
51 changes: 48 additions & 3 deletions tests/python/contrib/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,12 @@ def _verify_cudnn_relay(expr):
for param in func.params:
shape = [int(x) for x in param.checked_type.shape]
input_data.append(
(param.name_hint, np.random.uniform(0, 32, size=shape).astype(param.checked_type.dtype))
(
param.name_hint,
np.random.uniform(-32, 32, size=shape).astype(param.checked_type.dtype),
)
)

# Test against CPU reference
cuda_config = (tvm.target.cuda(), tvm.cuda(), cudnn_mod)
cpu_config = (tvm.target.Target("llvm"), tvm.cpu(), mod)
outputs = []
Expand All @@ -484,7 +486,8 @@ def _verify_cudnn_relay(expr):
tvm.testing.assert_allclose(
outputs[0],
outputs[1],
rtol=1e-2,
rtol=1e-3,
atol=30,
)


Expand Down Expand Up @@ -577,5 +580,47 @@ def test_relay_cudnn_conv2d(n, h, w, ci, co, kh, kw, strides, dilation, padding,
_verify_cudnn_relay(conv2d)


@tvm.testing.requires_cuda
@pytest.mark.parametrize(
"n,h,w,ci,co,groups",
[
(1, 16, 20, 8, 16, 1),
(10, 17, 19, 16, 8, 4),
],
)
@pytest.mark.parametrize(
"kh,kw,padding,strides,dilation,dtype",
[
(1, 1, (3, 1, 3, 1), (1, 1), (1, 1), "float32"),
(3, 3, (1, 2), (2, 1), (2, 2), "float16"),
(7, 2, (0, 0), (3, 3), (1, 2), "float64"),
],
)
@pytest.mark.parametrize("activation", [True, False])
def test_relay_cudnn_conv2d_bias_act(
n, h, w, ci, co, kh, kw, strides, dilation, padding, groups, dtype, activation
):
data = tvm.relay.var("data", tvm.relay.TensorType((n, ci, h, w), dtype))
weight = tvm.relay.var("weight", tvm.relay.TensorType((co, ci // groups, kh, kw), dtype))
bias = relay.var("bias", relay.TensorType((co,), dtype))
conv2d = relay.op.nn.conv2d(
data,
weight,
groups=groups,
channels=co,
kernel_size=(kh, kw),
strides=strides,
dilation=dilation,
padding=padding,
data_layout="NCHW",
kernel_layout="OIHW",
)
out = relay.op.nn.bias_add(conv2d, bias)
if activation:
out = relay.op.nn.relu(out)

_verify_cudnn_relay(out)


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))

0 comments on commit e1a0c38

Please sign in to comment.