Skip to content

Commit

Permalink
[pytorch] Allows load libtroch from pip installation package
Browse files Browse the repository at this point in the history
Change-Id: Ic87302a880564027bed1f213088baac1ee543d9e
  • Loading branch information
frankfliu committed Apr 16, 2022
1 parent 0d06682 commit 168f7b7
Showing 1 changed file with 22 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.Platform;
import ai.djl.util.Utils;
import ai.djl.util.cuda.CudaUtils;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
Expand Down Expand Up @@ -102,6 +103,7 @@ private static void loadLibTorch(LibTorch libTorch) {
// PyTorch 1.8.1 libtorch_cpu.dylib cannot be loaded individually
return;
}
boolean isCuda = libTorch.flavor.contains("cu");
List<String> deferred =
Arrays.asList(
System.mapLibraryName("fbgemm"),
Expand All @@ -118,6 +120,12 @@ private static void loadLibTorch(LibTorch libTorch) {
paths.filter(
path -> {
String name = path.getFileName().toString();
if (!isCuda
&& name.contains("nvrtc")
&& name.contains("cudart")
&& name.contains("nvTools")) {
return false;
}
return !loadLater.contains(name)
&& Files.isRegularFile(path)
&& !name.endsWith(JNI_LIB_NAME)
Expand All @@ -139,6 +147,14 @@ private static void loadLibTorch(LibTorch libTorch) {
loadNativeLibrary(libDir.resolve("cudnn64_7.dll").toString());
}

if (!isCuda) {
deferred =
Arrays.asList(
System.mapLibraryName("fbgemm"),
System.mapLibraryName("torch_cpu"),
System.mapLibraryName("torch"));
}

for (String dep : deferred) {
Path path = libDir.resolve(dep);
if (Files.exists(path)) {
Expand Down Expand Up @@ -289,6 +305,7 @@ private static LibTorch copyNativeLibraryFromClasspath(Platform platform) {
if (Files.exists(path)) {
return new LibTorch(dir.toAbsolutePath(), platform, flavor);
}
Utils.deleteQuietly(dir);

Matcher m = VERSION_PATTERN.matcher(version);
if (!m.matches()) {
Expand Down Expand Up @@ -498,7 +515,11 @@ private static final class LibTorch {
if (flavor == null) {
flavor = System.getProperty("PYTORCH_FLAVOR");
if (flavor == null) {
flavor = "cpu-precxx11";
if (CudaUtils.getGpuCount() > 0) {
flavor = "cu" + CudaUtils.getCudaVersionString() + "-precxx11";
} else {
flavor = "cpu-precxx11";
}
}
}
}
Expand Down

0 comments on commit 168f7b7

Please sign in to comment.