Skip to content

Commit

Permalink
Fix testFreezeParameters for multi-gpu (#1623)
Browse files Browse the repository at this point in the history
fixes #1606
  • Loading branch information
zachgk authored May 5, 2022
1 parent 0b5fee8 commit e66e836
Showing 1 changed file with 9 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.Tracker;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import org.testng.Assert;
Expand Down Expand Up @@ -106,7 +107,14 @@ public void testFreezeParameters() {
NDArray labels = manager.arange(100.0f).reshape(new Shape(10, 10));
Batch batch =
new Batch(
manager, new NDList(data), new NDList(labels), 1, null, null, 0, 1);
manager,
new NDList(data),
new NDList(labels),
1,
Batchifier.STACK,
Batchifier.STACK,
0,
1);
EasyTrain.trainBatch(trainer, batch);
trainer.step();

Expand Down

0 comments on commit e66e836

Please sign in to comment.