diff --git a/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java b/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java index 9aaa984a768..320023c9445 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java @@ -39,6 +39,7 @@ import ai.djl.training.loss.Loss; import ai.djl.training.optimizer.Optimizer; import ai.djl.training.tracker.Tracker; +import ai.djl.translate.Batchifier; import ai.djl.translate.TranslateException; import java.io.IOException; import org.testng.Assert; @@ -106,7 +107,14 @@ public void testFreezeParameters() { NDArray labels = manager.arange(100.0f).reshape(new Shape(10, 10)); Batch batch = new Batch( - manager, new NDList(data), new NDList(labels), 1, null, null, 0, 1); + manager, + new NDList(data), + new NDList(labels), + 1, + Batchifier.STACK, + Batchifier.STACK, + 0, + 1); EasyTrain.trainBatch(trainer, batch); trainer.step();