Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: TRTInterpreter output lacks return value #2114

Merged
merged 1 commit into from
Jul 21, 2023

Conversation

gs-olive
Copy link
Collaborator

@gs-olive gs-olive commented Jul 14, 2023

Description

  • Fixes error causing PyTree failures during Dynamo tracing for new converters returning lists of Tensors

The following error is encountered when testing converters:

Traceback (most recent call last):
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.10/site-packages/parameterized/parameterized.py", line 620, in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
  File "~/TensorRT/py/torch_tensorrt/dynamo/converters/test/test_split.py", line 25, in test_split
    self.run_test(
  File "~/TensorRT/py/torch_tensorrt/dynamo/converters/test_utils.py", line 102, in run_test
    super().run_test(
  File "~/TensorRT/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py", line 84, in run_test
    interpreter_result = interpreter.run(lower_precision=precision)
  File "~/TensorRT/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py", line 212, in run
    super().run()
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.10/site-packages/torch/fx/interpreter.py", line 155, in run
    return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.10/site-packages/torch/fx/graph.py", line 858, in process_outputs
    return self._codegen.process_outputs(out)
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.10/site-packages/torch/fx/graph.py", line 620, in process_outputs
    return pytree.tree_unflatten(out, self.pytree_info.out_spec)
  File "~/python_virtual_environments/torch_trt_venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 301, in tree_unflatten
    raise ValueError(
ValueError: tree_unflatten(values, spec): `values` has length 1 but the spec refers to a pytree that holds 2 items (TreeSpec(tuple, None, [*,
  *])).

The TorchGen output verifier, invoked here, requires the outputs to be a list of values. The TRTInterpreter was not returning any values, causing the first error, then was returning tuple values, causing the second issue. The cast introduced in this PR resolves those.

Related to: #1828 (comment)

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • [ x ] My code follows the style guidelines of this project (You can use the linters)
  • [ x ] I have performed a self-review of my own code
  • [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
  • [ x ] I have made corresponding changes to the documentation
  • [ x ] I have added tests to verify my fix or my feature
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified

@gs-olive gs-olive added component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths Story: Export/Compile Unification Issues relating to unification of Dynamo compile/export paths labels Jul 14, 2023
@gs-olive gs-olive requested a review from peri044 July 14, 2023 01:06
@gs-olive gs-olive self-assigned this Jul 14, 2023
@github-actions github-actions bot added the component: api [Python] Issues re: Python API label Jul 14, 2023
@github-actions github-actions bot requested a review from narendasan July 14, 2023 01:06
@gs-olive gs-olive force-pushed the interpreter_bugfix branch from 7f9a4a6 to e691e92 Compare July 14, 2023 21:48
@narendasan
Copy link
Collaborator

Change seems fine but when should this get merged?

@gs-olive
Copy link
Collaborator Author

Since the change is small, it can wait for #2104 to avoid major merge conflicts.

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

- Fixes error causing PyTree failures during Dynamo tracing for new
converters returning lists of Tensors
@gs-olive gs-olive force-pushed the interpreter_bugfix branch from e691e92 to dc26e08 Compare July 21, 2023 06:01
@gs-olive gs-olive merged commit 95730fe into pytorch:main Jul 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths Story: Export/Compile Unification Issues relating to unification of Dynamo compile/export paths
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants