Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

empty_memory_format evaluator #2745

Merged
merged 1 commit into from
Jun 14, 2024
Merged

empty_memory_format evaluator #2745

merged 1 commit into from
Jun 14, 2024

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Apr 12, 2024

In this PR I am facing issue in the test case of

    (
        "empty_four_dimension_memformat",
        [1, 2, 2, 1],
        torch.float32,
        "cuda",
        torch.channels_last,
    ),
    (
        "empty_five_dimension_memformat",
        [1, 2, 2, 2, 1],
        torch.float32,
        "cuda",
        torch.channels_last_3d,
    ),

In the TRTInterpreter.run(), the stride of the empty_tensor
torch.ops.aten.empty.memory_format([1,2,2,1], dtype = torch.int32, memory_format = torch.channels_last)
is (4,1,2,2) (the desired torch output. This is an evaluator so could check the output in this line -
However when it comes to line the stride is (4,2,1,1). This is the output from the TRT engine from the the above INetwork.

Would anyone know what is going wrong?

fixes: #2738

@apbose apbose self-assigned this Apr 12, 2024
@apbose apbose marked this pull request as draft April 12, 2024 00:22
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Apr 12, 2024
@github-actions github-actions bot requested a review from gs-olive April 12, 2024 00:23
@apbose apbose force-pushed the empty_memory_format_evaluator branch 2 times, most recently from ed97b61 to ad5b467 Compare May 24, 2024 00:12
@apbose apbose marked this pull request as ready for review May 24, 2024 00:13
@apbose apbose requested review from zewenli98 and chohk88 May 24, 2024 00:13
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py	2024-05-24 00:12:59.638980+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py	2024-05-24 00:14:55.849732+00:00
@@ -87,10 +87,11 @@
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    return np.random.randn(*args[0])

+
def randperm_validator(randperm_node: Node) -> bool:
    dtype = randperm_node.kwargs.get("dtype", None)
    layout = randperm_node.kwargs.get("layout", None)
    input = randperm_node.args[0]
    if not isinstance(input, int):
@@ -116,10 +117,11 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    return np.random.permutation(args[0])
+

def empty_validator(empty_node: Node) -> bool:
    layout = empty_node.kwargs.get("layout", None)
    pin_memory = empty_node.kwargs.get("pin_memory", None)
    memory_format = empty_node.kwargs.get("memory_format", None)
@@ -162,8 +164,5 @@
    elif memory_format == torch.channels_last_3d:
        # shape of args[0] must be 5
        empty_tensor = empty_tensor.to(memory_format=torch.channels_last_3d)

    return empty_tensor
-
-
-

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py	2024-05-24 00:13:13.183345+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py	2024-05-24 00:15:09.256602+00:00
@@ -87,10 +87,11 @@
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    return np.random.randn(*args[0])

+
def randperm_validator(randperm_node: Node) -> bool:
    dtype = randperm_node.kwargs.get("dtype", None)
    layout = randperm_node.kwargs.get("layout", None)
    input = randperm_node.args[0]
    if not isinstance(input, int):
@@ -116,10 +117,11 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    return np.random.permutation(args[0])
+

def empty_validator(empty_node: Node) -> bool:
    layout = empty_node.kwargs.get("layout", None)
    pin_memory = empty_node.kwargs.get("pin_memory", None)
    memory_format = empty_node.kwargs.get("memory_format", None)
@@ -162,8 +164,5 @@
    elif memory_format == torch.channels_last_3d:
        # shape of args[0] must be 5
        empty_tensor = empty_tensor.to(memory_format=torch.channels_last_3d)

    return empty_tensor
-
-
-

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py	2024-05-24 00:13:23.345164+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py	2024-05-24 00:15:15.812085+00:00
@@ -87,10 +87,11 @@
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    return np.random.randn(*args[0])

+
def randperm_validator(randperm_node: Node) -> bool:
    dtype = randperm_node.kwargs.get("dtype", None)
    layout = randperm_node.kwargs.get("layout", None)
    input = randperm_node.args[0]
    if not isinstance(input, int):
@@ -116,10 +117,11 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    return np.random.permutation(args[0])
+

def empty_validator(empty_node: Node) -> bool:
    layout = empty_node.kwargs.get("layout", None)
    pin_memory = empty_node.kwargs.get("pin_memory", None)
    memory_format = empty_node.kwargs.get("memory_format", None)
@@ -162,8 +164,5 @@
    elif memory_format == torch.channels_last_3d:
        # shape of args[0] must be 5
        empty_tensor = empty_tensor.to(memory_format=torch.channels_last_3d)

    return empty_tensor
-
-
-

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py	2024-05-24 00:13:45.282455+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py	2024-05-24 00:15:36.375546+00:00
@@ -87,10 +87,11 @@
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    return np.random.randn(*args[0])

+
def randperm_validator(randperm_node: Node) -> bool:
    dtype = randperm_node.kwargs.get("dtype", None)
    layout = randperm_node.kwargs.get("layout", None)
    input = randperm_node.args[0]
    if not isinstance(input, int):
@@ -116,10 +117,11 @@
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    return np.random.permutation(args[0])
+

def empty_validator(empty_node: Node) -> bool:
    layout = empty_node.kwargs.get("layout", None)
    pin_memory = empty_node.kwargs.get("pin_memory", None)
    memory_format = empty_node.kwargs.get("memory_format", None)
@@ -162,8 +164,5 @@
    elif memory_format == torch.channels_last_3d:
        # shape of args[0] must be 5
        empty_tensor = empty_tensor.to(memory_format=torch.channels_last_3d)

    return empty_tensor
-
-
-

@apbose apbose force-pushed the empty_memory_format_evaluator branch from ad5b467 to fd330cf Compare May 24, 2024 01:50
Comment on lines 126 to 127
pin_memory = empty_node.kwargs.get("pin_memory", None)
memory_format = empty_node.kwargs.get("memory_format", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems pin_memory and memory_format are not used in the validator.


class TestRandConverter(DispatchTestCase):
@parameterized.expand(
[(empty_op[0], empty_op[1], empty_op[2], empty_op[3]) for empty_op in empty_ops]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since each test has 5 arguments, do we need empty_op[4] here?


def forward(self, x):
shape_or_input[0] = x.shape[0]
return torch.empty(shape_or_input)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistence, I think it's better to use torch.ops.aten.empty.memory_format with necessary args.

@parameterized.expand(
[(empty_op[0], empty_op[1], empty_op[2], empty_op[3]) for empty_op in empty_ops]
)
def test_empty(self, name, shape_or_input, data_type, device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems data_type and device are not tested.

import torch
import torch.nn as nn
import torch_tensorrt
from harness import DispatchTestCase
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be from .harness import DispatchTestCase

@apbose apbose force-pushed the empty_memory_format_evaluator branch from f97b696 to 52fc6a9 Compare June 5, 2024 01:02
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_empty_aten.py	2024-06-05 01:02:35.600566+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_empty_aten.py	2024-06-05 01:04:39.444401+00:00
@@ -87,11 +87,14 @@
                super().__init__()

            def forward(self, x):
                shape_or_input[0] = x.shape[0]
                return torch.ops.aten.empty.memory_format(
-                    shape_or_input, dtype=data_type, memory_format=memory_format, device=device
+                    shape_or_input,
+                    dtype=data_type,
+                    memory_format=memory_format,
+                    device=device,
                )

        empty_model = TestModule()

        inputs = [torch.randint(1, 3, shape_or_input, dtype=torch.int32)]

@apbose apbose force-pushed the empty_memory_format_evaluator branch 2 times, most recently from 1854b76 to a135bab Compare June 5, 2024 01:14
Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a minor comment. The other LGTM

Comment on lines 55 to 56
def __init__(self):
super().__init__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The init seems not necessary.

@apbose apbose force-pushed the empty_memory_format_evaluator branch from 5bc1798 to efe6745 Compare June 13, 2024 17:55
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py	2024-06-13 17:55:17.730470+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py	2024-06-13 17:57:14.508152+00:00
@@ -130,12 +130,14 @@
    if layout is not None:
        _LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
        return False
    memory_format = empty_node.kwargs.get("memory_format", None)
    if memory_format is not None:
-        _LOGGER.debug(f"Currently we don't support specifying memory_format, got {memory_format}.")
-        return False    
+        _LOGGER.debug(
+            f"Currently we don't support specifying memory_format, got {memory_format}."
+        )
+        return False
    return True


@dynamo_tensorrt_converter(
    torch.ops.aten.empty.memory_format, capability_validator=empty_validator

@apbose apbose force-pushed the empty_memory_format_evaluator branch from efe6745 to 0180ae2 Compare June 14, 2024 21:26
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py	2024-06-14 21:26:46.844227+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py	2024-06-14 21:29:08.372195+00:00
@@ -130,12 +130,14 @@
    if layout is not None:
        _LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
        return False
    memory_format = empty_node.kwargs.get("memory_format", None)
    if memory_format is not None:
-        _LOGGER.debug(f"Currently we don't support specifying memory_format, got {memory_format}.")
-        return False    
+        _LOGGER.debug(
+            f"Currently we don't support specifying memory_format, got {memory_format}."
+        )
+        return False
    return True


@dynamo_tensorrt_converter(
    torch.ops.aten.empty.memory_format, capability_validator=empty_validator

@apbose apbose force-pushed the empty_memory_format_evaluator branch from 0180ae2 to 1719a13 Compare June 14, 2024 21:36
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py	2024-06-14 21:37:02.453645+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py	2024-06-14 21:38:54.121244+00:00
@@ -177,12 +177,14 @@
    if layout is not None:
        _LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
        return False
    memory_format = empty_node.kwargs.get("memory_format", None)
    if memory_format is not None:
-        _LOGGER.debug(f"Currently we don't support specifying memory_format, got {memory_format}.")
-        return False    
+        _LOGGER.debug(
+            f"Currently we don't support specifying memory_format, got {memory_format}."
+        )
+        return False
    return True


@dynamo_tensorrt_converter(
    torch.ops.aten.empty.memory_format, capability_validator=empty_validator

@apbose apbose force-pushed the empty_memory_format_evaluator branch from 1719a13 to 659195e Compare June 14, 2024 21:55
Review comments- adding cases for stride, correcting validator and changing call to torch.ops.aten.empty.memory_format

Review comments- adding cases for stride, correcting validator and changing call to torch.ops.aten.empty.memory_format

Removing stride and device since - 1. converting to torch Tensor would lead to faketensor 2. Also get_trt_tensor in the forward pass while creating the engine does not retain the memory_format

Removing init from empty test
@apbose apbose merged commit 0387757 into main Jun 14, 2024
10 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

aten.empty.memory_format
3 participants