From 239c908a494dbc7b75688348436de21dda258cdc Mon Sep 17 00:00:00 2001 From: Kexin Date: Fri, 6 May 2022 19:39:28 -0700 Subject: [PATCH] assertion_err --- .../java/ai/djl/integration/tests/ndarray/NDIndexTest.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 41db19eda6b..1d677b9c28b 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -60,7 +60,9 @@ public void testGather() { // In the dependencies, changing runtimeOnly to api however will remedy the problem. // TODO: remove this when gradle problem is fixed. TestRequirements.notWindows(); - try (NDManager manager = NDManager.newBaseManager()) { + Engine engine = Engine.getEngine("PyTorch"); + try (NDManager manager = engine.newBaseManager()) { +// try (NDManager manager = NDManager.newBaseManager()) { NDArray arr = manager.arange(20f).reshape(-1, 4); NDArray index = manager.create(new long[] {0, 0, 2, 1, 1, 2}, new Shape(3, 2)); NDArray actual = arr.gather(index, 1); @@ -73,7 +75,7 @@ public void testGather() { public void testTake() { Engine engine = Engine.getEngine("PyTorch"); try (NDManager manager = engine.newBaseManager()) { - NDArray arr = manager.arange(6f).reshape(-1, 3); + NDArray arr = manager.arange(1,7f).reshape(-1, 3); NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); NDArray actual = arr.take(index); NDArray expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2));