Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warning originating in C10 backend does not get translated to Python warning if run from subprocess #75725

Open
otaj opened this issue Apr 13, 2022 · 4 comments
Labels
high priority oncall: distributed Add this issue/PR to distributed oncall triage queue triage review

Comments

@otaj
Copy link

otaj commented Apr 13, 2022

🐛 Describe the bug

Hi,

I want to record a warning in Python, that is originating in C10 portion of the code (TORCH_WARN_ONCE), while running in a subprocess because of DDP. However, it seems that this warning is impossible to catch because it does not propagate to Python correctly. Below is a simple demo, that is mostly taken from this tutorial and adapted to catching warnings.

Code and output with warnings
import contextlib
import io
import os
import sys
import warnings

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import traceback


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(torch.nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = torch.nn.Linear(10, 10)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    new_stdout = io.StringIO()
    new_stderr = io.StringIO()

    with contextlib.ExitStack() as stack:
        warns = stack.enter_context(warnings.catch_warnings(record=True))
        stack.enter_context(contextlib.redirect_stdout(new_stdout))
        stack.enter_context(contextlib.redirect_stderr(new_stderr))
        warnings.simplefilter("always")
        warnings.warn("Simple warning", Warning)

        print(f"Running basic DDP example on rank {rank}.")
        setup(rank, world_size)

        # create model and move it to GPU with id rank
        model = ToyModel().to(rank)
        ddp_model = DDP(model, device_ids=[rank], find_unused_parameters=True)

        loss_fn = torch.nn.MSELoss()
        optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001)

        optimizer.zero_grad()
        try:
            outputs = ddp_model(torch.randn(20, 10))
            labels = torch.randn(20, 5).to(rank)
            loss_fn(outputs, labels).backward()
            optimizer.step()

        except:
            print(traceback.format_exc(), file=sys.stderr)

        finally:
            cleanup()

    print(f"Caught warnings:")
    for warn in warns:
        print(warn)

    print(f"Caught stdout: {new_stdout.getvalue()}")
    print(f"Caught stderr: {new_stderr.getvalue()}")


def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    world_size = n_gpus
    run_demo(demo_basic, world_size)

Output:

[W reducer.cpp:1289] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
Caught warnings:
{message : Warning('Simple warning'), category : 'Warning', filename : '/home/otaj/files/grid/simple-demo/main.py', lineno : 46, line : None}
Caught stdout: Running basic DDP example on rank 0.

Caught stderr: 

However, if I do some intentional mistake in order to raise an Exception in the similar code path (such as changing the size of tensors so that they do not match anymore), the Exception is correctly propagated to to Python as a RuntimeError, see the modified code

Code and output with Exception
import contextlib
import io
import os
import sys
import warnings

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import traceback


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(torch.nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = torch.nn.Linear(10, 10)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    new_stdout = io.StringIO()
    new_stderr = io.StringIO()

    with contextlib.ExitStack() as stack:
        warns = stack.enter_context(warnings.catch_warnings(record=True))
        stack.enter_context(contextlib.redirect_stdout(new_stdout))
        stack.enter_context(contextlib.redirect_stderr(new_stderr))
        warnings.simplefilter("always")
        warnings.warn("Simple warning", Warning)

        print(f"Running basic DDP example on rank {rank}.")
        setup(rank, world_size)

        # create model and move it to GPU with id rank
        model = ToyModel().to(rank)
        ddp_model = DDP(model, device_ids=[rank], find_unused_parameters=True)

        loss_fn = torch.nn.MSELoss()
        optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001)

        optimizer.zero_grad()
        try:
            outputs = ddp_model(torch.randn(20, 9)) # <--- Change is here, this will create error
            labels = torch.randn(20, 5).to(rank)
            loss_fn(outputs, labels).backward()
            optimizer.step()

        except:
            print(traceback.format_exc(), file=sys.stderr)

        finally:
            cleanup()

    print(f"Caught warnings:")
    for warn in warns:
        print(warn)

    print(f"Caught stdout: {new_stdout.getvalue()}")
    print(f"Caught stderr: {new_stderr.getvalue()}")


def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    world_size = n_gpus
    run_demo(demo_basic, world_size)

Output:

Caught warnings:
{message : Warning('Simple warning'), category : 'Warning', filename : '/home/otaj/files/grid/simple-demo/main.py', lineno : 46, line : None}
Caught stdout: Running basic DDP example on rank 0.

Caught stderr: Traceback (most recent call last):
  File "/home/otaj/files/grid/simple-demo/main.py", line 60, in demo_basic
    outputs = ddp_model(torch.randn(20, 9))
  File "/home/otaj/.pyenv/versions/pl-dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/otaj/.pyenv/versions/pl-dev/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 963, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/home/otaj/.pyenv/versions/pl-dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/otaj/files/grid/simple-demo/main.py", line 34, in forward
    return self.net2(self.relu(self.net1(x)))
  File "/home/otaj/.pyenv/versions/pl-dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/otaj/.pyenv/versions/pl-dev/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (20x9 and 10x10)

The issue was first reported on PyTorch slack, cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @albanD , @ezyang , @mruberry , it is most likely linked to this issue: #72948

Thanks a lot!

Versions

Collecting environment information...
PyTorch version: 1.11.0+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Arch Linux (x86_64)
GCC version: (GCC) 11.2.0
Clang version: Could not collect
CMake version: version 3.23.0
Libc version: glibc-2.35

Python version: 3.9.11 (main, Apr 7 2022, 15:33:34) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.17.1-zen1-1-zen-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.6.112
GPU models and configuration: GPU 0: NVIDIA T1200 Laptop GPU
Nvidia driver version: 510.60.02
cuDNN version: Probably one of the following:
/usr/lib/libcudnn.so.8.3.3
/usr/lib/libcudnn_adv_infer.so.8.3.3
/usr/lib/libcudnn_adv_train.so.8.3.3
/usr/lib/libcudnn_cnn_infer.so.8.3.3
/usr/lib/libcudnn_cnn_train.so.8.3.3
/usr/lib/libcudnn_ops_infer.so.8.3.3
/usr/lib/libcudnn_ops_train.so.8.3.3
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy==0.942
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.22.3
[pip3] pytorch-lightning==1.7.0.dev0
[pip3] torch==1.11.0+cu113
[pip3] torchmetrics==0.7.3
[pip3] torchtext==0.12.0
[pip3] torchvision==0.12.0+cu113
[conda] Could not collect

@albanD albanD added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Apr 13, 2022
@albanD
Copy link
Collaborator

albanD commented Apr 18, 2022

Note that beyond the "the warning is not properly propagated back to the main process" theory that is very likely to explain this. There is also the possibility that the warning is being triggered via a pybind11 warning which is missing the proper HANDLE_TH_ERRORS macro. Which would prevent the warning translation from happening.

The way to differentiate between the two is that in one case, the warning will happen on the python side but not be caught (first case above) while in the other case it will happen in c++ directly, only writing the raw string to stderr.

From checking the logs, the second seem to be what happens here btw. So most likely a pybind11 binding that is missing the macros.

@otaj
Copy link
Author

otaj commented Jun 22, 2022

Hi @albanD, can I ask if is this moving forward somehow? Thanks a lot!

@albanD
Copy link
Collaborator

albanD commented Jun 22, 2022

I don't know of any update i'm afraid.
cc @mrshenli from the distributed team

@otaj
Copy link
Author

otaj commented Sep 27, 2022

Hi, @albanD, @mrshenli, do you have any updates? Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority oncall: distributed Add this issue/PR to distributed oncall triage queue triage review
Projects
None yet
Development

No branches or pull requests

3 participants