From e66e836d78e0d08bd2a1fb9f58195f6988640374 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Thu, 5 May 2022 11:42:34 -0700 Subject: [PATCH] Fix testFreezeParameters for multi-gpu (#1623) fixes #1606 --- .../training/GradientCollectorIntegrationTest.java | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java b/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java index 9aaa984a768..320023c9445 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java @@ -39,6 +39,7 @@ import ai.djl.training.loss.Loss; import ai.djl.training.optimizer.Optimizer; import ai.djl.training.tracker.Tracker; +import ai.djl.translate.Batchifier; import ai.djl.translate.TranslateException; import java.io.IOException; import org.testng.Assert; @@ -106,7 +107,14 @@ public void testFreezeParameters() { NDArray labels = manager.arange(100.0f).reshape(new Shape(10, 10)); Batch batch = new Batch( - manager, new NDList(data), new NDList(labels), 1, null, null, 0, 1); + manager, + new NDList(data), + new NDList(labels), + 1, + Batchifier.STACK, + Batchifier.STACK, + 0, + 1); EasyTrain.trainBatch(trainer, batch); trainer.step();