Skip to content

Commit

Permalink
assertion_err
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed May 7, 2022
1 parent dc1b7a8 commit 239c908
Showing 1 changed file with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ public void testGather() {
// In the dependencies, changing runtimeOnly to api however will remedy the problem.
// TODO: remove this when gradle problem is fixed.
TestRequirements.notWindows();
try (NDManager manager = NDManager.newBaseManager()) {
Engine engine = Engine.getEngine("PyTorch");
try (NDManager manager = engine.newBaseManager()) {
// try (NDManager manager = NDManager.newBaseManager()) {
NDArray arr = manager.arange(20f).reshape(-1, 4);
NDArray index = manager.create(new long[] {0, 0, 2, 1, 1, 2}, new Shape(3, 2));
NDArray actual = arr.gather(index, 1);
Expand All @@ -73,7 +75,7 @@ public void testGather() {
public void testTake() {
Engine engine = Engine.getEngine("PyTorch");
try (NDManager manager = engine.newBaseManager()) {
NDArray arr = manager.arange(6f).reshape(-1, 3);
NDArray arr = manager.arange(1,7f).reshape(-1, 3);
NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2));
NDArray actual = arr.take(index);
NDArray expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2));
Expand Down

0 comments on commit 239c908

Please sign in to comment.