Skip to content

Commit

Permalink
[ENH] Add reasoning results in function return
Browse files Browse the repository at this point in the history
  • Loading branch information
Tony-HYX committed Dec 14, 2023
1 parent c8f537a commit 554c20c
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 70 deletions.
102 changes: 58 additions & 44 deletions abl/reasoning/kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@ def abduce_candidates(
Returns
-------
List[List[Any]]
A list of candidates, i.e. revised pseudo label samples that are compatible with the
knowledge base.
Tuple[List[List[Any]], List[Any]]
A tuple of two element. The first element is a list of candidate revisions, i.e. revised
pseudo label samples that are compatible with the knowledge base. The second element is
a list of reasoning results corresponding to each candidate, i.e., the outcome of the
logic_forward function.
"""
return self._abduce_by_search(pseudo_label, y, x, max_revision_num, require_more_revision)

Expand Down Expand Up @@ -173,19 +175,22 @@ def revise_at_idx(
Returns
-------
List[List[Any]]
A list of candidates, i.e. revised pseudo label samples that are compatible with the
knowledge base.
Tuple[List[List[Any]], List[Any]]
A tuple of two element. The first element is a list of candidate revisions, i.e. revised
pseudo label samples that are compatible with the knowledge base. The second element is
a list of reasoning results corresponding to each candidate, i.e., the outcome of the
logic_forward function.
"""
candidates = []
candidates, reasoning_results = [], []
abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx))
for c in abduce_c:
candidate = pseudo_label.copy()
for i, idx in enumerate(revision_idx):
candidate[idx] = c[i]
if self._check_equal(self.logic_forward(candidate, *(x,) if self._num_args == 2 else ()), y):
candidates.append(candidate)
return candidates
reasoning_result = self.logic_forward(candidate, *(x,) if self._num_args == 2 else ())
if self._check_equal(reasoning_result, y):
candidates.append(candidate); reasoning_results.append(reasoning_result)
return candidates, reasoning_results

