Skip to content

Commit

Permalink
Pairwise pairs
Browse files Browse the repository at this point in the history
Added parallel loop for pairwise pairs
  • Loading branch information
venetiap committed Nov 1, 2024
1 parent 4cfac7c commit 49a9fbb
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
3 changes: 3 additions & 0 deletions examples/plot_basic_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def score_function(X):

print("Pairwise comparison (one vs group):", xai.pairwise(X[2], X[5:10]))

pairlist=[(X[2], X[3]), (X[2], X[4]), (X[2], X[2]), (X[4], X[2])]
print("Pairwise comparison (group of pairs):", xai.pairwise_all(pairlist))


######################################################################################
# We can also turn these into visualizations:
Expand Down
21 changes: 21 additions & 0 deletions sharp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,24 @@ def pairwise(self, sample1, sample2, **kwargs):
coalition_size=coalition_size,
**kwargs
)

def pairwise_all(self, pairs, **kwargs):
"""
set_cols_idx should be passed in kwargs if measure is marginal
pairs is a list of tuples of indexes
"""
# X_ref = self._X if self._X is not None else check_inputs(X)[0]

if "sample_size" in kwargs.keys():
sample_size = 1

influences = parallel_loop(
lambda idx: self.individual(
pairs[idx][0].reshape(1, -1), X=pairs[idx][1].reshape(1, -1), verbose=False, **kwargs
),
range(len(pairs)),
n_jobs=self.n_jobs,
progress_bar=self.verbose,
)

return np.array(influences)

0 comments on commit 49a9fbb

Please sign in to comment.