diff --git a/fences/core/util.py b/fences/core/util.py index e35b3d2..53fe25c 100644 --- a/fences/core/util.py +++ b/fences/core/util.py @@ -141,10 +141,13 @@ def total(self) -> int: def accuracy(self) -> float: total = self.total() + if total == 0: + return 0 return (self.valid_accepted + self.invalid_rejected) / total def balanced_accuracy(self) -> float: - return ( - (self.valid_accepted / (self.valid_accepted + self.valid_rejected)) + - (self.invalid_rejected / (self.invalid_accepted + self.invalid_rejected)) - ) / 2 + valid_total = self.valid_accepted + self.valid_rejected + invalid_total = self.invalid_accepted + self.invalid_rejected + if valid_total == 0 or invalid_total == 0: + return 0 + return ((self.valid_accepted / valid_total) + (self.invalid_rejected / invalid_total)) / 2 diff --git a/test/core/test_util.py b/test/core/test_util.py index e503b48..0f32186 100644 --- a/test/core/test_util.py +++ b/test/core/test_util.py @@ -114,3 +114,8 @@ def test_balanced_accuracy(self): c.invalid_accepted = 7 c.invalid_rejected = 11 self.assertAlmostEqual(c.balanced_accuracy(), ((2/5) + (11/18)) / 2) + + def test_zero(self): + c = ConfusionMatrix() + self.assertEqual(c.accuracy(), 0) + self.assertEqual(c.balanced_accuracy(), 0)