Skip to content

Commit

Permalink
[api] Fixes nightly tests on GPU machine
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Jul 4, 2024
1 parent 725851b commit 8c37cc6
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 31 deletions.
25 changes: 0 additions & 25 deletions api/src/test/java/ai/djl/util/SecurityManagerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,11 @@
*/
package ai.djl.util;

import ai.djl.util.cuda.CudaUtils;

import org.testng.Assert;
import org.testng.annotations.AfterTest;
import org.testng.annotations.BeforeTest;
import org.testng.annotations.Test;

import java.io.FilePermission;
import java.security.Permission;

public class SecurityManagerTest {
Expand Down Expand Up @@ -59,26 +56,4 @@ public void checkPermission(Permission perm) {
Assert.assertEquals(Utils.getenv("TEST", "test"), "test");
Assert.assertEquals(Utils.getenv().size(), 0);
}

@Test
public void testCudaUtils() {
// Disable access to the cudart library files
SecurityManager sm =
new SecurityManager() {
@Override
public void checkPermission(Permission perm) {
if (perm instanceof FilePermission && perm.getName().contains("cudart")) {
throw new SecurityException(
"Don't have permission to read file: " + perm.getName());
}
}
};
System.setSecurityManager(sm);
try {
Assert.assertFalse(CudaUtils.hasCuda());
Assert.assertEquals(CudaUtils.getGpuCount(), 0);
} finally {
System.setSecurityManager(null);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class MaskDetectionTest {
@Test
public void testMaskDetection() throws ModelException, TranslateException, IOException {
TestRequirements.linux();
TestRequirements.notGpu();

DetectedObjects result = MaskDetection.predict();
logger.info("{}", result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public void testObjectDetection() throws ModelException, TranslateException, IOE
// throttling and will fail the test.
TestRequirements.linux();
TestRequirements.nightly();
TestRequirements.notGpu();

DetectedObjects result = ObjectDetectionWithTensorflowSavedModel.predict();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class SuperResolutionTest {
@Test
public void testSuperResolution() throws ModelException, TranslateException, IOException {
TestRequirements.linux();
TestRequirements.notGpu();

String imagePath = "src/test/resources/";
Image fox = ImageFactory.getInstance().fromFile(Paths.get(imagePath + "fox.png"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public class UniversalSentenceEncoderTest {
public void testSentimentAnalysis() throws ModelException, TranslateException, IOException {
TestRequirements.linux();
TestRequirements.nightly();
TestRequirements.notGpu();

List<String> inputs = new ArrayList<>();
inputs.add("The quick brown fox jumps over the lazy dog.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class Yolov8DetectionTest {

@Test
public void testYolov8Detection() throws ModelException, TranslateException, IOException {
TestRequirements.linux();
TestRequirements.notGpu();

DetectedObjects result = Yolov8Detection.predict();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class TrainResNetTest {
@Test
public void testTrainResNet() throws ModelException, IOException, TranslateException {
TestRequirements.nightly();
TestRequirements.gpu("PyTorch", 1);

// Limit max 4 gpu for cifar10 training to make it converge faster.
// and only train 10 batch for unit test.
Expand All @@ -47,7 +48,7 @@ public void testTrainResNetImperativeNightly()
throws ModelException, IOException, TranslateException {
TestRequirements.linux();
TestRequirements.nightly();
TestRequirements.gpu("PyTorch");
TestRequirements.gpu("PyTorch", 4);

// Limit max 4 gpu for cifar10 training to make it converge faster.
// and only train 10 batch for unit test.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public void testTrainSentimentAnalysis()
throws ModelException, TranslateException, IOException {
TestRequirements.linux();
TestRequirements.nightly();
TestRequirements.gpu("MXNet");
TestRequirements.gpu("MXNet", 1);

String[] args = {"-e", "1", "-g", "1", "--engine", "MXNet"};
TrainSentimentAnalysis.runExample(args);
Expand Down
14 changes: 11 additions & 3 deletions examples/src/test/java/ai/djl/testing/TestRequirements.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ai.djl.engine.Engine;
import ai.djl.engine.EngineException;

import ai.djl.util.cuda.CudaUtils;
import org.testng.SkipException;

import java.util.Arrays;
Expand Down Expand Up @@ -76,9 +77,16 @@ public static void linux() {
}

/** Requires a test have at least one gpu. */
public static void gpu(String engine) {
if (Engine.getEngine(engine).getGpuCount() == 0) {
throw new SkipException("This test requires a GPU to run");
public static void gpu(String engine, int numGpu) {
if (Engine.getEngine(engine).getGpuCount() < numGpu) {
throw new SkipException("This test requires " + numGpu + " GPUs to run");
}
}

/** Avoid OOM on GPUs with multiple engines. */
public static void notGpu() {
if (CudaUtils.getGpuCount() > 0) {
throw new SkipException("This test requires CPU only machine to run");
}
}
}

0 comments on commit 8c37cc6

Please sign in to comment.