From 5bba95f9d82b37203d8c64b7c989eac3e37d8874 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Otto?= Date: Tue, 1 Oct 2024 14:29:05 +0200 Subject: [PATCH] Handle division by zero properly --- fences/core/util.py | 11 +++++++---- test/core/test_util.py | 5 +++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/fences/core/util.py b/fences/core/util.py index e35b3d2..53fe25c 100644 --- a/fences/core/util.py +++ b/fences/core/util.py @@ -141,10 +141,13 @@ def total(self) -> int: def accuracy(self) -> float: total = self.total() + if total == 0: + return 0 return (self.valid_accepted + self.invalid_rejected) / total def balanced_accuracy(self) -> float: - return ( - (self.valid_accepted / (self.valid_accepted + self.valid_rejected)) + - (self.invalid_rejected / (self.invalid_accepted + self.invalid_rejected)) - ) / 2 + valid_total = self.valid_accepted + self.valid_rejected + invalid_total = self.invalid_accepted + self.invalid_rejected + if valid_total == 0 or invalid_total == 0: + return 0 + return ((self.valid_accepted / valid_total) + (self.invalid_rejected / invalid_total)) / 2 diff --git a/test/core/test_util.py b/test/core/test_util.py index e503b48..0f32186 100644 --- a/test/core/test_util.py +++ b/test/core/test_util.py @@ -114,3 +114,8 @@ def test_balanced_accuracy(self): c.invalid_accepted = 7 c.invalid_rejected = 11 self.assertAlmostEqual(c.balanced_accuracy(), ((2/5) + (11/18)) / 2) + + def test_zero(self): + c = ConfusionMatrix() + self.assertEqual(c.accuracy(), 0) + self.assertEqual(c.balanced_accuracy(), 0)