Skip to content

Commit

Permalink
[CUDA][THRUST] Enforce -libs=thrust to allow thrust offload (#7468)
Browse files Browse the repository at this point in the history
* add contrib/thrust.py

* update cuda strategy

* remove is_thrust_available, update nms, scan, sort and tests

* remove unused import

* trigger CI

* update

* add note on how to enable thrust in ssd tutorial

* add warning

* Revert "update"

This reverts commit c1629b3.

Co-authored-by: masa <masa@pop-os.localdomain>
  • Loading branch information
masahi and masa authored Feb 22, 2021
1 parent cfe88c1 commit 072c469
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 112 deletions.
1 change: 1 addition & 0 deletions apps/topi_recipe/gemm/cuda_gemm_square.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tvm.contrib import nvcc
from tvm.contrib import spirv
import numpy as np
import tvm.testing

TASK = "gemm"
USE_MANUAL_CODE = False
Expand Down
45 changes: 45 additions & 0 deletions python/tvm/contrib/thrust.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Utilities for thrust"""
import logging

from tvm._ffi import get_global_func


def maybe_warn(target, func_name):
if get_global_func(func_name, allow_missing=True) and not "thrust" in target.libs:
logging.warning("TVM is built with thrust but thrust is not used.")
if "thrust" in target.libs and get_global_func(func_name, allow_missing=True) is None:
logging.warning("thrust is requested but TVM is not built with thrust.")


def can_use_thrust(target, func_name):
maybe_warn(target, func_name)
return (
target.kind.name in ["cuda", "nvptx"]
and "thrust" in target.libs
and get_global_func(func_name, allow_missing=True)
)


def can_use_rocthrust(target, func_name):
maybe_warn(target, func_name)
return (
target.kind.name == "rocm"
and "thrust" in target.libs
and get_global_func(func_name, allow_missing=True)
)
18 changes: 5 additions & 13 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.te import SpecializedCondition
from tvm.contrib import nvcc
from tvm._ffi import get_global_func
from tvm.contrib.thrust import can_use_thrust
from .generic import *
from .. import op as _op

Expand Down Expand Up @@ -791,9 +791,7 @@ def scatter_cuda(attrs, inputs, out_type, target):
rank = len(inputs[0].shape)

with SpecializedCondition(rank == 1):
if target.kind.name == "cuda" and get_global_func(
"tvm.contrib.thrust.stable_sort_by_key", allow_missing=True
):
if can_use_thrust(target, "tvm.contrib.thrust.stable_sort_by_key"):
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter_via_sort),
wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort),
Expand Down Expand Up @@ -838,9 +836,7 @@ def sort_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_sort),
name="sort.cuda",
)
if target.kind.name == "cuda" and get_global_func(
"tvm.contrib.thrust.sort", allow_missing=True
):
if can_use_thrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_sort(topi.cuda.sort_thrust),
wrap_topi_schedule(topi.cuda.schedule_sort),
Expand All @@ -859,9 +855,7 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_argsort),
name="argsort.cuda",
)
if target.kind.name == "cuda" and get_global_func(
"tvm.contrib.thrust.sort", allow_missing=True
):
if can_use_thrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_argsort(topi.cuda.argsort_thrust),
wrap_topi_schedule(topi.cuda.schedule_argsort),
Expand All @@ -880,9 +874,7 @@ def topk_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_topk),
name="topk.cuda",
)
if target.kind.name == "cuda" and get_global_func(
"tvm.contrib.thrust.sort", allow_missing=True
):
if can_use_thrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_topk(topi.cuda.topk_thrust),
wrap_topi_schedule(topi.cuda.schedule_topk),
Expand Down
19 changes: 6 additions & 13 deletions python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from tvm import topi
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.te import SpecializedCondition
from tvm._ffi import get_global_func
from tvm.contrib.thrust import can_use_rocthrust

from .generic import *
from .. import op as _op
from .cuda import judge_winograd, naive_schedule
Expand Down Expand Up @@ -223,14 +224,6 @@ def batch_matmul_strategy_rocm(attrs, inputs, out_type, target):
return strategy


def can_use_thrust(target, func_name):
return (
target.kind.name == "rocm"
and "thrust" in target.libs
and get_global_func(func_name, allow_missing=True)
)


@argsort_strategy.register(["rocm"])
def argsort_strategy_cuda(attrs, inputs, out_type, target):
"""argsort rocm strategy"""
Expand All @@ -240,7 +233,7 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_argsort),
name="argsort.rocm",
)
if can_use_thrust(target, "tvm.contrib.thrust.sort"):
if can_use_rocthrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_argsort(topi.cuda.argsort_thrust),
wrap_topi_schedule(topi.cuda.schedule_argsort),
Expand All @@ -264,7 +257,7 @@ def scatter_cuda(attrs, inputs, out_type, target):
rank = len(inputs[0].shape)

with SpecializedCondition(rank == 1):
if can_use_thrust(target, "tvm.contrib.thrust.stable_sort_by_key"):
if can_use_rocthrust(target, "tvm.contrib.thrust.stable_sort_by_key"):
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter_via_sort),
wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort),
Expand All @@ -283,7 +276,7 @@ def sort_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_sort),
name="sort.rocm",
)
if can_use_thrust(target, "tvm.contrib.thrust.sort"):
if can_use_rocthrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_sort(topi.cuda.sort_thrust),
wrap_topi_schedule(topi.cuda.schedule_sort),
Expand All @@ -303,7 +296,7 @@ def topk_strategy_cuda(attrs, inputs, out_type, target):
name="topk.rocm",
)

if can_use_thrust(target, "tvm.contrib.thrust.sort"):
if can_use_rocthrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_topk(topi.cuda.topk_thrust),
wrap_topi_schedule(topi.cuda.schedule_topk),
Expand Down
10 changes: 6 additions & 4 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
"""Non-maximum suppression operator"""
import tvm
from tvm import te

from tvm.contrib.thrust import can_use_thrust, can_use_rocthrust
from tvm.tir import if_then_else
from .sort import argsort, argsort_thrust, is_thrust_available
from .sort import argsort, argsort_thrust
from .scan import exclusive_scan
from ..utils import ceil_div

Expand Down Expand Up @@ -610,8 +610,10 @@ def _get_sorted_indices(data, data_buf, score_index, score_shape):
)

target = tvm.target.Target.current()
# TODO(masahi): Check -libs=thrust option
if target and target.kind.name in ["cuda", "rocm"] and is_thrust_available():
if target and (
can_use_thrust(target, "tvm.contrib.thrust.sort")
or can_use_rocthrust(target, "tvm.contrib.thrust.sort")
):
sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype="int32")
else:
sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32")
Expand Down
13 changes: 5 additions & 8 deletions python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"Scan related operators"
import tvm
from tvm import te
from tvm._ffi import get_global_func
from tvm.contrib.thrust import can_use_thrust, can_use_rocthrust
from ..transform import expand_dims, squeeze, transpose, reshape
from ..utils import ceil_div, swap, prod, get_const_int
from ..math import cast
Expand Down Expand Up @@ -249,11 +249,6 @@ def ir(data, data_ex_scan, reduction):
return reduction


def is_thrust_available():
"""Test if thrust based scan ops are available."""
return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None


def scan_thrust(
data, output_dtype, exclusive=True, return_reduction=False, binop=tvm.tir.generic.add
):
Expand Down Expand Up @@ -352,8 +347,10 @@ def exclusive_scan(

def do_scan(data, output_dtype):
target = tvm.target.Target.current()
# TODO(masahi): Check -libs=thrust option
if target and target.kind.name in ["cuda", "rocm"] and is_thrust_available():
if target and (
can_use_thrust(target, "tvm.contrib.thrust.sum_scan")
or can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan")
):
return scan_thrust(
data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop
)
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..scatter import _verify_scatter_nd_inputs
from ..generic import schedule_extern
from .nms import atomic_add
from .sort import stable_sort_by_key_thrust, is_thrust_available
from .sort import stable_sort_by_key_thrust
from ..utils import prod, ceil_div


Expand Down Expand Up @@ -565,7 +565,6 @@ def scatter_via_sort(cfg, data, indices, updates, axis=0):
if axis < 0:
axis += len(data.shape)
assert axis == 0 and len(data.shape) == 1, "sorting based scatter only supported for 1d input"
assert is_thrust_available(), "Thrust is required for this op"

cfg.add_flop(1) # A dummy value to satisfy AutoTVM

Expand Down
8 changes: 0 additions & 8 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""Sort related operators """
import tvm
from tvm import te
from tvm._ffi import get_global_func

from .injective import schedule_injective_from_existing
from ..transform import strided_slice, transpose
Expand Down Expand Up @@ -879,10 +878,3 @@ def stable_sort_by_key_thrust(keys, values, for_scatter=False):
tag="stable_sort_by_key",
)
return out[0], out[1]


def is_thrust_available():
"""
Test if thrust based sorting ops are available.
"""
return get_global_func("tvm.contrib.thrust.sort", allow_missing=True) is not None
Loading

0 comments on commit 072c469

Please sign in to comment.