Skip to content

Commit

Permalink
Add callback for MirroredStrategy.run() (#61)
Browse files Browse the repository at this point in the history
Fixed wala#92.
  • Loading branch information
khatchad committed Dec 4, 2023
1 parent d1a1d97 commit b96d758
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ public void testTf2()
// once https://github.com/wala/ML/issues/106 is fixed.
testTf2("tf2_test_model_call3.py", "SequentialModel.call", 1, 4, 2);
testTf2("tf2_test_model_call4.py", "SequentialModel.__call__", 1, 4, 2);
testTf2("tf2_test_callbacks.py", "replica_fn", 1, 3, 2);
testTf2("tf2_test_callbacks2.py", "replica_fn", 1, 4, 2);
}

private void testTf2(
Expand Down
26 changes: 26 additions & 0 deletions com.ibm.wala.cast.python.ml/data/tensorflow.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
<new def="estimator" class="Lobject" />
<putfield class="LRoot" field="estimator" fieldType="LRoot" ref="x" value="estimator" />

<new def="distribute" class="Lobject" />
<putfield class="LRoot" field="distribute" fieldType="LRoot" ref="x" value="distribute" />

<new def="nn" class="Lobject" />
<putfield class="LRoot" field="nn" fieldType="LRoot" ref="x" value="nn" />
<new def="random" class="Lobject" />
Expand All @@ -58,6 +61,9 @@
<new def="Estimator" class="Ltensorflow/estimator/Estimator" />
<putfield class="LRoot" field="Estimator" fieldType="LRoot" ref="estimator" value="Estimator" />

<new def="MirroredStrategy" class="Ltensorflow/distribute/MirroredStrategy" />
<putfield class="LRoot" field="MirroredStrategy" fieldType="LRoot" ref="distribute" value="MirroredStrategy" />

<new def="inputs" class="Lobject" />
<putfield class="LRoot" field="inputs" fieldType="LRoot" ref="estimator" value="inputs" />

Expand Down Expand Up @@ -786,6 +792,26 @@
</class>
</package>

<package name="tensorflow/distribute">
<class name="MirroredStrategy" allocatable="true">
<method name="do" descriptor="()LRoot;" numArgs="3" paramNames="self devices cross_device_ops">
<new def="x" class="Ltensorflow/distribute/run/run" />
<putfield class="LRoot" field="run" fieldType="LRoot" ref="self" value="x" />
<return value="arg0" />
</method>
</class>
</package>

<package name="tensorflow/distribute/run">
<class name="run" allocatable="true">
<method name="do" descriptor="()LRoot;" numArgs="3">
<getfield class="LRoot" field="0" fieldType="LRoot" ref="arg2" def="x" />
<call class="LRoot" name="do" descriptor="()LRoot;" type="virtual" arg0="arg1" arg1="x" numArgs="2" def="v" />
<return value="v" />
</method>
</class>
</package>

<package name="tensorflow/app">
<class name="run" allocatable="true">
<method name="do" descriptor="()LRoot;" numArgs="3" paramNames="self main argv">
Expand Down
15 changes: 15 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/distribute/Strategy#example_usage_2.

import tensorflow as tf

tensor_input = tf.constant(3.0)


@tf.function
def replica_fn(input):
return input * 2.0


# Direct call.
result = replica_fn((tensor_input,)[0])
print(result)
16 changes: 16 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_callbacks2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/distribute/Strategy#example_usage_2.

import tensorflow as tf

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
tensor_input = tf.constant(3.0)


@tf.function
def replica_fn(input):
return input * 2.0


# Indirect call to replica_fun().
result = strategy.run(replica_fn, (tensor_input,))
print(result)

0 comments on commit b96d758

Please sign in to comment.