diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java index 4213e47ea..7895ae6d7 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java @@ -2,6 +2,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import com.ibm.wala.cast.python.client.PythonAnalysisEngine; import com.ibm.wala.cast.python.ipa.callgraph.PythonSSAPropagationCallGraphBuilder; @@ -201,15 +202,15 @@ public void testTf2() 0, 0); // NOTE: Change to testTf2("tf2_test_dataset.py", "add", 2, 3, 2, 3) once // https://github.com/wala/ML/issues/89 is fixed. - testTf2("tf2_test_model_call.py", "SequentialModel.__call__", 1, 4, 2); + testTf2("tf2_test_model_call.py", "SequentialModel.__call__", 1, 4, 3); testTf2( "tf2_test_model_call2.py", "SequentialModel.call", 0, 2); // NOTE: Change to testTf2("tf2_test_model_call2.py", "SequentialModel.call", 1, 4, 2) // once https://github.com/wala/ML/issues/106 is fixed. - testTf2("tf2_test_model_call3.py", "SequentialModel.call", 1, 4, 2); - testTf2("tf2_test_model_call4.py", "SequentialModel.__call__", 1, 4, 2); + testTf2("tf2_test_model_call3.py", "SequentialModel.call", 1, 4, 3); + testTf2("tf2_test_model_call4.py", "SequentialModel.__call__", 1, 4, 3); } private void testTf2( @@ -298,7 +299,11 @@ private void testTf2( assertEquals(expectedTensorParameterValueNumbers.length, actualValueNumberSet.size()); Arrays.stream(expectedTensorParameterValueNumbers) - .forEach(ev -> actualValueNumberSet.contains(ev)); + .forEach( + ev -> + assertTrue( + "Expecting " + actualValueNumberSet + " to contain " + ev + ".", + actualValueNumberSet.contains(ev))); // get the tensor variables for the function. Set functionTensors =