Skip to content

Commit

Permalink
[ENH] Change dist_func to four parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Tony-HYX committed Dec 15, 2023
1 parent a96cdfd commit 6610301
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions abl/reasoning/reasoner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Reasoner:
candidate, 'confidence': calculates the distance between the prediction
and each candidate based on confidence derived from the predicted probability
in the data sample. The callable function should have the signature
dist_func(data_sample, candidates) and must return a cost list. Each element
dist_func(data_sample, candidates, candidate_idxs, reasoning_results) and must return a cost list. Each element
in this cost list should be a numerical value representing the cost for each
candidate, and the list should have the same length as candidates.
Defaults to 'confidence'.
Expand Down Expand Up @@ -79,8 +79,8 @@ def _check_valid_dist(self, dist_func):
return
elif callable(dist_func):
params = inspect.signature(dist_func).parameters.values()
if len(params) != 3:
raise ValueError(f"User-defined dist_func must have exactly three parameters, but got {len(params)}.")
if len(params) != 4:
raise ValueError(f"User-defined dist_func must have exactly four parameters, but got {len(params)}.")
return
else:
raise TypeError(
Expand Down Expand Up @@ -161,7 +161,8 @@ def _get_cost_list(
candidates = [[self.remapping[x] for x in c] for c in candidates]
return confidence_dist(data_sample.pred_prob, candidates)
else:
cost_list = self.dist_func(data_sample, candidates, reasoning_results)
candidate_idxs = [[self.remapping[x] for x in c] for c in candidates]
cost_list = self.dist_func(data_sample, candidates, candidate_idxs, reasoning_results)
if len(cost_list) != len(candidates):
raise ValueError(
f"The length of the array returned by dist_func must be equal to the number of candidates. "
Expand Down

0 comments on commit 6610301

Please sign in to comment.