def _revision(
self,
Expand All @@ -198,13 +203,12 @@ def _revision(
For a specified number of labels in a pseudo label sample to revise, iterate through
all possible indices to find any candidates that are compatible with the knowledge base.
"""
new_candidates = []
new_candidates, new_reasoning_results = [], []
revision_idx_list = combinations(range(len(pseudo_label)), revision_num)

for revision_idx in revision_idx_list:
candidates = self.revise_at_idx(pseudo_label, y, x, revision_idx)
new_candidates.extend(candidates)
return new_candidates
candidates, reasoning_results = self.revise_at_idx(pseudo_label, y, x, revision_idx)
new_candidates.extend(candidates); new_reasoning_results.extend(reasoning_results)
return new_candidates, new_reasoning_results

@abl_cache()
def _abduce_by_search(
Expand Down Expand Up @@ -237,26 +241,30 @@ def _abduce_by_search(
Returns
-------
List[List[Any]]
A list of candidates, i.e. revised pseudo label samples that are compatible with the
knowledge base.
Tuple[List[List[Any]], List[Any]]
A tuple of two element. The first element is a list of candidate revisions, i.e. revised
pseudo label samples that are compatible with the knowledge base. The second element is
a list of reasoning results corresponding to each candidate, i.e., the outcome of the
logic_forward function.
"""
candidates = []
candidates, reasoning_results = [], []
for revision_num in range(len(pseudo_label) + 1):
candidates.extend(self._revision(revision_num, pseudo_label, y, x))
new_candidates, new_reasoning_results = self._revision(revision_num, pseudo_label, y, x)
candidates.extend(new_candidates); reasoning_results.extend(new_reasoning_results)
if len(candidates) > 0:
min_revision_num = revision_num
break
if revision_num >= max_revision_num:
return []
return [], []

for revision_num in range(
min_revision_num + 1, min_revision_num + require_more_revision + 1
):
if revision_num > max_revision_num:
return candidates
candidates.extend(self._revision(revision_num, pseudo_label, y, x))
return candidates
return candidates, reasoning_results
new_candidates, new_reasoning_results = self._revision(revision_num, pseudo_label, y, x)
candidates.extend(new_candidates); reasoning_results.extend(new_reasoning_results)
return candidates, reasoning_results

def __repr__(self):
return (
Expand Down Expand Up @@ -363,28 +371,31 @@ def abduce_candidates(
Returns
-------
List[List[Any]]
A list of candidates, i.e. revised pseudo label samples that are compatible with the
knowledge base.
Tuple[List[List[Any]], List[Any]]
A tuple of two element. The first element is a list of candidate revisions, i.e. revised
pseudo label samples that are compatible with the knowledge base. The second element is
a list of reasoning results corresponding to each candidate, i.e., the outcome of the
logic_forward function.
"""
if self.GKB == {} or len(pseudo_label) not in self.GKB_len_list:
return []
return [], []

all_candidates = self._find_candidate_GKB(pseudo_label, y)
all_candidates, all_reasoning_results = self._find_candidate_GKB(pseudo_label, y)
if len(all_candidates) == 0:
return []
return [], []

cost_list = hamming_dist(pseudo_label, all_candidates)
min_revision_num = np.min(cost_list)
revision_num = min(max_revision_num, min_revision_num + require_more_revision)
idxs = np.where(cost_list <= revision_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates
reasoning_results = [all_reasoning_results[idx] for idx in idxs]
return candidates, reasoning_results

def _find_candidate_GKB(self, pseudo_label: List[Any], y: Any) -> List[List[Any]]:
"""
Retrieve compatible candidates from the prebuilt GKB. For numerical reasoning results,
return all candidates whose reasoning results fall within the
return all candidates and their corresponding reasoning results which fall within the
[y - max_err, y + max_err] range.
"""
if isinstance(y, (int, float)):
Expand All @@ -394,15 +405,14 @@ def _find_candidate_GKB(self, pseudo_label: List[Any], y: Any) -> List[List[Any]
low_key = bisect.bisect_left(key_list, y - self.max_err)
high_key = bisect.bisect_right(key_list, y + self.max_err)

all_candidates = [
candidate
for key in key_list[low_key:high_key]
for candidate in potential_candidates[key]
]
return all_candidates

all_candidates, all_reasoning_results = [], []
for key in key_list[low_key:high_key]:
for candidate in potential_candidates[key]:
all_candidates.append(candidate); all_reasoning_results.append(key)
else:
return self.GKB[len(pseudo_label)][y]
all_candidates = self.GKB[len(pseudo_label)][y]
all_reasoning_results = [y] * len(all_candidates)
return all_candidates, all_reasoning_results

def __repr__(self):
GKB_info_parts = []
Expand Down Expand Up @@ -551,11 +561,15 @@ def revise_at_idx(
Returns
-------
List[List[Any]]
Tuple[List[List[Any]], List[Any]]
A list of candidates, i.e. revised pseudo label samples that are compatible with the
knowledge base.
A tuple of two element. The first element is a list of candidate revisions, i.e. revised
pseudo label samples that are compatible with the knowledge base. The second element is
a list of reasoning results corresponding to each candidate, i.e., the outcome of the
logic_forward function.
"""
candidates = []
candidates, reasoning_results = [], []
query_string = self.get_query_string(pseudo_label, y, x, revision_idx)
save_pseudo_label = pseudo_label
pseudo_label = flatten(pseudo_label)
Expand All @@ -565,8 +579,8 @@ def revise_at_idx(
for i, idx in enumerate(revision_idx):
candidate[idx] = c[i]
candidate = reform_list(candidate, save_pseudo_label)
candidates.append(candidate)
return candidates
candidates.append(candidate); reasoning_results.append(y)
return candidates, reasoning_results

def __repr__(self):
return (
Expand Down
39 changes: 24 additions & 15 deletions abl/reasoning/reasoner.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def _check_valid_dist(self, dist_func):
return
elif callable(dist_func):
params = inspect.signature(dist_func).parameters.values()
if len(params) != 2:
raise ValueError(f"User-defined dist_func must have exactly two parameters, but got {len(params)}.")
if len(params) != 3:
raise ValueError(f"User-defined dist_func must have exactly three parameters, but got {len(params)}.")
return
else:
raise TypeError(
Expand All @@ -102,6 +102,7 @@ def _get_one_candidate(
self,
data_sample: ListData,
candidates: List[List[Any]],
reasoning_results: List[Any],
) -> List[Any]:
"""
Due to the nondeterminism of abductive reasoning, there could be multiple candidates
Expand All @@ -114,6 +115,8 @@ def _get_one_candidate(
Data sample.
candidates : List[List[Any]]
Multiple compatible candidates.
reasoning_results : List[Any]
Corresponding reasoning results of the candidates.
Returns
-------
Expand All @@ -125,14 +128,15 @@ def _get_one_candidate(
elif len(candidates) == 1:
return candidates[0]
else:
cost_array = self._get_cost_list(data_sample, candidates)
cost_array = self._get_cost_list(data_sample, candidates, reasoning_results)
candidate = candidates[np.argmin(cost_array)]
return candidate

def _get_cost_list(
self,
data_sample: ListData,
candidates: List[List[Any]],
reasoning_results: List[Any],
) -> np.ndarray:
"""
Get the list of costs between each candidate and the given data sample.
Expand All @@ -143,6 +147,8 @@ def _get_cost_list(
Data sample.
candidates : List[List[Any]]
Multiple compatible candidates.
reasoning_results : List[Any]
Corresponding reasoning results of the candidates.
Returns
-------
Expand All @@ -155,7 +161,7 @@ def _get_cost_list(
candidates = [[self.remapping[x] for x in c] for c in candidates]
return confidence_dist(data_sample.pred_prob, candidates)
else:
cost_list = self.dist_func(data_sample, candidates)
cost_list = self.dist_func(data_sample, candidates, reasoning_results)
if len(cost_list) != len(candidates):
raise ValueError(
f"The length of the array returned by dist_func must be equal to the number of candidates. "
Expand Down Expand Up @@ -222,11 +228,11 @@ def zoopt_revision_score(
The revision score for the solution.
"""
revision_idx = np.where(sol.get_x() != 0)[0]
candidates = self.kb.revise_at_idx(
candidates, reasoning_results = self.kb.revise_at_idx(
data_sample.pred_pseudo_label, data_sample.Y, data_sample.X, revision_idx
)
if len(candidates) > 0:
return np.min(self._get_cost_list(data_sample, candidates))
return np.min(self._get_cost_list(data_sample, candidates, reasoning_results))
else:
return symbol_num

Expand Down Expand Up @@ -281,19 +287,22 @@ def abduce(self, data_sample: ListData) -> List[Any]:
if self.use_zoopt:
solution = self._zoopt_get_solution(symbol_num, data_sample, max_revision_num)
revision_idx = np.where(solution != 0)[0]
candidates = self.kb.revise_at_idx(
data_sample.pred_pseudo_label, data_sample.Y, data_sample.X, revision_idx
candidates, reasoning_results = self.kb.revise_at_idx(
pseudo_label=data_sample.pred_pseudo_label,
y=data_sample.Y,
x=data_sample.X,
revision_idx=revision_idx
)
else:
candidates = self.kb.abduce_candidates(
data_sample.pred_pseudo_label,
data_sample.Y,
data_sample.X,
max_revision_num,
self.require_more_revision,
candidates, reasoning_results = self.kb.abduce_candidates(
pseudo_label=data_sample.pred_pseudo_label,
y=data_sample.Y,
x=data_sample.X,
max_revision_num=max_revision_num,
require_more_revision=self.require_more_revision
)

candidate = self._get_one_candidate(data_sample, candidates)
candidate = self._get_one_candidate(data_sample, candidates, reasoning_results)
return candidate

def batch_abduce(self, data_samples: ListData) -> List[List[Any]]:
Expand Down
22 changes: 11 additions & 11 deletions tests/test_reasoning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ def test_logic_forward(self, kb_add):

def test_revise_at_idx(self, kb_add):
result = kb_add.revise_at_idx([0, 2], 2, [0.1, -0.2, 0.2, -0.3], [])
assert result == [[0, 2]]
assert result == ([[0, 2]], [2])
result = kb_add.revise_at_idx([1, 2], 2, [0.1, -0.2, 0.2, -0.3], [])
assert result == []
assert result == ([], [])
result = kb_add.revise_at_idx([1, 2], 2, [0.1, -0.2, 0.2, -0.3], [0, 1])
assert result == [[0, 2], [1, 1], [2, 0]]
assert result == ([[0, 2], [1, 1], [2, 0]], [2, 2, 2])

def test_abduce_candidates(self, kb_add):
result = kb_add.abduce_candidates([0, 1], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0)
assert result == [[0, 1]]
assert result == ([[0, 1]], [1])
result = kb_add.abduce_candidates([1, 2], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0)
assert result == [[1, 0]]
assert result == ([[1, 0]], [1])


class TestGroundKB(object):
Expand All @@ -47,7 +47,7 @@ def test_abduce_candidates_ground(self, kb_add_ground):
result = kb_add_ground.abduce_candidates(
[1, 2], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0
)
assert result == [(1, 0)]
assert result == ([(1, 0)], [1])


class TestPrologKB(object):
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_logic_forward_pl2(self, kb_hed):

def test_revise_at_idx(self, kb_add_prolog):
result = kb_add_prolog.revise_at_idx([1, 2], 2, [0.1, -0.2, 0.2, -0.3], [0])
assert result == [[0, 2]]
assert result == ([[0, 2]], [2])


class TestReaonser(object):
Expand All @@ -101,8 +101,8 @@ def test_invalid_predefined_dist_func(self, kb_add):
excinfo.value
)

def random_dist(self, data_sample, candidates):
cost_list = np.array([np.random.rand() for _ in candidates])
def random_dist(self, data_sample, candidates, reasoning_results):
cost_list = [np.random.rand() for _ in candidates]
return cost_list

def test_user_defined_dist_func(self, kb_add):
Expand All @@ -113,14 +113,14 @@ def invalid_dist1(self, candidates):
cost_list = np.array([np.random.rand() for _ in candidates])
return cost_list

def invalid_dist2(self, data_sample, candidates):
def invalid_dist2(self, data_sample, candidates, reasoning_results):
cost_list = np.array([np.random.rand() for _ in candidates])
return np.append(cost_list, np.random.rand())

def test_invalid_user_defined_dist_func(self, kb_add, data_samples_add):
with pytest.raises(ValueError) as excinfo:
Reasoner(kb_add, self.invalid_dist1)
assert 'User-defined dist_func must have exactly two parameters' in str(
assert 'User-defined dist_func must have exactly three parameters' in str(
excinfo.value
)
with pytest.raises(ValueError) as excinfo:
Expand Down

0 comments on commit 554c20c

Please sign in to comment.