Skip to content

Commit

Permalink
Merge pull request #291 from biocore/fix-matching
Browse files Browse the repository at this point in the history
Fixing sparse matching
  • Loading branch information
mortonjt authored Jul 8, 2021
2 parents 2ddd451 + 446a270 commit 544ef2f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
25 changes: 25 additions & 0 deletions gneiss/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,31 @@ def test_biom_match_no_common_ids(self):
with self.assertRaises(ValueError):
match(table, md)

def test_biom_match_intersect(self):
table = Table(
np.array([[0, 0, 1, 1],
[2, 3, 4, 4],
[5, 5, 3, 3]]).T,
['a', 'b', 'c', 'd'],
['s1', 's2', 'y4'])
md = pd.DataFrame([[0, 1], [1, 0], [1, 1]],
index=['s2', 's1', 's3'],
columns=['x1', 'x2'])
exp_table = Table(
np.array([[0, 0, 1, 1],
[2, 3, 4, 4]]).T,
['a', 'b', 'c', 'd'],
['s1', 's2'])
exp_md = pd.DataFrame([[1, 0], [0, 1]],
columns=['x1', 'x2'],
index=['s1', 's2'])
res_table, res_md = match(table, md)
pdt.assert_frame_equal(res_md, exp_md)
exp_df = pd.DataFrame(exp_table.to_dataframe())
res_df = pd.DataFrame(res_table.to_dataframe())
pdt.assert_frame_equal(res_df, exp_df)


def test_biom_match_tips_intersect_tips(self):
# there are less tree tips than table columns
table = Table(
Expand Down
11 changes: 9 additions & 2 deletions gneiss/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ def _dense_match(table, metadata):
if len(submetadataids) != len(metadata.index):
raise ValueError("`metadata` has duplicate sample ids.")

idx = subtableids & submetadataids
idx = list(subtableids & submetadataids)
# make sure that the list is always the same to remove
# unwanted random behavior
idx.sort()
if len(idx) == 0:
raise ValueError(("No more samples left. Check to make sure that "
"the sample names between `metadata` and `table` "
Expand All @@ -148,7 +151,10 @@ def _sparse_match(table, metadata):
submetadataids = set(metadata.index)
if len(submetadataids) != len(metadata.index):
raise ValueError("`metadata` has duplicate sample ids.")
idx = subtableids & submetadataids
idx = list(subtableids & submetadataids)
# make sure that the list is always the same to remove
# unwanted random behavior
idx.sort()
if len(idx) == 0:
raise ValueError(("No more samples left. Check to make sure that "
"the sample names between `metadata` and `table` "
Expand All @@ -165,6 +171,7 @@ def sort_f(xs):
return [xs[out_metadata.index.get_loc(x)] for x in xs]

out_table = out_table.sort(sort_f=sort_f, axis='sample')
out_metadata = out_metadata.loc[out_table.ids()]
return out_table, out_metadata


Expand Down

0 comments on commit 544ef2f

Please sign in to comment.