Skip to content

Commit

Permalink
feat(//py): Use TensorRT to fill in .so libraries automatically if
Browse files Browse the repository at this point in the history
possible

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Jul 22, 2022
1 parent 1625cd3 commit 7680456
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 6 deletions.
17 changes: 13 additions & 4 deletions py/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@

JETPACK_VERSION = None

__version__ = '1.2.0a0'
FX_ONLY = False

__version__ = '1.2.0a0'
__cuda_version__ = '11.3'
__cudnn_version__ = '8.2'
__tensorrt_version__ = '8.2'

def get_git_revision_short_hash() -> str:
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip()

Expand All @@ -51,8 +55,10 @@ def get_git_revision_short_hash() -> str:
JETPACK_VERSION = "4.5"
elif version == "4.6":
JETPACK_VERSION = "4.6"
elif version == "5.0":
JETPACK_VERSION = "4.6"
if not JETPACK_VERSION:
warnings.warn("Assuming jetpack version to be 4.6, if not use the --jetpack-version option")
warnings.warn("Assuming jetpack version to be 4.6 or greater, if not use the --jetpack-version option")
JETPACK_VERSION = "4.6"


Expand Down Expand Up @@ -103,7 +109,7 @@ def build_libtorchtrt_pre_cxx11_abi(develop=True, use_dist_dir=True, cxx11_abi=F
print("Jetpack version: 4.5")
elif JETPACK_VERSION == "4.6":
cmd.append("--platforms=//toolchains:jetpack_4.6")
print("Jetpack version: 4.6")
print("Jetpack version: >=4.6")

print("building libtorchtrt")
status_code = subprocess.run(cmd).returncode
Expand All @@ -118,7 +124,10 @@ def gen_version_file():

with open(dir_path + '/torch_tensorrt/_version.py', 'w') as f:
print("creating version file")
f.write("__version__ = \"" + __version__ + '\"')
f.write("__version__ = \"" + __version__ + '\"\n')
f.write("__cuda_version__ = \"" + __cuda_version__ + '\"\n')
f.write("__cudnn_version__ = \"" + __cudnn_version__ + '\"\n')
f.write("__tensorrt_version__ = \"" + __tensorrt_version__ + '\"\n')


def copy_libtorchtrt(multilinux=False):
Expand Down
79 changes: 77 additions & 2 deletions py/torch_tensorrt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,88 @@
import ctypes
import glob
import os
import sys
import platform
import warnings
from torch_tensorrt._version import __version__, __cuda_version__, __cudnn_version__, __tensorrt_version__


if sys.version_info < (3,):
raise Exception("Python 2 has reached end-of-life and is not supported by Torch-TensorRT")

import ctypes
def _parse_semver(version):
split = version.split(".")
if len(split) < 3:
split.append("")

return {
"major": split[0],
"minor": split[1],
"patch": split[2]
}

def _find_lib(name, paths):
for path in paths:
libpath = os.path.join(path, name)
if os.path.isfile(libpath):
return libpath

raise FileNotFoundError(
f"Could not find {name}\n Search paths: {paths}"
)

try:
import tensorrt
except:
cuda_version = _parse_semver(__cuda_version__)
cudnn_version = _parse_semver(__cudnn_version__)
tensorrt_version = _parse_semver(__tensorrt_version__)

CUDA_MAJOR = cuda_version["major"]
CUDNN_MAJOR = cudnn_version["major"]
TENSORRT_MAJOR = tensorrt_version["major"]

if sys.platform.startswith("win"):
WIN_LIBS = [
f"cublas64_{CUDA_MAJOR}.dll",
f"cublasLt64_{CUDA_MAJOR}.dll",
f"cudnn64_{CUDNN_MAJOR}.dll",
"nvinfer.dll",
"nvinfer_plugin.dll",
]

WIN_PATHS = os.environ["PATH"].split(os.path.pathsep)


for lib in WIN_LIBS:
ctypes.CDLL(_find_lib(lib, WIN_PATHS))

elif sys.platform.startswith("linux"):
LINUX_PATHS = [
"/usr/local/cuda/lib64",
] + os.environ["LD_LIBRARY_PATH"].split(os.path.pathsep)

if platform.uname().processor == "x86_64":
LINUX_PATHS += [
"/usr/lib/x86_64-linux-gnu",
]

elif platform.uname().processor == "aarch64":
LINUX_PATHS += [
"/usr/lib/aarch64-linux-gnu"
]

LINUX_LIBS = [
f"libcudnn.so.{CUDNN_MAJOR}",
f"libnvinfer.so.{TENSORRT_MAJOR}",
f"libnvinfer_plugin.so.{TENSORRT_MAJOR}",
]

for lib in LINUX_LIBS:
ctypes.CDLL(_find_lib(lib, LINUX_PATHS))

import torch

from torch_tensorrt._version import __version__
from torch_tensorrt._compile import *
from torch_tensorrt._util import *
from torch_tensorrt import ts
Expand Down

0 comments on commit 7680456

Please sign in to comment.