Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confusing recommendation to use sync_dist=True even with TorchMetrics #20153

Open
srprca opened this issue Aug 2, 2024 · 9 comments
Open

Confusing recommendation to use sync_dist=True even with TorchMetrics #20153

srprca opened this issue Aug 2, 2024 · 9 comments
Labels
bug Something isn't working help wanted Open to be worked on logging Related to the `LoggerConnector` and `log()` ver: 2.2.x

Comments

@srprca
Copy link

srprca commented Aug 2, 2024

Bug description

Hello!

When I train and validate a model in a multi-GPU setting (HPC, sbatch job that requests multiple GPUs on a single node), I use self.log(..., sync_dist=True) when logging PyTorch losses, and don't specify any value for sync_dist when logging metrics from TorchMetrics library. However, I still get warnings like

...
It is recommended to use `self.log('val_mean_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
...
It is recommended to use `self.log('val_bg_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.

These specific messages correspond to logging tmc.MulticlassRecall(len(self.task.class_names), average="macro", ignore_index=self.metric_ignore_index) and individual components of tmc.MulticlassRecall(len(self.task.class_names), average="none", ignore_index=self.metric_ignore_index).

Full code listing for metric object definitions and logging is provided in the "reproducing the bug" section.

As I understand from a note here, and from discussion here, one doesn't typically need to explicitly use sync_dist when using TorchMetrics.

I wonder if I still need to enable sync_dist=True as advised in the warnings due to some special case that I am not aware about, or should I follow the docs and keep it as is? In any case, this is probably a bug, either in documentation, or in warning code.

Thank you!

What version are you seeing the problem on?

2.3.0

How to reproduce the bug

self.val_metric_funs = tm.MetricCollection(
                {
                    "cm_normalize_all": tmc.MulticlassConfusionMatrix(
                        len(self.task.class_names),
                        ignore_index=self.metric_ignore_index,
                        normalize="all",
                    ),
                    "recall_average_macro": tmc.MulticlassRecall(
                        len(self.task.class_names),
                        average="macro",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "recall_average_none": tmc.MulticlassRecall(
                        len(self.task.class_names),
                        average="none",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "precision_average_macro": tmc.MulticlassPrecision(
                        len(self.task.class_names),
                        average="macro",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "precision_average_none": tmc.MulticlassPrecision(
                        len(self.task.class_names),
                        average="none",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "f1_average_macro": tmc.MulticlassF1Score(
                        len(self.task.class_names),
                        average="macro",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "f1_average_none": tmc.MulticlassF1Score(
                        len(self.task.class_names),
                        average="none",
                        ignore_index=self.metric_ignore_index,
                    ),
                }
            )
if not sanity_check:
            for metric_name, metric in metrics.items():
                metric_fun = self.val_metric_funs[metric_name]
                metric_name_ = metric_name.split("_")[0]
                if isinstance(metric_fun, tmc.MulticlassConfusionMatrix):
                    for true_class_num in range(metric.shape[0]):
                        true_class_name = self.task.class_names[true_class_num]
                        for pred_class_num in range(metric.shape[1]):
                            pred_class_name = self.task.class_names[pred_class_num]
                            self.log(
                                f"val_true_{true_class_name}_pred_{pred_class_name}_cm",
                                metric[true_class_num, pred_class_num].item(),
                                on_step=False,
                                on_epoch=True,
                                logger=True,
                            )
                elif isinstance(
                    metric_fun,
                    (
                        tmc.MulticlassRecall,
                        tmc.MulticlassPrecision,
                        tmc.MulticlassF1Score,
                    ),
                ):
                    if metric_fun.average == "macro":
                        self.log(
                            f"val_mean_{metric_name_}",
                            metric.item(),
                            on_step=False,
                            on_epoch=True,
                            logger=True,
                        )
                    elif metric_fun.average == "none":
                        for class_num, metric_ in enumerate(metric):
                            class_name = self.task.class_names[class_num]
                            self.log(
                                f"val_{class_name}_{metric_name_}",
                                metric_.item(),
                                on_step=False,
                                on_epoch=True,
                                logger=True,
                            )
                    else:
                        raise NotImplementedError(
                            f"Code for logging metric {metric_name} is not implemented"
                        )
                else:
                    raise NotImplementedError(
                        f"Code for logging metric {metric_name} is not implemented"
                    )

Error messages and logs

...
It is recommended to use `self.log('val_mean_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
...
It is recommended to use `self.log('val_bg_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.

Environment

Current environment
#- PyTorch Lightning Version: 2.3.0
#- PyTorch Version: 2.3.1
#- Python version: 3.11.9
#- OS: Linux
#- CUDA/cuDNN version: 11.8
#- How you installed Lightning: conda-forge

More info

No response

cc @carmocca

@srprca srprca added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Aug 2, 2024
@awaelchli
Copy link
Contributor

awaelchli commented Aug 2, 2024

Yes that's right, the warning shouldn't occur when logging TorchMetrics. Does it occur only with MetricCollection or a regular Metric too?

@awaelchli awaelchli added logging Related to the `LoggerConnector` and `log()` help wanted Open to be worked on and removed needs triage Waiting to be triaged by maintainers labels Aug 2, 2024
@srprca
Copy link
Author

srprca commented Aug 2, 2024

Thank you for your reply!

I will be able to check this tomorrow, and will report back.

Meanwhile, my second suspicion is that since I log metric.item() objects, is it possible that somehow self.log doesn't recognize these objects as originating from TorchMetrics, and sees them as "generic" numbers or tensors...?

I will try to check this hypothesis, too.

@awaelchli
Copy link
Contributor

If you pass in scalar tensors then not of course. Then the warning is normal and expected. For logging TorchMetrics you would just pass in the metric directly into self.log. You can find a guide here: https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html

@srprca
Copy link
Author

srprca commented Aug 2, 2024

I guess this is exactly my case. I don't exactly remember why I have .item() for clearly-scalar metrics, and I guess I can try removing .item() in their case, but my use-case in general is that I want to log both aggregate metrics (mean recall over classes) and per-class metrics (recall of every individual class), and in the latter case I use multiclass recall with average="none", and then extract individual elements to be saved in separate columns in my metrics csv. Similarly, for confusion matrix, I want to log every individual entry as a separate column in the metrics csv file.

So I guess it's not a bug then, thank you for clarifying this!

Now I have just a couple more questions:

  1. Does this mean that the metrics are properly reduced across the devices behind the scenes, it is just that self.log doesn't recognize that this is the case and shows me the warning? That is, is it safe and correct to ignore it?
  2. Is there a more idiomatic way to do what I want to do here, given the use-case described above?

Thank you!

@srprca srprca closed this as completed Aug 2, 2024
@srprca
Copy link
Author

srprca commented Aug 3, 2024

Actually, this still happens when I log all the metric properly, without using .item(): now, with the following metrics definition:

self.val_metric_funs = tm.MetricCollection(
                {
                    "recall_average_macro": tmc.MulticlassRecall(
                        len(self.task.class_names),
                        average="macro",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "precision_average_macro": tmc.MulticlassPrecision(
                        len(self.task.class_names),
                        average="macro",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "f1_average_macro": tmc.MulticlassF1Score(
                        len(self.task.class_names),
                        average="macro",
                        ignore_index=self.metric_ignore_index,
                    ),
                }
            )

and the following logging code:

metrics = self.val_metric_funs(logits, annotator_masks_mixed)
        if not sanity_check:
            for metric_name, metric in metrics.items():
                metric_fun = self.val_metric_funs[metric_name]
                metric_name_ = metric_name.split("_")[0]
                if isinstance(
                    metric_fun,
                    (
                        tmc.MulticlassRecall,
                        tmc.MulticlassPrecision,
                        tmc.MulticlassF1Score,
                    ),
                ):
                    if metric_fun.average == "macro":
                        self.log(
                            f"val_mean_{metric_name_}",
                            metric,
                            on_step=False,
                            on_epoch=True,
                            logger=True,
                        )
                    else:
                        raise NotImplementedError(
                            f"Code for logging metric {metric_name} is not implemented"
                        )
                else:
                    raise NotImplementedError(
                        f"Code for logging metric {metric_name} is not implemented"
                    )

I get

.../miniforge3/envs/main/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:439: It is recommended to use `self.log('val_mean_f1', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
.../miniforge3/envs/main/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:439: It is recommended to use `self.log('val_mean_precision', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
.../miniforge3/envs/main/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:439: It is recommended to use `self.log('val_mean_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.

So here, metrics is the return value of calling MetricCollection with my logits and true masks, then I iterate over it like this: for metric_name, metric in metrics.items(), and finally properly log metric, not metric.item().

I will test whether this still happens without MetricCollection a bit later.

@srprca srprca reopened this Aug 3, 2024
@srprca
Copy link
Author

srprca commented Aug 3, 2024

This still happens when logged "properly" (without .item()) without MetricCollection wrapper.

@awaelchli
Copy link
Contributor

@geometrikal
Copy link

If you pass in scalar tensors then not of course. Then the warning is normal and expected. For logging TorchMetrics you would just pass in the metric directly into self.log. You can find a guide here: https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html

I get the same thing when logging using the manual method.

Going to try directly logging with the metric, but does that support ClasswiseWrapper?

@david-rohrschneider
Copy link

Hello, i can confirm the confusion. I am just training on 2 GPUs and cannot find any documentation on how to use MetricCollection in distributed environments. Im not using sync_dist, so getting the same warning and i am not sure if my metrics are computed / logged properly

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on logging Related to the `LoggerConnector` and `log()` ver: 2.2.x
Projects
None yet
Development

No branches or pull requests

4 participants