Skip to content

Commit

Permalink
revert
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Apr 11, 2022
1 parent 80dc17e commit e817ad1
Showing 1 changed file with 38 additions and 9 deletions.
47 changes: 38 additions & 9 deletions torchmetrics/wrappers/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from copy import deepcopy
from typing import Any, Dict, List, Tuple, Union

Expand Down Expand Up @@ -152,7 +153,14 @@ def reset_all(self) -> None:

def best_metric(
self, return_step: bool = False
) -> Union[float, Tuple[int, float], Dict[str, float], Tuple[Dict[str, int], Dict[str, float]]]:
) -> Union[
None,
float,
Tuple[int, float],
Tuple[None, None],
Dict[str, Union[float, None]],
Tuple[Dict[str, Union[int, None]], Dict[str, Union[float, None]]],
]:
"""Returns the highest metric out of all tracked.
Args:
Expand All @@ -163,18 +171,39 @@ def best_metric(
"""
if isinstance(self._base_metric, Metric):
fn = torch.max if self.maximize else torch.min
idx, best = fn(self.compute_all(), 0)
if return_step:
return idx.item(), best.item()
return best.item()
else:
try:
idx, best = fn(self.compute_all(), 0)
if return_step:
return idx.item(), best.item()
return best.item()
except ValueError as error:
warnings.warn(
f"Encountered the following error when trying to get the best metric: {error}"
"this is probably due to the 'best' not being defined for this metric."
"Returning `None` instead.",
UserWarning,
)
if return_step:
return None, None
return None

else: # this is a metric collection
res = self.compute_all()
maximize = self.maximize if isinstance(self.maximize, list) else len(res) * [self.maximize]
idx, best = {}, {}
for i, (k, v) in enumerate(res.items()):
fn = torch.max if maximize[i] else torch.min
out = fn(v, 0)
idx[k], best[k] = out[0].item(), out[1].item()
try:
fn = torch.max if maximize[i] else torch.min
out = fn(v, 0)
idx[k], best[k] = out[0].item(), out[1].item()
except ValueError as error:
warnings.warn(
f"Encountered the following error when trying to get the best metric for metric {k}:"
f"{error} this is probably due to the 'best' not being defined for this metric."
"Returning `None` instead.",
UserWarning,
)
idx[k], best[k] = None, None

if return_step:
return idx, best
Expand Down

0 comments on commit e817ad1

Please sign in to comment.