Skip to content

Commit

Permalink
add python definition stub
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 18, 2022
1 parent 3166952 commit c86b128
Showing 1 changed file with 156 additions and 0 deletions.
156 changes: 156 additions & 0 deletions python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,73 @@ def conv_find_algo(
)


def conv_backward_data_find_algo(
tensor_format,
pad,
stride,
dilation,
x_shape,
w_shape,
y_shape,
data_dtype,
conv_dtype,
groups=1,
):
"""Choose the best algo for the given input.
Paramters
---------
tensor_format: int
0: CUDNN_TENSOR_NCHW
1: CUDNN_TENSOR_NHWC
2: CUDNN_TENSOR_NCHW_VECT_C
pad: int or list
padding
stride: int or list
stride
dilation: int or list
dilation
x_shape: list
input shape
w_shape: list
weight shape
y_shape: list
output shape
data_dtype: str
data type
conv_dtype: str
convolution type
groups: int
number of groups
Returns
-------
algo: int
algo chosen by CUDNN
"""
dims = len(x_shape)
assert dims in (4, 5)

pad, stride, dilation, xshape, wshape = _prepare_global_func_params(
dims - 2, pad, stride, dilation, x_shape, w_shape
)
yshape = np.array(y_shape, dtype=np.int32)
func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.backward_data_find_algo")
return func(
tensor_format,
dims - 2,
_get_np_int32_array_handle(pad),
_get_np_int32_array_handle(stride),
_get_np_int32_array_handle(dilation),
_get_np_int32_array_handle(xshape),
_get_np_int32_array_handle(wshape),
_get_np_int32_array_handle(yshape),
data_dtype,
conv_dtype,
groups,
)


def conv_forward(x, w, pad, stride, dilation, conv_mode, tensor_format, algo, conv_dtype, groups=1):
"""Create an extern op that compute 2D or 3D convolution with CuDNN
Expand Down Expand Up @@ -496,6 +563,95 @@ def conv_forward(x, w, pad, stride, dilation, conv_mode, tensor_format, algo, co
)


def conv_backward_data(x, w, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1):
"""Create an extern op that compute 2D or 3D convolution with CuDNN
Parameters
----------
x: Tensor
input feature map
w: Tensor
convolution weight
pad: int or list
padding
stride: int or list
stride
dilation: int or list
dilation
conv_mode: int
0: CUDNN_CONVOLUTION
1: CUDNN_CROSS_CORRELATION
tensor_format: int
0: CUDNN_TENSOR_NCHW
1: CUDNN_TENSOR_NHWC
2: CUDNN_TENSOR_NCHW_VECT_C
conv_dtype: str
convolution type
groups: int
the number of groups
Returns
-------
y: Tensor
The result tensor
"""
dims = len(x.shape)
assert dims in (4, 5)

conv_dtype = x.dtype if conv_dtype is None else conv_dtype
pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation)

x_shape = list(x.shape)

assert isinstance(x.shape[0], tvm.tir.expr.IntImm), "Dynamic batch is not supported for cudnn conv2d backwad data yet."
oshape = conv_output_shape(
tensor_format,
pad,
stride,
dilation,
x_shape,
list(w.shape),
x.dtype,
conv_dtype,
groups,
)
algo = conv_backward_data_find_algo(
tensor_format,
pad,
stride,
dilation,
list(x.shape),
list(w.shape),
oshape,
x.dtype,
conv_dtype,
groups,
)

return te.extern(
oshape,
[x, w],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.conv2d.backward_data",
conv_mode,
tensor_format,
algo,
pad[0],
pad[1],
stride[0],
stride[1],
dilation[0],
dilation[1],
ins[0],
ins[1],
outs[0],
conv_dtype,
groups,
),
name="y",
)


def softmax(x, axis=-1):
"""Compute softmax using CuDNN
Expand Down

0 comments on commit c86b128

Please sign in to comment.