Skip to content

Commit

Permalink
nccl op changes
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Jan 22, 2025
1 parent e44919b commit ee4a9c8
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 12 deletions.
8 changes: 3 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import Dict, Sequence, Tuple, Union

import tensorrt as trt
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
Expand All @@ -16,8 +17,6 @@
tensorrt_fused_nccl_reduce_scatter_op,
)

import tensorrt as trt

_LOGGER: logging.Logger = logging.getLogger(__name__)

if load_tensorrt_llm():
Expand All @@ -30,7 +29,7 @@ def fused_nccl_gather(
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.distributed.nccl_gather(
return impl.nccl_ops.nccl_gather(
ctx,
target,
SourceIR.ATEN,
Expand All @@ -46,15 +45,14 @@ def fused_nccl_reduce_scatter(
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.distributed.nccl_reduce_scatter(
return impl.nccl_ops.nccl_reduce_scatter(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

breakpoint()
else:
_LOGGER.debug(
"Did not load torch.distributed converters since TensorRT-LLM is not available"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ def update_node_meta(node: torch.fx.Node, fake_mode: FakeTensorMode) -> None:

if op_target in shape_inference_funcs:
new_shape = shape_inference_funcs[op_target](node)
real_tensor = torch.empty(new_shape, dtype=node.meta["val"].dtype)
new_node_dtype = None
if node.meta["val"].dtype == torch.complex64:
new_node_dtype = torch.float32
else:
new_node_dtype = torch.float64
real_tensor = torch.empty(new_shape, dtype=new_node_dtype)
node.meta["val"] = fake_mode.from_tensor(real_tensor)
else:
print("No shape for the inference function", {op_name})
12 changes: 6 additions & 6 deletions py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def fuse_distributed_ops(
== torch.ops._c10d_functional.wait_tensor.default
):
wait_tensor_node = list(node.users)[0]
fused_op = None
if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default:
with gm.graph.inserting_after(wait_tensor_node):
fused_node = gm.graph.create_node(
Expand All @@ -58,11 +57,12 @@ def fuse_distributed_ops(
args=(node.args[0], node.args[1], node.args[2]),
)
else:
fused_node = gm.graph.create_node(
op="call_function",
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
)
with gm.graph.inserting_after(wait_tensor_node):
fused_node = gm.graph.create_node(
op="call_function",
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
)

wait_tensor_node.replace_all_uses_with(fused_node)
fused_node.meta.update(node.meta)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,15 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
for i in inputs
]

for i, contiguous_input in enumerate(contiguous_inputs):
if contiguous_input.dtype == torch.complex64:
contiguous_input_real = contiguous_input.real
contiguous_input_imag = contiguous_input.imag
contiguous_inputs[i] = torch.stack(
(contiguous_input_real, contiguous_input_imag), dim=-1
)

with (
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
if self.profiling_enabled
Expand Down

0 comments on commit ee4a9c8

Please sign in to comment.