diff --git a/src/sourmash/commands.py b/src/sourmash/commands.py index c80b242cc1..60c65ff560 100644 --- a/src/sourmash/commands.py +++ b/src/sourmash/commands.py @@ -733,9 +733,9 @@ def gather(args): else: raise # re-raise other errors, if no picklist. - save_prefetch.add_many(counter.siglist) + save_prefetch.add_many(counter.signatures()) # subtract found hashes as we can. - for found_sig in counter.siglist: + for found_sig in counter.signatures(): noident_mh.remove_many(found_sig.minhash) # optionally calculate and save prefetch csv @@ -935,7 +935,7 @@ def multigather(args): counters = [] for db in databases: counter = db.counter_gather(prefetch_query, args.threshold_bp) - for found_sig in counter.siglist: + for found_sig in counter.signatures(): noident_mh.remove_many(found_sig.minhash) counters.append(counter) diff --git a/src/sourmash/index/__init__.py b/src/sourmash/index/__init__.py index eb8a55a94c..c55fcf1a35 100644 --- a/src/sourmash/index/__init__.py +++ b/src/sourmash/index/__init__.py @@ -39,11 +39,15 @@ from abc import abstractmethod, ABC from collections import namedtuple, Counter -from sourmash.search import (make_jaccard_search_query, make_gather_query, +from sourmash.search import (make_jaccard_search_query, + make_containment_query, calc_threshold_from_bp) from sourmash.manifest import CollectionManifest from sourmash.logging import debug_literal from sourmash.signature import load_signatures, save_signatures +from sourmash.minhash import (flatten_and_downsample_scaled, + flatten_and_downsample_num, + flatten_and_intersect_scaled) # generic return tuple for Index.search and Index.gather IndexSearchResult = namedtuple('Result', 'score, signature, location') @@ -108,7 +112,7 @@ def find(self, search_fn, query, **kwargs): search_fn follows the protocol in JaccardSearch objects. - Returns a list. + Generator. Returns 0 or more IndexSearchResult objects. """ # first: is this query compatible with this search? search_fn.check_is_compatible(query) @@ -124,50 +128,19 @@ def find(self, search_fn, query, **kwargs): query_scaled = query_mh.scaled def prepare_subject(subj_mh): - assert subj_mh.scaled - if subj_mh.track_abundance: - subj_mh = subj_mh.flatten() - - # downsample subject to highest scaled - subj_scaled = subj_mh.scaled - if subj_scaled < query_scaled: - return subj_mh.downsample(scaled=query_scaled) - else: - return subj_mh + return flatten_and_downsample_scaled(subj_mh, query_scaled) def prepare_query(query_mh, subj_mh): - assert subj_mh.scaled - - # downsample query to highest scaled - subj_scaled = subj_mh.scaled - if subj_scaled > query_scaled: - return query_mh.downsample(scaled=subj_scaled) - else: - return query_mh + return flatten_and_downsample_scaled(query_mh, subj_mh.scaled) else: # num query_num = query_mh.num def prepare_subject(subj_mh): - assert subj_mh.num - if subj_mh.track_abundance: - subj_mh = subj_mh.flatten() - - # downsample subject to smallest num - subj_num = subj_mh.num - if subj_num > query_num: - return subj_mh.downsample(num=query_num) - else: - return subj_mh + return flatten_and_downsample_num(subj_mh, query_num) def prepare_query(query_mh, subj_mh): - assert subj_mh.num - # downsample query to smallest num - subj_num = subj_mh.num - if subj_num < query_num: - return query_mh.downsample(num=subj_num) - else: - return query_mh + return flatten_and_downsample_num(query_mh, subj_mh.num) # now, do the search! for subj, location in self.signatures_with_location(): @@ -195,7 +168,7 @@ def prepare_query(query_mh, subj_mh): yield IndexSearchResult(score, subj, location) def search_abund(self, query, *, threshold=None, **kwargs): - """Return set of matches with angular similarity above 'threshold'. + """Return list of IndexSearchResult with angular similarity above 'threshold'. Results will be sorted by similarity, highest to lowest. """ @@ -223,7 +196,7 @@ def search_abund(self, query, *, threshold=None, **kwargs): def search(self, query, *, threshold=None, do_containment=False, do_max_containment=False, best_only=False, **kwargs): - """Return set of matches with similarity above 'threshold'. + """Return list of IndexSearchResult with similarity above 'threshold'. Results will be sorted by similarity, highest to lowest. @@ -239,50 +212,55 @@ def search(self, query, *, threshold=None, threshold = float(threshold) search_obj = make_jaccard_search_query(do_containment=do_containment, - do_max_containment=do_max_containment, + do_max_containment=do_max_containment, best_only=best_only, threshold=threshold) # do the actual search: - matches = [] - - for sr in self.find(search_obj, query, **kwargs): - matches.append(sr) + matches = list(self.find(search_obj, query, **kwargs)) # sort! matches.sort(key=lambda x: -x.score) return matches def prefetch(self, query, threshold_bp, **kwargs): - "Return all matches with minimum overlap." + """Return all matches with minimum overlap. + + Generator. Returns 0 or more IndexSearchResult namedtuples. + """ if not self: # empty database? quit. raise ValueError("no signatures to search") - search_fn = make_gather_query(query.minhash, threshold_bp, - best_only=False) + # default best_only to False + best_only = kwargs.get('best_only', False) + + search_fn = make_containment_query(query.minhash, threshold_bp, + best_only=best_only) for sr in self.find(search_fn, query, **kwargs): yield sr - def gather(self, query, threshold_bp=None, **kwargs): - "Return the match with the best Jaccard containment in the Index." + def best_containment(self, query, threshold_bp=None, **kwargs): + """Return the match with the best Jaccard containment in the Index. - results = [] - for result in self.prefetch(query, threshold_bp, **kwargs): - results.append(result) + Returns an IndexSearchResult namedtuple or None. + """ - # sort results by best score. - results.sort(reverse=True, - key=lambda x: (x.score, x.signature.md5sum())) + results = self.prefetch(query, threshold_bp, best_only=True, **kwargs) + results = sorted(results, + key=lambda x: (-x.score, x.signature.md5sum())) - return results[:1] + try: + return next(iter(results)) + except StopIteration: + return None def peek(self, query_mh, *, threshold_bp=0): """Mimic CounterGather.peek() on top of Index. This is implemented for situations where we don't want to use 'prefetch' functionality. It is a light wrapper around the - 'gather'/search-by-containment method. + 'best_containment(...)' method. """ from sourmash import SourmashSignature @@ -291,7 +269,7 @@ def peek(self, query_mh, *, threshold_bp=0): # run query! try: - result = self.gather(query_ss, threshold_bp=threshold_bp) + result = self.best_containment(query_ss, threshold_bp=threshold_bp) except ValueError: result = None @@ -299,14 +277,10 @@ def peek(self, query_mh, *, threshold_bp=0): return [] # if matches, calculate intersection & return. - sr = result[0] - match_mh = sr.signature.minhash - scaled = max(query_mh.scaled, match_mh.scaled) - match_mh = match_mh.downsample(scaled=scaled).flatten() - query_mh = query_mh.downsample(scaled=scaled) - intersect_mh = match_mh & query_mh + intersect_mh = flatten_and_intersect_scaled(result.signature.minhash, + query_mh) - return [sr, intersect_mh] + return [result, intersect_mh] def consume(self, intersect_mh): "Mimic CounterGather.consume on top of Index. Yes, this is backwards." @@ -326,7 +300,7 @@ def counter_gather(self, query, threshold_bp, **kwargs): prefetch_query.minhash = prefetch_query.minhash.flatten() # find all matches and construct a CounterGather object. - counter = CounterGather(prefetch_query.minhash) + counter = CounterGather(prefetch_query) for result in self.prefetch(prefetch_query, threshold_bp, **kwargs): counter.add(result.signature, location=result.location) @@ -721,9 +695,14 @@ class CounterGather: This particular implementation maintains a collections.Counter that is used to quickly find the best match when 'peek' is called, but other implementations are possible ;). + + Note that redundant matches (SourmashSignature objects) with + duplicate md5s are collapsed inside the class, because we use the + md5sum as a key into the dictionary used to store matches. """ - def __init__(self, query_mh): - "Constructor - takes a query FracMinHash." + def __init__(self, query): + "Constructor - takes a query SourmashSignature." + query_mh = query.minhash if not query_mh.scaled: raise ValueError('gather requires scaled signatures') @@ -732,8 +711,8 @@ def __init__(self, query_mh): self.scaled = query_mh.scaled # use these to track loaded matches & their locations - self.siglist = [] - self.locations = [] + self.siglist = {} + self.locations = {} # ...and also track overlaps with the progressive query self.counter = Counter() @@ -749,11 +728,11 @@ def add(self, ss, *, location=None, require_overlap=True): # upon insertion, count & track overlap with the specific query. overlap = self.orig_query_mh.count_common(ss.minhash, True) if overlap: - i = len(self.siglist) + md5 = ss.md5sum() - self.counter[i] = overlap - self.siglist.append(ss) - self.locations.append(location) + self.counter[md5] = overlap + self.siglist[md5] = ss + self.locations[md5] = location # note: scaled will be max of all matches. self.downsample(ss.minhash.scaled) @@ -766,6 +745,11 @@ def downsample(self, scaled): self.scaled = scaled return self.scaled + def signatures(self): + "Return all signatures." + for ss in self.siglist.values(): + yield ss + def peek(self, cur_query_mh, *, threshold_bp=0): "Get next 'gather' result for this database, w/o changing counters." self.query_started = 1 @@ -789,11 +773,11 @@ def peek(self, cur_query_mh, *, threshold_bp=0): raise ValueError("current query not a subset of original query") # are we setting a threshold? - threshold, n_threshold_hashes = calc_threshold_from_bp(threshold_bp, - scaled, - len(cur_query_mh)) - # is it too high to ever match? if so, exit. - if threshold > 1.0: + try: + x = calc_threshold_from_bp(threshold_bp, scaled, len(cur_query_mh)) + threshold, n_threshold_hashes = x + except ValueError: + # too high to ever match => exit return [] # Find the best match using the internal Counter. diff --git a/src/sourmash/minhash.py b/src/sourmash/minhash.py index f4411b30d7..97b9973aef 100644 --- a/src/sourmash/minhash.py +++ b/src/sourmash/minhash.py @@ -103,6 +103,39 @@ def translate_codon(codon): raise ValueError(e.message) +def flatten_and_downsample_scaled(mh, *scaled_vals): + "Flatten MinHash object and downsample to max of scaled values." + assert mh.scaled + assert all( (x > 0 for x in scaled_vals) ) + + mh = mh.flatten() + scaled = max(scaled_vals) + if scaled > mh.scaled: + return mh.downsample(scaled=scaled) + return mh + + +def flatten_and_downsample_num(mh, *num_vals): + "Flatten MinHash object and downsample to min of num values." + assert mh.num + assert all( (x > 0 for x in num_vals) ) + + mh = mh.flatten() + num = min(num_vals) + if num < mh.num: + return mh.downsample(num=num) + return mh + + +def flatten_and_intersect_scaled(mh1, mh2): + "Flatten and downsample two scaled MinHash objs, then return intersection." + scaled = max(mh1.scaled, mh2.scaled) + mh1 = mh1.flatten().downsample(scaled=scaled) + mh2 = mh2.flatten().downsample(scaled=scaled) + + return mh1 & mh2 + + class _HashesWrapper(Mapping): "A read-only view of the hashes contained by a MinHash object." def __init__(self, h): diff --git a/src/sourmash/search.py b/src/sourmash/search.py index 3e03951978..96a218b7d0 100644 --- a/src/sourmash/search.py +++ b/src/sourmash/search.py @@ -20,6 +20,9 @@ def calc_threshold_from_bp(threshold_bp, scaled, query_size): n_threshold_hashes = 0 if threshold_bp: + if threshold_bp < 0: + raise TypeError("threshold_bp must be non-negative") + # if we have a threshold_bp of N, then that amounts to N/scaled # hashes: n_threshold_hashes = float(threshold_bp) / scaled @@ -27,6 +30,9 @@ def calc_threshold_from_bp(threshold_bp, scaled, query_size): # that then requires the following containment: threshold = n_threshold_hashes / query_size + # is it too high to ever match? + if threshold > 1.0: + raise ValueError("requested threshold_bp is unattainable with this query") return threshold, n_threshold_hashes @@ -62,8 +68,8 @@ def make_jaccard_search_query(*, return search_obj -def make_gather_query(query_mh, threshold_bp, *, best_only=True): - "Make a search object for gather." +def make_containment_query(query_mh, threshold_bp, *, best_only=True): + "Make a search object for containment, with threshold_bp." if not query_mh: raise ValueError("query is empty!?") @@ -72,21 +78,7 @@ def make_gather_query(query_mh, threshold_bp, *, best_only=True): raise TypeError("query signature must be calculated with scaled") # are we setting a threshold? - threshold = 0 - if threshold_bp: - if threshold_bp < 0: - raise TypeError("threshold_bp must be non-negative") - - # if we have a threshold_bp of N, then that amounts to N/scaled - # hashes: - n_threshold_hashes = threshold_bp / scaled - - # that then requires the following containment: - threshold = n_threshold_hashes / len(query_mh) - - # is it too high to ever match? if so, exit. - if threshold > 1.0: - raise ValueError("requested threshold_bp is unattainable with this query") + threshold, _ = calc_threshold_from_bp(threshold_bp, scaled, len(query_mh)) if best_only: search_obj = JaccardSearchBestOnly(SearchType.CONTAINMENT, diff --git a/tests/test_index.py b/tests/test_index.py index e36275b092..ad04598db1 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -168,11 +168,11 @@ def test_linear_index_gather_subj_has_abundance(): linear = LinearIndex() linear.insert(ss) - results = list(linear.gather(qs, threshold=0)) - assert len(results) == 1 + result = linear.best_containment(qs, threshold=0) + assert result # note: gather returns _original_ signature, not flattened - assert results[0].signature == ss + assert result.signature == ss def test_index_search_subj_scaled_is_lower(): @@ -457,22 +457,24 @@ def test_linear_gather_threshold_1(): # query with empty hashes assert not new_mh with pytest.raises(ValueError): - linear.gather(SourmashSignature(new_mh)) + linear.best_containment(SourmashSignature(new_mh)) # add one hash new_mh.add_hash(mins.pop()) assert len(new_mh) == 1 - results = linear.gather(SourmashSignature(new_mh)) - assert len(results) == 1 - containment, match_sig, name = results[0] + result = linear.best_containment(SourmashSignature(new_mh)) + assert result + + # it's a namedtuple, so we can unpack like a tuple. + containment, match_sig, name = result assert containment == 1.0 assert match_sig == sig2 assert name is None # check with a threshold -> should be no results. with pytest.raises(ValueError): - linear.gather(SourmashSignature(new_mh), threshold_bp=5000) + linear.best_containment(SourmashSignature(new_mh), threshold_bp=5000) # add three more hashes => length of 4 new_mh.add_hash(mins.pop()) @@ -480,16 +482,16 @@ def test_linear_gather_threshold_1(): new_mh.add_hash(mins.pop()) assert len(new_mh) == 4 - results = linear.gather(SourmashSignature(new_mh)) - assert len(results) == 1 - containment, match_sig, name = results[0] + result = linear.best_containment(SourmashSignature(new_mh)) + assert result + containment, match_sig, name = result assert containment == 1.0 assert match_sig == sig2 assert name is None # check with a too-high threshold -> should be no results. with pytest.raises(ValueError): - linear.gather(SourmashSignature(new_mh), threshold_bp=5000) + linear.best_containment(SourmashSignature(new_mh), threshold_bp=5000) def test_linear_gather_threshold_5(): @@ -519,17 +521,18 @@ def test_linear_gather_threshold_5(): new_mh.add_hash(mins.pop()) # should get a result with no threshold (any match at all is returned) - results = linear.gather(SourmashSignature(new_mh)) - assert len(results) == 1 - containment, match_sig, name = results[0] + result = linear.best_containment(SourmashSignature(new_mh)) + assert result + containment, match_sig, name = result assert containment == 1.0 assert match_sig == sig2 assert name == 'foo' # now, check with a threshold_bp that should be meet-able. - results = linear.gather(SourmashSignature(new_mh), threshold_bp=5000) - assert len(results) == 1 - containment, match_sig, name = results[0] + result = linear.best_containment(SourmashSignature(new_mh), + threshold_bp=5000) + assert result + containment, match_sig, name = result assert containment == 1.0 assert match_sig == sig2 assert name == 'foo' @@ -1097,7 +1100,7 @@ def test_multi_index_search(): def test_multi_index_gather(): - # test MultiIndex.gather + # test MultiIndex.best_containment sig2 = utils.get_test_data('2.fa.sig') sig47 = utils.get_test_data('47.fa.sig') sig63 = utils.get_test_data('63.fa.sig') @@ -1115,16 +1118,16 @@ def test_multi_index_gather(): None) lidx = lidx.select(ksize=31) - matches = lidx.gather(ss2) - assert len(matches) == 1 - assert matches[0][0] == 1.0 - assert matches[0][2] == 'A' + match = lidx.best_containment(ss2) + assert match + assert match.score == 1.0 + assert match.location == 'A' - matches = lidx.gather(ss47) - assert len(matches) == 1 - assert matches[0][0] == 1.0 - assert matches[0][1] == ss47 - assert matches[0][2] == sig47 # no source override + match = lidx.best_containment(ss47) + assert match + assert match.score == 1.0 + assert match.signature == ss47 + assert match.location == sig47 # no source override def test_multi_index_signatures(): @@ -1562,7 +1565,7 @@ def test_counter_gather_test_consume(): match_ss_3 = SourmashSignature(match_mh_3, name='match3') # load up the counter - counter = CounterGather(query_ss.minhash) + counter = CounterGather(query_ss) counter.add(match_ss_1, location='loc a') counter.add(match_ss_2, location='loc b') counter.add(match_ss_3, location='loc c') @@ -1570,12 +1573,16 @@ def test_counter_gather_test_consume(): ### ok, dig into actual counts... import pprint pprint.pprint(counter.counter) - pprint.pprint(counter.siglist) + pprint.pprint(list(counter.signatures())) pprint.pprint(counter.locations) - assert counter.siglist == [ match_ss_1, match_ss_2, match_ss_3 ] - assert counter.locations == ['loc a', 'loc b', 'loc c'] - assert list(counter.counter.items()) == [(0, 10), (1, 8), (2, 4)] + assert set(counter.signatures()) == set([match_ss_1, match_ss_2, match_ss_3]) + assert list(sorted(counter.locations.values())) == ['loc a', 'loc b', 'loc c'] + pprint.pprint(counter.counter.most_common()) + assert list(counter.counter.most_common()) == \ + [('26d4943627b33c446f37be1f5baf8d46', 10), + ('f51cedec90ea666e0ebc11aa274eca61', 8), + ('f331f8279113d77e42ab8efca8f9cc17', 4)] ## round 1 @@ -1586,9 +1593,12 @@ def test_counter_gather_test_consume(): assert cur_query == query_ss.minhash counter.consume(intersect_mh) - assert counter.siglist == [ match_ss_1, match_ss_2, match_ss_3 ] - assert counter.locations == ['loc a', 'loc b', 'loc c'] - assert list(counter.counter.items()) == [(1, 5), (2, 4)] + assert set(counter.signatures()) == set([ match_ss_1, match_ss_2, match_ss_3 ]) + assert list(sorted(counter.locations.values())) == ['loc a', 'loc b', 'loc c'] + pprint.pprint(counter.counter.most_common()) + assert list(counter.counter.most_common()) == \ + [('f51cedec90ea666e0ebc11aa274eca61', 5), + ('f331f8279113d77e42ab8efca8f9cc17', 4)] ### round 2 @@ -1599,9 +1609,12 @@ def test_counter_gather_test_consume(): assert cur_query != query_ss.minhash counter.consume(intersect_mh) - assert counter.siglist == [ match_ss_1, match_ss_2, match_ss_3 ] - assert counter.locations == ['loc a', 'loc b', 'loc c'] - assert list(counter.counter.items()) == [(2, 2)] + assert set(counter.signatures()) == set([ match_ss_1, match_ss_2, match_ss_3 ]) + assert list(sorted(counter.locations.values())) == ['loc a', 'loc b', 'loc c'] + + pprint.pprint(counter.counter.most_common()) + assert list(counter.counter.most_common()) == \ + [('f331f8279113d77e42ab8efca8f9cc17', 2)] ## round 3 @@ -1612,9 +1625,10 @@ def test_counter_gather_test_consume(): assert cur_query != query_ss.minhash counter.consume(intersect_mh) - assert counter.siglist == [ match_ss_1, match_ss_2, match_ss_3 ] - assert counter.locations == ['loc a', 'loc b', 'loc c'] - assert list(counter.counter.items()) == [] + assert set(counter.signatures()) == set([ match_ss_1, match_ss_2, match_ss_3 ]) + assert list(sorted(counter.locations.values())) == ['loc a', 'loc b', 'loc c'] + pprint.pprint(counter.counter.most_common()) + assert list(counter.counter.most_common()) == [] ## round 4 - nothing left! @@ -1623,9 +1637,41 @@ def test_counter_gather_test_consume(): assert not results counter.consume(intersect_mh) - assert counter.siglist == [ match_ss_1, match_ss_2, match_ss_3 ] - assert counter.locations == ['loc a', 'loc b', 'loc c'] - assert list(counter.counter.items()) == [] + assert set(counter.signatures()) == set([ match_ss_1, match_ss_2, match_ss_3 ]) + assert list(sorted(counter.locations.values())) == ['loc a', 'loc b', 'loc c'] + assert list(counter.counter.most_common()) == [] + + +def test_counter_gather_identical_md5sum(): + # open-box testing of CounterGather.consume(...) + # check what happens with identical matches w/different names + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_mh.add_many(range(0, 20)) + query_ss = SourmashSignature(query_mh, name='query') + + match_mh_1 = query_mh.copy_and_clear() + match_mh_1.add_many(range(0, 10)) + match_ss_1 = SourmashSignature(match_mh_1, name='match1') + + # same as match_mh_1 + match_mh_2 = query_mh.copy_and_clear() + match_mh_2.add_many(range(0, 10)) + match_ss_2 = SourmashSignature(match_mh_2, name='match2') + + # identical md5sum + assert match_ss_1.md5sum() == match_ss_2.md5sum() + + # load up the counter + counter = CounterGather(query_ss) + counter.add(match_ss_1, location='loc a') + counter.add(match_ss_2, location='loc b') + + assert len(counter.siglist) == 1 + stored_match = list(counter.siglist.values()).pop() + assert stored_match.name == 'match2' + # CTB note: this behavior may be changed freely, as the protocol + # tests simply specify that _one_ of the identical matches is + # returned. See test_counter_gather_multiple_identical_matches. def test_lazy_index_1(): @@ -1773,7 +1819,7 @@ def test_revindex_index_search(): def test_revindex_gather(): - # check that RevIndex.gather works. + # check that RevIndex.best_containment works. sig2 = utils.get_test_data("2.fa.sig") sig47 = utils.get_test_data("47.fa.sig") sig63 = utils.get_test_data("63.fa.sig") @@ -1787,15 +1833,15 @@ def test_revindex_gather(): lidx.insert(ss47) lidx.insert(ss63) - matches = lidx.gather(ss2) - assert len(matches) == 1 - assert matches[0][0] == 1.0 - assert matches[0][1] == ss2 + match = lidx.best_containment(ss2) + assert match + assert match.score == 1.0 + assert match.signature == ss2 - matches = lidx.gather(ss47) - assert len(matches) == 1 - assert matches[0][0] == 1.0 - assert matches[0][1] == ss47 + match = lidx.best_containment(ss47) + assert match + assert match.score == 1.0 + assert match.signature == ss47 def test_revindex_gather_ignore(): diff --git a/tests/test_index_protocol.py b/tests/test_index_protocol.py index 19f27788c8..e69a9b0c68 100644 --- a/tests/test_index_protocol.py +++ b/tests/test_index_protocol.py @@ -10,13 +10,16 @@ from sourmash import SourmashSignature from sourmash.index import (LinearIndex, ZipFileLinearIndex, LazyLinearIndex, MultiIndex, - StandaloneManifestIndex) + StandaloneManifestIndex, + IndexSearchResult) from sourmash.index import CounterGather from sourmash.index.sqlite_index import SqliteIndex from sourmash.index.revindex import RevIndex from sourmash.sbt import SBT, GraphFactory from sourmash.manifest import CollectionManifest, BaseCollectionManifest from sourmash.lca.lca_db import LCA_Database, load_single_database +from sourmash.minhash import (flatten_and_intersect_scaled, + flatten_and_downsample_scaled) import sourmash_tst_utils as utils @@ -365,23 +368,23 @@ def test_index_prefetch(index_obj): assert results[1].signature.minhash == ss63.minhash -def test_index_gather(index_obj): - # test basic gather +def test_index_best_containment(index_obj): + # test basic containment search ss2, ss47, ss63 = _load_three_sigs() - matches = index_obj.gather(ss2) - assert len(matches) == 1 - assert matches[0].score == 1.0 - assert matches[0].signature.minhash == ss2.minhash + match = index_obj.best_containment(ss2) + assert match + assert match.score == 1.0 + assert match.signature.minhash == ss2.minhash - matches = index_obj.gather(ss47) - assert len(matches) == 1 - assert matches[0].score == 1.0 - assert matches[0].signature.minhash == ss47.minhash + match = index_obj.best_containment(ss47) + assert match + assert match.score == 1.0 + assert match.signature.minhash == ss47.minhash -def test_index_gather_threshold_1(index_obj): - # test gather() method, in some detail +def test_index_best_containment_threshold_1(index_obj): + # test best_containment() method, in some detail ss2, ss47, ss63 = _load_three_sigs() # now construct query signatures with specific numbers of hashes -- @@ -393,21 +396,21 @@ def test_index_gather_threshold_1(index_obj): # query with empty hashes assert not new_mh with pytest.raises(ValueError): - index_obj.gather(SourmashSignature(new_mh)) + index_obj.best_containment(SourmashSignature(new_mh)) # add one hash new_mh.add_hash(mins.pop()) assert len(new_mh) == 1 - results = index_obj.gather(SourmashSignature(new_mh)) - assert len(results) == 1 - containment, match_sig, name = results[0] + result = index_obj.best_containment(SourmashSignature(new_mh)) + assert result + containment, match_sig, name = result assert containment == 1.0 assert match_sig.minhash == ss2.minhash # check with a threshold -> should be no results. with pytest.raises(ValueError): - index_obj.gather(SourmashSignature(new_mh), threshold_bp=5000) + index_obj.best_containment(SourmashSignature(new_mh), threshold_bp=5000) # add three more hashes => length of 4 new_mh.add_hash(mins.pop()) @@ -415,18 +418,18 @@ def test_index_gather_threshold_1(index_obj): new_mh.add_hash(mins.pop()) assert len(new_mh) == 4 - results = index_obj.gather(SourmashSignature(new_mh)) - assert len(results) == 1 - containment, match_sig, name = results[0] + result = index_obj.best_containment(SourmashSignature(new_mh)) + assert result + containment, match_sig, name = result assert containment == 1.0 assert match_sig.minhash == ss2.minhash # check with a too-high threshold -> should be no results. with pytest.raises(ValueError): - index_obj.gather(SourmashSignature(new_mh), threshold_bp=5000) + index_obj.best_containment(SourmashSignature(new_mh), threshold_bp=5000) -def test_gather_threshold_5(index_obj): +def test_best_containment_threshold_5(index_obj): # test gather() method, in some detail ss2, ss47, ss63 = _load_three_sigs() @@ -445,16 +448,16 @@ def test_gather_threshold_5(index_obj): new_mh.add_hash(mins.pop()) # should get a result with no threshold (any match at all is returned) - results = index_obj.gather(SourmashSignature(new_mh)) - assert len(results) == 1 - containment, match_sig, name = results[0] + result = index_obj.best_containment(SourmashSignature(new_mh)) + assert result + containment, match_sig, name = result assert containment == 1.0 assert match_sig.minhash == ss2.minhash # now, check with a threshold_bp that should be meet-able. - results = index_obj.gather(SourmashSignature(new_mh), threshold_bp=5000) - assert len(results) == 1 - containment, match_sig, name = results[0] + result = index_obj.best_containment(SourmashSignature(new_mh), threshold_bp=5000) + assert result + containment, match_sig, name = result assert containment == 1.0 assert match_sig.minhash == ss2.minhash @@ -474,8 +477,9 @@ class CounterGather_LinearIndex: Provides an (inefficient) CounterGather-style class, for protocol testing purposes. """ - def __init__(self, orig_query_mh): - "Constructor - take a FracMinHash that is the original query." + def __init__(self, orig_query): + "Constructor - take a SourmashSignature that is the original query." + orig_query_mh = orig_query.minhash if orig_query_mh.scaled == 0: raise ValueError @@ -510,6 +514,10 @@ def add(self, ss, *, location=None, require_overlap=True): self.idx.insert(ss) self.locations[md5] = location + def signatures(self): + "Yield all signatures" + return self.idx.signatures() + def downsample(self, scaled): "Track highest scaled across all possible matches." if scaled > self.scaled: @@ -521,11 +529,11 @@ def peek(self, cur_query_mh, *, threshold_bp=0): Find best match to current query within this CounterGather object. """ self.query_started = 1 - cur_query_mh = cur_query_mh.flatten() + scaled = self.downsample(cur_query_mh.scaled) - cur_query_mh = cur_query_mh.downsample(scaled=scaled) + cur_query_mh = flatten_and_downsample_scaled(cur_query_mh, scaled) - # no match? exit. + # no hashes remaining? exit. if not self.orig_query_mh or not cur_query_mh: return [] @@ -539,20 +547,110 @@ def peek(self, cur_query_mh, *, threshold_bp=0): return [] sr, intersect_mh = res - from sourmash.index import IndexSearchResult + # got match - replace location & return. match = sr.signature md5 = match.md5sum() location = self.locations[md5] - new_sr = IndexSearchResult(sr.score, match, location) - return new_sr, intersect_mh + return IndexSearchResult(sr.score, match, location), intersect_mh def consume(self, *args, **kwargs): self.query_started = 1 return self.idx.consume(*args, **kwargs) +class CounterGather_LCA: + """ + Provides an alternative implementation of a CounterGather-style class, + based on LCA_Database. This is currently just for protocol + and API testing purposes. + """ + def __init__(self, query): + from sourmash.lca.lca_db import LCA_Database + + query_mh = query.minhash + if query_mh.scaled == 0: + raise ValueError("must use scaled MinHash") + + self.orig_query_mh = query_mh + lca_db = LCA_Database(query_mh.ksize, query_mh.scaled, + query_mh.moltype) + self.db = lca_db + self.siglist = {} + self.locations = {} + self.query_started = 0 + + def add(self, ss, *, location=None, require_overlap=True): + "Add this signature into the counter." + if self.query_started: + raise ValueError("cannot add more signatures to counter after peek/consume") + + overlap = self.orig_query_mh.count_common(ss.minhash, True) + if overlap: + self.downsample(ss.minhash.scaled) + elif require_overlap: + raise ValueError("no overlap between query and signature!?") + + self.db.insert(ss) + + md5 = ss.md5sum() + self.siglist[md5] = ss + self.locations[md5] = location + + def signatures(self): + "Yield all signatures." + for ss in self.siglist.values(): + yield ss + + def downsample(self, scaled): + "Track highest scaled across all possible matches." + if scaled > self.db.scaled: + self.db.downsample_scaled(scaled) + return self.db.scaled + + def peek(self, query_mh, *, threshold_bp=0): + "Return next possible match." + from sourmash import SourmashSignature + + self.query_started = 1 + scaled = self.downsample(query_mh.scaled) + query_mh = query_mh.downsample(scaled=scaled) + + if not self.orig_query_mh or not query_mh: + return [] + + if query_mh.contained_by(self.orig_query_mh, downsample=True) < 1: + raise ValueError("current query not a subset of original query") + + query_ss = SourmashSignature(query_mh) + + # returns search_result, intersect_mh + try: + result = self.db.best_containment(query_ss, threshold_bp=threshold_bp) + except ValueError: + result = None + + if not result: + return [] + + cont = result.score + match = result.signature + + intersect_mh = flatten_and_intersect_scaled(result.signature.minhash, + query_mh) + + md5 = result.signature.md5sum() + location = self.locations[md5] + + new_sr = IndexSearchResult(cont, match, location) + return [new_sr, intersect_mh] + + def consume(self, intersect_mh): + self.query_started = 1 + + @pytest.fixture(params=[CounterGather, CounterGather_LinearIndex, + CounterGather_LCA, ] ) def counter_gather_constructor(request): @@ -562,6 +660,36 @@ def counter_gather_constructor(request): return build_fn +def test_counter_get_signatures(counter_gather_constructor): + # test .signatures() method + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_mh.add_many(range(0, 20)) + query_ss = SourmashSignature(query_mh, name='query') + + match_mh_1 = query_mh.copy_and_clear() + match_mh_1.add_many(range(0, 10)) + match_ss_1 = SourmashSignature(match_mh_1, name='match1') + + match_mh_2 = query_mh.copy_and_clear() + match_mh_2.add_many(range(10, 15)) + match_ss_2 = SourmashSignature(match_mh_2, name='match2') + + match_mh_3 = query_mh.copy_and_clear() + match_mh_3.add_many(range(15, 17)) + match_ss_3 = SourmashSignature(match_mh_3, name='match3') + + counter = counter_gather_constructor(query_ss) + counter.add(match_ss_1) + counter.add(match_ss_2) + counter.add(match_ss_3) + + siglist = list(counter.signatures()) + assert len(siglist) == 3 + assert match_ss_1 in siglist + assert match_ss_2 in siglist + assert match_ss_3 in siglist + + def _consume_all(query_mh, counter, threshold_bp=0): results = [] query_mh = query_mh.to_mutable() @@ -607,7 +735,7 @@ def test_counter_gather_1(counter_gather_constructor): match_ss_3 = SourmashSignature(match_mh_3, name='match3') # load up the counter - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) counter.add(match_ss_1) counter.add(match_ss_2) counter.add(match_ss_3) @@ -649,7 +777,7 @@ def test_counter_gather_1_b(counter_gather_constructor): match_ss_3 = SourmashSignature(match_mh_3, name='match3') # load up the counter - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) counter.add(match_ss_1) counter.add(match_ss_2) counter.add(match_ss_3) @@ -693,7 +821,7 @@ def test_counter_gather_1_c_with_threshold(counter_gather_constructor): match_ss_3 = SourmashSignature(match_mh_3, name='match3') # load up the counter - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) counter.add(match_ss_1) counter.add(match_ss_2) counter.add(match_ss_3) @@ -731,7 +859,7 @@ def test_counter_gather_1_d_diff_scaled(counter_gather_constructor): match_ss_3 = SourmashSignature(match_mh_3, name='match3') # load up the counter - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) counter.add(match_ss_1) counter.add(match_ss_2) counter.add(match_ss_3) @@ -771,7 +899,7 @@ def test_counter_gather_1_d_diff_scaled_query(counter_gather_constructor): query_ss = SourmashSignature(query_mh.downsample(scaled=100), name='query') # load up the counter - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) counter.add(match_ss_1) counter.add(match_ss_2) counter.add(match_ss_3) @@ -809,7 +937,7 @@ def test_counter_gather_1_e_abund_query(counter_gather_constructor): match_ss_3 = SourmashSignature(match_mh_3, name='match3') # load up the counter - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) counter.add(match_ss_1) counter.add(match_ss_2) counter.add(match_ss_3) @@ -848,7 +976,7 @@ def test_counter_gather_1_f_abund_match(counter_gather_constructor): match_ss_3 = SourmashSignature(match_mh_3, name='match3') # load up the counter - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) counter.add(match_ss_1) counter.add(match_ss_2) counter.add(match_ss_3) @@ -880,7 +1008,7 @@ def test_counter_gather_2(counter_gather_constructor): for t in testdata_sigs ] # load up the counter - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) for ss, loc in subject_sigs: counter.add(ss, location=loc) @@ -915,7 +1043,7 @@ def test_counter_gather_exact_match(counter_gather_constructor): query_ss = SourmashSignature(query_mh, name='query') # load up the counter; provide a location override, too. - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) counter.add(query_ss, location='somewhere over the rainbow') results = _consume_all(query_ss.minhash, counter) @@ -934,7 +1062,7 @@ def test_counter_gather_multiple_identical_matches(counter_gather_constructor): query_ss = SourmashSignature(query_mh, name='query') # create counter... - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) # now add multiple identical matches. match_mh = query_mh.copy_and_clear() @@ -962,7 +1090,7 @@ def test_counter_gather_add_after_peek(counter_gather_constructor): query_ss = SourmashSignature(query_mh, name='query') # load up the counter - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) counter.add(query_ss, location='somewhere over the rainbow') counter.peek(query_ss.minhash) @@ -978,7 +1106,7 @@ def test_counter_gather_add_after_consume(counter_gather_constructor): query_ss = SourmashSignature(query_mh, name='query') # load up the counter - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) counter.add(query_ss, location='somewhere over the rainbow') counter.consume(query_ss.minhash) @@ -994,7 +1122,7 @@ def test_counter_gather_consume_empty_intersect(counter_gather_constructor): query_ss = SourmashSignature(query_mh, name='query') # load up the counter - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) counter.add(query_ss, location='somewhere over the rainbow') # nothing really happens here :laugh:, just making sure there's no error @@ -1011,7 +1139,7 @@ def test_counter_gather_empty_initial_query(counter_gather_constructor): match_ss_1 = SourmashSignature(match_mh_1, name='match1') # load up the counter - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) counter.add(match_ss_1, require_overlap=False) assert counter.peek(query_ss.minhash) == [] @@ -1024,7 +1152,7 @@ def test_counter_gather_num_query(counter_gather_constructor): query_ss = SourmashSignature(query_mh, name='query') with pytest.raises(ValueError): - counter_gather_constructor(query_ss.minhash) + counter_gather_constructor(query_ss) def test_counter_gather_empty_cur_query(counter_gather_constructor): @@ -1034,7 +1162,7 @@ def test_counter_gather_empty_cur_query(counter_gather_constructor): query_ss = SourmashSignature(query_mh, name='query') # load up the counter - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) counter.add(query_ss, location='somewhere over the rainbow') cur_query_mh = query_ss.minhash.copy_and_clear() @@ -1053,7 +1181,7 @@ def test_counter_gather_add_num_matchy(counter_gather_constructor): match_ss = SourmashSignature(match_mh, name='query') # load up the counter - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) with pytest.raises(ValueError): counter.add(match_ss, location='somewhere over the rainbow') @@ -1065,7 +1193,7 @@ def test_counter_gather_bad_cur_query(counter_gather_constructor): query_ss = SourmashSignature(query_mh, name='query') # load up the counter - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) counter.add(query_ss, location='somewhere over the rainbow') cur_query_mh = query_ss.minhash.copy_and_clear() @@ -1085,7 +1213,7 @@ def test_counter_gather_add_no_overlap(counter_gather_constructor): match_ss_1 = SourmashSignature(match_mh_1, name='match1') # load up the counter - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) with pytest.raises(ValueError): counter.add(match_ss_1) @@ -1103,7 +1231,7 @@ def test_counter_gather_big_threshold(counter_gather_constructor): match_ss_1 = SourmashSignature(match_mh_1, name='match1') # load up the counter - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) counter.add(match_ss_1) # impossible threshold: @@ -1118,6 +1246,6 @@ def test_counter_gather_empty_counter(counter_gather_constructor): query_ss = SourmashSignature(query_mh, name='query') # empty counter! - counter = counter_gather_constructor(query_ss.minhash) + counter = counter_gather_constructor(query_ss) assert counter.peek(query_ss.minhash) == [] diff --git a/tests/test_lca.py b/tests/test_lca.py index af9778c76b..6602d46d86 100644 --- a/tests/test_lca.py +++ b/tests/test_lca.py @@ -335,10 +335,10 @@ def test_api_create_gather(): lca_db = sourmash.lca.LCA_Database(ksize=31, scaled=1000) lca_db.insert(ss) - results = lca_db.gather(ss, threshold_bp=0) - print(results) - assert len(results) == 1 - (similarity, match, filename) = results[0] + result = lca_db.best_containment(ss, threshold_bp=0) + print(result) + assert result + (similarity, match, filename) = result assert match.minhash == ss.minhash @@ -682,8 +682,8 @@ def test_search_db_scaled_lt_sig_scaled(): results = db.search(sig, threshold=.01, ignore_abundance=True) print(results) - assert results[0][0] == 1.0 - match = results[0][1] + assert results[0].score == 1.0 + match = results[0].signature orig_sig = sourmash.load_one_signature(utils.get_test_data('47.fa.sig')) assert orig_sig.minhash.jaccard(match.minhash, downsample=True) == 1.0 @@ -694,8 +694,8 @@ def test_gather_db_scaled_gt_sig_scaled(): db, ksize, scaled = lca_utils.load_single_database(dbfile) sig = sourmash.load_one_signature(utils.get_test_data('47.fa.sig')) - results = db.gather(sig, threshold=.01, ignore_abundance=True) - match_sig = results[0][1] + result = db.best_containment(sig, threshold=.01, ignore_abundance=True) + match_sig = result[1] sig.minhash = sig.minhash.downsample(scaled=10000) assert sig.minhash == match_sig.minhash @@ -707,8 +707,8 @@ def test_gather_db_scaled_lt_sig_scaled(): sig = sourmash.load_one_signature(utils.get_test_data('47.fa.sig')) sig.minhash = sig.minhash.downsample(scaled=100000) - results = db.gather(sig, threshold=.01, ignore_abundance=True) - match_sig = results[0][1] + result = db.best_containment(sig, threshold=.01, ignore_abundance=True) + match_sig = result[1] match_sig.minhash = match_sig.minhash.downsample(scaled=100000) assert sig.minhash == match_sig.minhash @@ -2386,9 +2386,9 @@ def test_lca_index_empty(runtmp, lca_db_format): lca_db_filename = c.output(f'xxx.lca.{lca_db_format}') db, ksize, scaled = lca_utils.load_single_database(lca_db_filename) - results = db.gather(sig63) - assert len(results) == 1 - containment, match_sig, name = results[0] + result = db.best_containment(sig63) + assert result + containment, match_sig, name = result assert containment == 1.0 assert match_sig.minhash == sig63.minhash assert name == lca_db_filename @@ -2419,22 +2419,22 @@ def test_lca_gather_threshold_1(): # query with empty hashes assert not new_mh with pytest.raises(ValueError): - db.gather(SourmashSignature(new_mh)) + db.best_containment(SourmashSignature(new_mh)) # add one hash new_mh.add_hash(mins.pop()) assert len(new_mh) == 1 - results = db.gather(SourmashSignature(new_mh)) - assert len(results) == 1 - containment, match_sig, name = results[0] + result = db.best_containment(SourmashSignature(new_mh)) + assert result + containment, match_sig, name = result assert containment == 1.0 assert match_sig.minhash == sig2.minhash assert name == None # check with a threshold -> should be no results. with pytest.raises(ValueError): - db.gather(SourmashSignature(new_mh), threshold_bp=5000) + db.best_containment(SourmashSignature(new_mh), threshold_bp=5000) # add three more hashes => length of 4 new_mh.add_hash(mins.pop()) @@ -2442,16 +2442,16 @@ def test_lca_gather_threshold_1(): new_mh.add_hash(mins.pop()) assert len(new_mh) == 4 - results = db.gather(SourmashSignature(new_mh)) - assert len(results) == 1 - containment, match_sig, name = results[0] + result = db.best_containment(SourmashSignature(new_mh)) + assert result + containment, match_sig, name = result assert containment == 1.0 assert match_sig.minhash == sig2.minhash assert name == None # check with a too-high threshold -> should be no results. with pytest.raises(ValueError): - db.gather(SourmashSignature(new_mh), threshold_bp=5000) + db.best_containment(SourmashSignature(new_mh), threshold_bp=5000) def test_lca_gather_threshold_5(): @@ -2485,17 +2485,17 @@ def test_lca_gather_threshold_5(): new_mh.add_hash(mins.pop()) # should get a result with no threshold (any match at all is returned) - results = db.gather(SourmashSignature(new_mh)) - assert len(results) == 1 - containment, match_sig, name = results[0] + result = db.best_containment(SourmashSignature(new_mh)) + assert result + containment, match_sig, name = result assert containment == 1.0 assert match_sig.minhash == sig2.minhash assert name == None # now, check with a threshold_bp that should be meet-able. - results = db.gather(SourmashSignature(new_mh), threshold_bp=5000) - assert len(results) == 1 - containment, match_sig, name = results[0] + result = db.best_containment(SourmashSignature(new_mh), threshold_bp=5000) + assert result + containment, match_sig, name = result assert containment == 1.0 assert match_sig.minhash == sig2.minhash assert name == None @@ -2518,10 +2518,10 @@ def test_gather_multiple_return(): # now, run gather. how many results do we get, and are they in the # right order? - results = db.gather(sig63) - print(len(results)) - assert len(results) == 1 - assert results[0][0] == 1.0 + result = db.best_containment(sig63) + print(result) + assert result + assert result.score == 1.0 def test_lca_db_protein_build(): @@ -2546,8 +2546,8 @@ def test_lca_db_protein_build(): results = db.search(sig1, threshold=0.0) assert len(results) == 2 - results = db.gather(sig2) - assert results[0][0] == 1.0 + result = db.best_containment(sig2) + assert result.score == 1.0 @utils.in_tempdir @@ -2582,8 +2582,8 @@ def test_lca_db_protein_save_load(c): results = db2.search(sig1, threshold=0.0) assert len(results) == 2 - results = db2.gather(sig2) - assert results[0][0] == 1.0 + result = db2.best_containment(sig2) + assert result.score == 1.0 def test_lca_db_protein_command_index(runtmp, lca_db_format): @@ -2618,8 +2618,8 @@ def test_lca_db_protein_command_index(runtmp, lca_db_format): results = db2.search(sig1, threshold=0.0) assert len(results) == 2 - results = db2.gather(sig2) - assert results[0][0] == 1.0 + result = db2.best_containment(sig2) + assert result.score == 1.0 @utils.in_thisdir @@ -2659,8 +2659,8 @@ def test_lca_db_hp_build(): results = db.search(sig1, threshold=0.0) assert len(results) == 2 - results = db.gather(sig2) - assert results[0][0] == 1.0 + result = db.best_containment(sig2) + assert result.score == 1.0 @utils.in_tempdir @@ -2693,8 +2693,8 @@ def test_lca_db_hp_save_load(c): results = db2.search(sig1, threshold=0.0) assert len(results) == 2 - results = db2.gather(sig2) - assert results[0][0] == 1.0 + result = db2.best_containment(sig2) + assert result.score == 1.0 def test_lca_db_hp_command_index(runtmp, lca_db_format): @@ -2729,8 +2729,8 @@ def test_lca_db_hp_command_index(runtmp, lca_db_format): results = db2.search(sig1, threshold=0.0) assert len(results) == 2 - results = db2.gather(sig2) - assert results[0][0] == 1.0 + result = db2.best_containment(sig2) + assert result.score == 1.0 @utils.in_thisdir @@ -2770,8 +2770,8 @@ def test_lca_db_dayhoff_build(): results = db.search(sig1, threshold=0.0) assert len(results) == 2 - results = db.gather(sig2) - assert results[0][0] == 1.0 + result = db.best_containment(sig2) + assert result.score == 1.0 @utils.in_tempdir @@ -2804,8 +2804,8 @@ def test_lca_db_dayhoff_save_load(c): results = db2.search(sig1, threshold=0.0) assert len(results) == 2 - results = db2.gather(sig2) - assert results[0][0] == 1.0 + result = db2.best_containment(sig2) + assert result.score == 1.0 def test_lca_db_dayhoff_command_index(runtmp, lca_db_format): @@ -2840,8 +2840,8 @@ def test_lca_db_dayhoff_command_index(runtmp, lca_db_format): results = db2.search(sig1, threshold=0.0) assert len(results) == 2 - results = db2.gather(sig2) - assert results[0][0] == 1.0 + result = db2.best_containment(sig2) + assert result.score == 1.0 @utils.in_thisdir diff --git a/tests/test_sbt.py b/tests/test_sbt.py index 9d9ba7273a..3c83915e9c 100644 --- a/tests/test_sbt.py +++ b/tests/test_sbt.py @@ -830,22 +830,22 @@ def test_sbt_gather_threshold_1(): # query with empty hashes assert not new_mh with pytest.raises(ValueError): - tree.gather(SourmashSignature(new_mh)) + tree.best_containment(SourmashSignature(new_mh)) # add one hash new_mh.add_hash(mins.pop()) assert len(new_mh) == 1 - results = tree.gather(SourmashSignature(new_mh)) - assert len(results) == 1 - containment, match_sig, name = results[0] + result = tree.best_containment(SourmashSignature(new_mh)) + assert result + containment, match_sig, name = result assert containment == 1.0 assert match_sig == sig2 assert name is None # check with a threshold -> should be no results. with pytest.raises(ValueError): - tree.gather(SourmashSignature(new_mh), threshold_bp=5000) + tree.best_containment(SourmashSignature(new_mh), threshold_bp=5000) # add three more hashes => length of 4 new_mh.add_hash(mins.pop()) @@ -853,9 +853,9 @@ def test_sbt_gather_threshold_1(): new_mh.add_hash(mins.pop()) assert len(new_mh) == 4 - results = tree.gather(SourmashSignature(new_mh)) - assert len(results) == 1 - containment, match_sig, name = results[0] + result = tree.best_containment(SourmashSignature(new_mh)) + assert result + containment, match_sig, name = result assert containment == 1.0 assert match_sig == sig2 assert name is None @@ -863,7 +863,7 @@ def test_sbt_gather_threshold_1(): # check with a too-high threshold -> should be no results. print('len mh', len(new_mh)) with pytest.raises(ValueError): - tree.gather(SourmashSignature(new_mh), threshold_bp=5000) + tree.best_containment(SourmashSignature(new_mh), threshold_bp=5000) def test_sbt_gather_threshold_5(): @@ -894,17 +894,17 @@ def test_sbt_gather_threshold_5(): new_mh.add_hash(mins.pop()) # should get a result with no threshold (any match at all is returned) - results = tree.gather(SourmashSignature(new_mh)) - assert len(results) == 1 - containment, match_sig, name = results[0] + result = tree.best_containment(SourmashSignature(new_mh)) + assert result + containment, match_sig, name = result assert containment == 1.0 assert match_sig == sig2 assert name is None # now, check with a threshold_bp that should be meet-able. - results = tree.gather(SourmashSignature(new_mh), threshold_bp=5000) - assert len(results) == 1 - containment, match_sig, name = results[0] + results = tree.best_containment(SourmashSignature(new_mh), threshold_bp=5000) + assert result + containment, match_sig, name = result assert containment == 1.0 assert match_sig == sig2 assert name is None @@ -931,10 +931,10 @@ def test_gather_single_return(c): # now, run gather. how many results do we get, and are they in the # right order? - results = tree.gather(sig63) - print(len(results)) - assert len(results) == 1 - assert results[0][0] == 1.0 + result = tree.best_containment(sig63) + print(result) + assert result + assert result.score == 1.0 def test_sbt_jaccard_ordering(runtmp): @@ -1015,10 +1015,10 @@ def test_sbt_protein_command_index(runtmp): do_containment=False, best_only=False) assert len(results) == 2 - results = db2.gather(sig2) - assert results[0][0] == 1.0 - assert results[0][2] == db2._location - assert results[0][2] == db_out + result = db2.best_containment(sig2) + assert result.score == 1.0 + assert result.location == db2._location + assert result.location == db_out @utils.in_tempdir @@ -1081,12 +1081,12 @@ def test_sbt_hp_command_index(c): # and search, gather results = db2.search(sig1, threshold=0.0, ignore_abundance=True, do_containment=False, best_only=False) - assert len(results) == 2 + assert results - results = db2.gather(sig2) - assert results[0][0] == 1.0 - assert results[0][2] == db2._location - assert results[0][2] == db_out + result = db2.best_containment(sig2) + assert result.score == 1.0 + assert result.location == db2._location + assert result.location == db_out @utils.in_thisdir @@ -1130,10 +1130,10 @@ def test_sbt_dayhoff_command_index(c): do_containment=False, best_only=False) assert len(results) == 2 - results = db2.gather(sig2) - assert results[0][0] == 1.0 - assert results[0][2] == db2._location - assert results[0][2] == db_out + result = db2.best_containment(sig2) + assert result.score == 1.0 + assert result.location == db2._location + assert result.location == db_out @utils.in_thisdir diff --git a/tests/test_search.py b/tests/test_search.py index 0d765e0f96..55fbf4a73a 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1,13 +1,12 @@ "Tests for search.py code." -# CTB TODO: test search protocol with mock class? - import pytest import numpy as np import sourmash_tst_utils as utils from sourmash import search, SourmashSignature, MinHash, load_one_signature -from sourmash.search import (make_jaccard_search_query, make_gather_query, +from sourmash.search import (make_jaccard_search_query, + make_containment_query, SearchResult, PrefetchResult, GatherResult) from sourmash.index import LinearIndex @@ -129,35 +128,35 @@ def test_collect_best_only(): assert search_obj.threshold == 1.0 -def test_make_gather_query(): - # test basic make_gather_query call +def test_make_containment_query(): + # test basic make_containment_query call mh = MinHash(n=0, ksize=31, scaled=1000) for i in range(100): mh.add_hash(i) - search_obj = make_gather_query(mh, 5e4) + search_obj = make_containment_query(mh, 5e4) assert search_obj.score_fn == search_obj.score_containment assert search_obj.require_scaled assert search_obj.threshold == 0.5 -def test_make_gather_query_no_threshold(): - # test basic make_gather_query call +def test_make_containment_query_no_threshold(): + # test basic make_containment_query call mh = MinHash(n=0, ksize=31, scaled=1000) for i in range(100): mh.add_hash(i) - search_obj = make_gather_query(mh, None) + search_obj = make_containment_query(mh, None) assert search_obj.score_fn == search_obj.score_containment assert search_obj.require_scaled assert search_obj.threshold == 0 -def test_make_gather_query_num_minhash(): +def test_make_containment_query_num_minhash(): # will fail on non-scaled minhash mh = MinHash(n=500, ksize=31) @@ -165,12 +164,12 @@ def test_make_gather_query_num_minhash(): mh.add_hash(i) with pytest.raises(TypeError) as exc: - search_obj = make_gather_query(mh, 5e4) + search_obj = make_containment_query(mh, 5e4) assert str(exc.value) == "query signature must be calculated with scaled" -def test_make_gather_query_empty_minhash(): +def test_make_containment_query_empty_minhash(): # will fail on non-scaled minhash mh = MinHash(n=0, ksize=31, scaled=1000) @@ -178,12 +177,12 @@ def test_make_gather_query_empty_minhash(): mh.add_hash(i) with pytest.raises(TypeError) as exc: - search_obj = make_gather_query(mh, -1) + search_obj = make_containment_query(mh, -1) assert str(exc.value) == "threshold_bp must be non-negative" -def test_make_gather_query_high_threshold(): +def test_make_containment_query_high_threshold(): # will fail on non-scaled minhash mh = MinHash(n=0, ksize=31, scaled=1000) @@ -192,7 +191,7 @@ def test_make_gather_query_high_threshold(): # effective threshold > 1; raise ValueError with pytest.raises(ValueError): - search_obj = make_gather_query(mh, 200000) + search_obj = make_containment_query(mh, 200000) class FakeIndex(LinearIndex): @@ -223,8 +222,8 @@ def validate_kwarg_passthru(search_fn, query, args, kwargs): idx.search(query, threshold=0.0, this_kw_arg=5) -def test_index_gather_passthru(): - # check that kwargs are passed through from 'gather' to 'find' +def test_index_containment_passthru(): + # check that kwargs are passed through from 'search' to 'find' query = None def validate_kwarg_passthru(search_fn, query, args, kwargs):