diff --git a/torchmetrics/wrappers/tracker.py b/torchmetrics/wrappers/tracker.py index 273cb9019f5..84544c1c1dc 100644 --- a/torchmetrics/wrappers/tracker.py +++ b/torchmetrics/wrappers/tracker.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from copy import deepcopy from typing import Any, Dict, List, Tuple, Union @@ -152,7 +153,14 @@ def reset_all(self) -> None: def best_metric( self, return_step: bool = False - ) -> Union[float, Tuple[int, float], Dict[str, float], Tuple[Dict[str, int], Dict[str, float]]]: + ) -> Union[ + None, + float, + Tuple[int, float], + Tuple[None, None], + Dict[str, Union[float, None]], + Tuple[Dict[str, Union[int, None]], Dict[str, Union[float, None]]], + ]: """Returns the highest metric out of all tracked. Args: @@ -163,18 +171,39 @@ def best_metric( """ if isinstance(self._base_metric, Metric): fn = torch.max if self.maximize else torch.min - idx, best = fn(self.compute_all(), 0) - if return_step: - return idx.item(), best.item() - return best.item() - else: + try: + idx, best = fn(self.compute_all(), 0) + if return_step: + return idx.item(), best.item() + return best.item() + except ValueError as error: + warnings.warn( + f"Encountered the following error when trying to get the best metric: {error}" + "this is probably due to the 'best' not being defined for this metric." + "Returning `None` instead.", + UserWarning, + ) + if return_step: + return None, None + return None + + else: # this is a metric collection res = self.compute_all() maximize = self.maximize if isinstance(self.maximize, list) else len(res) * [self.maximize] idx, best = {}, {} for i, (k, v) in enumerate(res.items()): - fn = torch.max if maximize[i] else torch.min - out = fn(v, 0) - idx[k], best[k] = out[0].item(), out[1].item() + try: + fn = torch.max if maximize[i] else torch.min + out = fn(v, 0) + idx[k], best[k] = out[0].item(), out[1].item() + except ValueError as error: + warnings.warn( + f"Encountered the following error when trying to get the best metric for metric {k}:" + f"{error} this is probably due to the 'best' not being defined for this metric." + "Returning `None` instead.", + UserWarning, + ) + idx[k], best[k] = None, None if return_step: return idx, best