From 168f7b7d24658b21deaf57a81783a66d7fcfe727 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sat, 16 Apr 2022 14:15:07 -0700 Subject: [PATCH] [pytorch] Allows load libtroch from pip installation package Change-Id: Ic87302a880564027bed1f213088baac1ee543d9e --- .../java/ai/djl/pytorch/jni/LibUtils.java | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java index 0deff350d9f..105ec60af4d 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java @@ -15,6 +15,7 @@ import ai.djl.util.ClassLoaderUtils; import ai.djl.util.Platform; import ai.djl.util.Utils; +import ai.djl.util.cuda.CudaUtils; import java.io.File; import java.io.IOException; import java.io.InputStream; @@ -102,6 +103,7 @@ private static void loadLibTorch(LibTorch libTorch) { // PyTorch 1.8.1 libtorch_cpu.dylib cannot be loaded individually return; } + boolean isCuda = libTorch.flavor.contains("cu"); List deferred = Arrays.asList( System.mapLibraryName("fbgemm"), @@ -118,6 +120,12 @@ private static void loadLibTorch(LibTorch libTorch) { paths.filter( path -> { String name = path.getFileName().toString(); + if (!isCuda + && name.contains("nvrtc") + && name.contains("cudart") + && name.contains("nvTools")) { + return false; + } return !loadLater.contains(name) && Files.isRegularFile(path) && !name.endsWith(JNI_LIB_NAME) @@ -139,6 +147,14 @@ private static void loadLibTorch(LibTorch libTorch) { loadNativeLibrary(libDir.resolve("cudnn64_7.dll").toString()); } + if (!isCuda) { + deferred = + Arrays.asList( + System.mapLibraryName("fbgemm"), + System.mapLibraryName("torch_cpu"), + System.mapLibraryName("torch")); + } + for (String dep : deferred) { Path path = libDir.resolve(dep); if (Files.exists(path)) { @@ -289,6 +305,7 @@ private static LibTorch copyNativeLibraryFromClasspath(Platform platform) { if (Files.exists(path)) { return new LibTorch(dir.toAbsolutePath(), platform, flavor); } + Utils.deleteQuietly(dir); Matcher m = VERSION_PATTERN.matcher(version); if (!m.matches()) { @@ -498,7 +515,11 @@ private static final class LibTorch { if (flavor == null) { flavor = System.getProperty("PYTORCH_FLAVOR"); if (flavor == null) { - flavor = "cpu-precxx11"; + if (CudaUtils.getGpuCount() > 0) { + flavor = "cu" + CudaUtils.getCudaVersionString() + "-precxx11"; + } else { + flavor = "cpu-precxx11"; + } } } }