Skip to content

Commit

Permalink
fix(materialize): onnx loading with torch model available (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
ganler authored Mar 19, 2024
1 parent 4061eb7 commit cff46a5
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 26 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ jobs:
cd doc && python bug_summary.py
- name: Test core
run: |
pytest tests/core
pytest -x tests/core
- name: Test PyTorch
run: |
pip install -r requirements/sys/torch.txt --pre --upgrade
pip install -r requirements/sys/onnx.txt --pre --upgrade
pip install -r requirements/sys/tvm.txt --pre --upgrade
pip install -r requirements/sys/onnxruntime.txt --pre --upgrade
pytest tests/torch
pytest -x tests/torch
yes | python nnsmith/cli/model_gen.py debug.viz=true model.type=torch mgen.method=symbolic
yes | python nnsmith/cli/model_gen.py debug.viz=true model.type=torch mgen.method=symbolic-cinit
yes | python nnsmith/cli/model_gen.py debug.viz=true model.type=torch backend.type="pt2 backend@inductor" mgen.method=concolic
Expand All @@ -42,20 +42,20 @@ jobs:
yes | python nnsmith/cli/model_gen.py model.type=torch mgen.method=symbolic-cinit mgen.rank_choices="[4]" mgen.dtype_choices="[f32]" mgen.include="[core.NCHWConv2d, core.ReLU]" mgen.patch_requires=./tests/mock/requires_patch.py backend.type=pt2 mgen.grad_check=true
- name: Test ONNX + ONNXRuntime
run: |
pytest tests/onnxruntime
pytest -x tests/onnxruntime
yes | python nnsmith/cli/model_gen.py model.type=onnx mgen.method=symbolic
yes | python nnsmith/cli/model_gen.py model.type=onnx backend.type=onnxruntime mgen.method=concolic
python nnsmith/cli/model_exec.py model.type=onnx backend.type=onnxruntime model.path=nnsmith_output/model.onnx
- name: Test ONNX + TVM
run: |
pytest tests/tvm
pytest -x tests/tvm
- name: Test ONNX + TRT
run: |
pytest tests/tensorrt
pytest -x tests/tensorrt
- name: Test TensorFlow
run: |
pip install -r requirements/sys/tensorflow.txt --pre --upgrade
pytest tests/tensorflow --log-cli-level=DEBUG
pytest -x tests/tensorflow --log-cli-level=DEBUG
yes | python nnsmith/cli/model_gen.py model.type=tensorflow mgen.method=symbolic
python nnsmith/cli/model_exec.py model.type=tensorflow backend.type=xla model.path=nnsmith_output/model/
yes | python nnsmith/cli/model_gen.py model.type=tensorflow mgen.method=concolic
Expand Down
2 changes: 1 addition & 1 deletion experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ Next, we will run NNSmith to dump a bunch of random test-cases for an SUT (say P
>
> ```shell
> # PyTorch
> pip install --extra-index-url https://download.pytorch.org/whl/nightly/cpu --pre torch
> pip install --index-url https://download.pytorch.org/whl/nightly/cpu --pre torch
> # TensorFlow
> pip install tf-nightly
> ```
Expand Down
2 changes: 1 addition & 1 deletion nnsmith/materialize/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def load(cls, path: PathLike) -> "ONNXModel":
# FIXME: missing key(s) in state_dict: "mlist.0.data", "mlist.1.data".
if os.path.exists(torch_path):
ret.with_torch = True
ret.torch_model = cls.PTType.load(torch_path)
ret.torch_model = cls.PTType.load(torch_path).torch_model
ret.full_input_like = ret.torch_model.input_like
ret.full_output_like = ret.torch_model.output_like

Expand Down
28 changes: 13 additions & 15 deletions nnsmith/materialize/torch/parse.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import operator
from typing import Any, Dict, List, Union, cast
from typing import Any, Dict, List, cast

import torch
import torch._dynamo as dynamo
import torch.fx as fx
import torch.nn as nn
import torch.utils._pytree as pytree
Expand All @@ -22,27 +21,26 @@ def run_node(self, n: fx.node.Node) -> Any:


def parse(model: nn.Module, *example_args: List[torch.Tensor]) -> GraphIR:
gm: fx.GraphModule = dynamo.export(model, *example_args)[0]
gm: fx.GraphModule = fx.symbolic_trace(model)
# store shape info on nodes
sp = PropInterpreter(gm)
sp.run(*example_args)

def load_args(args: Union[List, Dict[str, Any]]) -> Union[List, Dict[str, Any]]:
"""
Map nodes to their outputs while keeping structures and other values the same.
"""
return torch.fx.graph.map_arg(args, lambda n: n.meta["res"])

named_modules = dict(gm.named_modules())
ir = GraphIR()
name2retvals: Dict[str, List[str]] = {}
for i_node, node in enumerate(gm.graph.nodes):
for node in gm.graph.nodes:
node = cast(fx.node.Node, node)
if node.op == "placeholder":
iexpr = InstExpr(Input(dim=len(node.meta["res"].shape)), [])
input_node = Input(dim=len(node.meta["res"].shape))
input_node.abs_tensor = AbsTensor(
shape=list(node.meta["res"].shape),
dtype=DType.from_torch(node.meta["res"].dtype),
)
iexpr = InstExpr(input_node, [])
else:
args_flatten, args_treespec = pytree.tree_flatten(node.args)
kwargs_flatten, kwargs_treespec = pytree.tree_flatten(node.kwargs)
args_flatten, _ = pytree.tree_flatten(node.args)
kwargs_flatten, _ = pytree.tree_flatten(node.kwargs)
input_nodes = [
a
for a in (args_flatten + kwargs_flatten)
Expand All @@ -67,8 +65,8 @@ def load_args(args: Union[List, Dict[str, Any]]) -> Union[List, Dict[str, Any]]:
pytree.tree_flatten(node.meta["res"])[0],
)
)
nodes2empty = (
lambda n: ConcreteOp.empty if isinstance(n, fx.node.Node) else n
nodes2empty = lambda n: (
ConcreteOp.empty if isinstance(n, fx.node.Node) else n
)
args_wo_nodes = pytree.tree_map(nodes2empty, node.args)
kwargs_wo_nodes = pytree.tree_map(nodes2empty, node.kwargs)
Expand Down
2 changes: 1 addition & 1 deletion requirements/sys/torch.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# TODO(@ganler): make other platform/device distribution also work.
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
--index-url https://download.pytorch.org/whl/nightly/cpu
--pre
torch
4 changes: 2 additions & 2 deletions tests/torch/test_dump_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_onnx_load_dump(tmp_path):
# check oracle
compare_two_oracle(oracle, loaded_testcase.oracle)

loaded_model = loaded_testcase.model.torch_model
loaded_model = loaded_testcase.model
loaded_model.sat_inputs = {k: torch.from_numpy(v) for k, v in oracle.input.items()}
rerun_oracle = loaded_model.make_oracle()
compare_two_oracle(oracle, rerun_oracle)
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_bug_report_load_dump(tmp_path):
# check oracle
compare_two_oracle(oracle, loaded_testcase.oracle)

loaded_model = loaded_testcase.model.torch_model
loaded_model = loaded_testcase.model
loaded_model.sat_inputs = {k: torch.from_numpy(v) for k, v in oracle.input.items()}
rerun_oracle = loaded_model.make_oracle()
compare_two_oracle(oracle, rerun_oracle)

0 comments on commit cff46a5

Please sign in to comment.