Skip to content

Commit

Permalink
fix MAP and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
damianoazzolini committed Oct 26, 2024
1 parent ca8b827 commit ac42523
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 17 deletions.
6 changes: 2 additions & 4 deletions pastasolver/models_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,8 +694,7 @@ def get_highest_prob_and_w_id_map(
max_prob : float = 0.0
w_id_list : 'list[str]' = []

for el in current_worlds_dict:
w = current_worlds_dict[el]
for el, w in current_worlds_dict.items():
if w.model_query_count > 0 and (w.model_not_query_count == 0 if lower else True):
if w.prob == max_prob:
max_prob = w.prob
Expand Down Expand Up @@ -727,8 +726,7 @@ def get_map_solution(
else:
# group by map variables
map_worlds : 'dict[str,World]' = {}
for el in self.worlds_dict:
w = self.worlds_dict[el]
for el, w in self.worlds_dict.items():
if w.model_query_count > 0:
# keep both lower and upper
sub_w = ModelsHandler.get_sub_world(el, map_id_list)
Expand Down
3 changes: 0 additions & 3 deletions pastasolver/pasta_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,9 +596,6 @@ def map_inference(self, from_string : str = "") -> 'tuple[float,list[list[str]]]
self.interface.compute_probabilities()
max_prob, map_state = self.interface.model_handler.get_map_solution(
self.parser.map_id_list, self.consider_lower_prob)
if self.interface.normalizing_factor >= 1:
max_prob = 1
# print_warning("No worlds have > 1 answer sets")

if self.normalize_prob and self.interface.normalizing_factor != 0:
max_prob = max_prob / (1 - self.interface.normalizing_factor)
Expand Down
1 change: 1 addition & 0 deletions test/test_approximate_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def wrap_test_approximate_inference(self, parameters : utils_for_tests.TestArgum
args.rejection = parameters.rejection
args.mh = parameters.mh
args.gibbs = parameters.gibbs
args.approximate_hybrid = False

lp, up = pasta_solver.approximate_solve(args)

Expand Down
11 changes: 1 addition & 10 deletions test/test_map_mpe.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
import unittest

import pytest

import importlib.util

import utils_for_tests

from pastasolver.pasta_solver import Pasta

# t_utils = __import__('test_utils')




class TestClassMapMpe(unittest.TestCase):

def wrap_test_map_mpe(self,
Expand All @@ -31,9 +23,8 @@ def wrap_test_map_mpe(self,
max_p, atoms_list = pasta_solver.upper_mpe_inference()


if max_p is not None and atoms_list is not None:
if max_p > 0 and len(atoms_list) > 0:
self.assertTrue(utils_for_tests.almostEqual(max_p, expected_map_mpe, 5), test_name + ": wrong MAP/MPE")

self.assertTrue(utils_for_tests.check_if_lists_equal(atoms_list, expected_atoms_list), test_name + ": wrong atoms list")


Expand Down

0 comments on commit ac42523

Please sign in to comment.