Skip to content

Commit

Permalink
[api] Loads native engine in deterministic order (#3300)
Browse files Browse the repository at this point in the history
Fixes: #3296
  • Loading branch information
frankfliu authored Jul 5, 2024
1 parent 725851b commit 90d6019
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 19 deletions.
48 changes: 34 additions & 14 deletions api/src/main/java/ai/djl/util/Platform.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.Properties;

/**
Expand Down Expand Up @@ -71,32 +73,50 @@ public static Platform detectPlatform(String engine) {
}

Platform systemPlatform = Platform.fromSystem(engine);
Platform placeholder = null;
List<Platform> availablePlatforms = new ArrayList<>();
while (urls.hasMoreElements()) {
URL url = urls.nextElement();
Platform platform = Platform.fromUrl(url);
platform.apiVersion = systemPlatform.apiVersion;
if (platform.isPlaceholder()) {
placeholder = platform;
availablePlatforms.add(platform);
} else if (platform.matches(systemPlatform)) {
logger.info("Found matching platform from: {}", url);
return platform;
availablePlatforms.add(platform);
} else {
logger.info("Ignore mismatching platform from: {}", url);
}
}
if (placeholder != null) {
logger.info("Found placeholder platform from: {}", placeholder);
return placeholder;
}

if (systemPlatform.version == null) {
throw new AssertionError("No " + engine + " version found in property file.");
}
if (systemPlatform.apiVersion == null) {
throw new AssertionError("No " + engine + " djl_version found in property file.");
if (availablePlatforms.isEmpty()) {
if (systemPlatform.version == null) {
throw new AssertionError("No " + engine + " version found in property file.");
}
if (systemPlatform.apiVersion == null) {
throw new AssertionError("No " + engine + " djl_version found in property file.");
}
return systemPlatform;
} else if (availablePlatforms.size() == 1) {
Platform ret = availablePlatforms.get(0);
if (ret.isPlaceholder()) {
logger.info("Found placeholder platform from: {}", ret);
}
return ret;
}
return systemPlatform;
availablePlatforms.sort(
(o1, o2) -> {
if (o1.isPlaceholder()) {
return 1;
} else if (o2.isPlaceholder()) {
return -1;
}
// cu121-precx11 > cu121 > cu118-precss11 > cpu-precxx11 > cpu
int ret = o2.getFlavor().compareTo(o1.getFlavor());
if (ret == 0) {
return o2.getVersion().compareTo(o1.getVersion());
}
return ret;
});
return availablePlatforms.get(0);
}

/**
Expand Down
17 changes: 12 additions & 5 deletions api/src/main/java/ai/djl/util/cuda/CudaUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,13 @@ public static String getCudaVersionString() {
*/
public static String getComputeCapability(int device) {
if (Boolean.getBoolean("ai.djl.util.cuda.fork")) {
String[] ret = execute(device);
if (ret.length != 3) {
throw new IllegalArgumentException(ret[0]);
if (gpuInfo == null) { // NOPMD
gpuInfo = execute(-1);
}
return ret[0];
if (device >= gpuInfo.length - 2) {
throw new IllegalArgumentException("Invalid device: " + device);
}
return gpuInfo[device + 2];
}

if (LIB == null) {
Expand Down Expand Up @@ -214,7 +216,12 @@ public static void main(String[] args) {
return;
}
int cudaVersion = getCudaVersion();
System.out.println(gpuCount + "," + cudaVersion);
StringBuilder sb = new StringBuilder();
sb.append(gpuCount).append(',').append(cudaVersion);
for (int i = 0; i < gpuCount; ++i) {
sb.append(',').append(getComputeCapability(i));
}
System.out.println(sb);
return;
}
try {
Expand Down
64 changes: 64 additions & 0 deletions api/src/test/java/ai/djl/util/PlatformTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,23 @@
*/
package ai.djl.util;

import ai.djl.util.cuda.CudaUtils;

import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.Test;

import java.io.BufferedWriter;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.jar.JarEntry;
import java.util.jar.JarOutputStream;

public class PlatformTest {

Expand Down Expand Up @@ -94,6 +101,43 @@ public void testPlatform() throws IOException {
Assert.assertFalse(platform.matches(system));
}

@Test
public void testDetectPlatform() throws IOException, ReflectiveOperationException {
Path dir = Paths.get("build/tmp/");
Files.createDirectories(dir);
Platform system = Platform.fromSystem();
String classifier = system.getClassifier();
createZipFile(0, "1.0", "cpu", classifier, true);
createZipFile(1, "1.0", "cpu", classifier, false);
createZipFile(2, "1.0", "cpu-precxx11", classifier, false);
createZipFile(3, "1.0", "cu117", classifier, false);
createZipFile(4, "1.0", "cu117-precxx11", classifier, false);
createZipFile(5, "1.0", "cu999", classifier, false);
createZipFile(6, "1.0", "cu999-precxx11", classifier, false);
createZipFile(7, "99.99", "cu999-precxx11", classifier, false);
System.setProperty("ai.djl.util.cuda.fork", "true");
try {
String[] gpuInfo = new String[] {"1", "99990", "90"};
Field field = CudaUtils.class.getDeclaredField("gpuInfo");
field.setAccessible(true);
field.set(null, gpuInfo);
URL[] urls = new URL[8];
for (int i = 0; i < 8; ++i) {
urls[i] = dir.resolve(i + ".jar").toUri().toURL();
}
URLClassLoader cl = new URLClassLoader(urls);
Thread.currentThread().setContextClassLoader(cl);

Platform detected = Platform.detectPlatform("pytorch");
Assert.assertEquals(detected.getFlavor(), "cu999-precxx11");

field.set(null, null);
} finally {
System.clearProperty("ai.djl.util.cuda.fork");
Thread.currentThread().setContextClassLoader(null);
}
}

private URL createPropertyFile(String content) throws IOException {
Path dir = Paths.get("build/tmp/testFile/");
Files.createDirectories(dir);
Expand All @@ -104,4 +148,24 @@ private URL createPropertyFile(String content) throws IOException {
}
return file.toUri().toURL();
}

private void createZipFile(
int index, String version, String flavor, String classifier, boolean placeHolder)
throws IOException {
Path file = Paths.get("build/tmp/" + index + ".jar");
try (JarOutputStream jos = new JarOutputStream(Files.newOutputStream(file))) {
JarEntry entry = new JarEntry("native/lib/pytorch.properties");
jos.putNextEntry(entry);
if (placeHolder) {
jos.write("placeholder=true\nversion=2.3.1".getBytes(StandardCharsets.UTF_8));
} else {
jos.write("version=".getBytes(StandardCharsets.UTF_8));
jos.write(version.getBytes(StandardCharsets.UTF_8));
jos.write("\nflavor=".getBytes(StandardCharsets.UTF_8));
jos.write(flavor.getBytes(StandardCharsets.UTF_8));
jos.write("\nclassifier=".getBytes(StandardCharsets.UTF_8));
jos.write(classifier.getBytes(StandardCharsets.UTF_8));
}
}
}
}
2 changes: 2 additions & 0 deletions api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ public void testCudaUtilsWithFork() {
System.setProperty("ai.djl.util.cuda.fork", "true");
try {
testCudaUtils();
CudaUtils.main(new String[0]);
CudaUtils.main(new String[] {"-1"});
} finally {
System.clearProperty("ai.djl.util.cuda.fork");
}
Expand Down

0 comments on commit 90d6019

Please sign in to comment.