Skip to content

Commit

Permalink
Rework the find functionality for Index classes (#1392)
Browse files Browse the repository at this point in the history
* have the 'find' function for SBTs return signatures
* fix majority of tests
* split find and _find_nodes to take different kinds of functions
* redo 'find' on index
* refactor lca_db to use new find
* refactor SBT to use new find
* refactor out common code
* use 'passes' properly
* adjust tree downsampling for regular minhashes, too
* remove now-unused search functions in sbtmh
* refactor categorize to use new find
* fix jaccard calculation in sbt
* check for compatibility of search fn and query signature
* switch tests over to jaccard similarity, not containment
* remove test for unimplemented LCA_Database.find method
* document threshold change; update test
* refuse to run abund signatures
* flatten sigs internally for gather
* reinflate abundances for saving
* fix problem where sbt indices coudl be created with abund signatures
* split flat and abund search
* make ignore_abundance work again for categorize
* turn off best-only, since it triggers on self-hits.
* add test: 'sourmash index' flattens sigs
* location is now a property
* move search code into search.py
* remove redundant scaled checking code
* best-only now works properly for two tests
* 'fix' tests by removing v1 and v2 SBT compatibility
* simplify downsampling code
* require keyword args in MinHash.downsample(...)
* fix test to use proper downsampling, reverse order to match scaled
* flatten subject MinHash, too
* add IndexSearchResult namedtuple for search and gather results
* add more tests for Index classes
* add tests for subj & query num downsampling
* tests for Index.search_abund
* refactor make_jaccard_search_query; start tests
* test collect, best_only
* deal with status == None on SystemExit
* upgrade and simplify categorize
* fix abundance search in SBT for categorize
* add explicit test for incompatible num
* add simple tests for SBT load and search API
* allow arbitrary kwargs for LCA_DAtabase.find
* add testing of passthru-kwargs
* docstring updates
* better tests for gather --save-unassigned
* SBT search doesn't work on v1 and v2 SBTs b/c no min_n_below
* add intersection_and_union_size method to MinHash
* make flatten a no-op if track_abundance=False
* intersection_union_size in the FFI

Co-authored-by: Luiz Irber <luiz.irber@gmail.com>
  • Loading branch information
ctb and luizirber authored Apr 22, 2021
1 parent eb2b210 commit f02e250
Show file tree
Hide file tree
Showing 18 changed files with 1,419 additions and 659 deletions.
7 changes: 6 additions & 1 deletion include/sourmash.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,12 @@ void kmerminhash_hash_function_set(SourmashKmerMinHash *ptr, HashFunctions hash_

bool kmerminhash_hp(const SourmashKmerMinHash *ptr);

uint64_t kmerminhash_intersection(const SourmashKmerMinHash *ptr, const SourmashKmerMinHash *other);
SourmashKmerMinHash *kmerminhash_intersection(const SourmashKmerMinHash *ptr,
const SourmashKmerMinHash *other);

uint64_t kmerminhash_intersection_union_size(const SourmashKmerMinHash *ptr,
const SourmashKmerMinHash *other,
uint64_t *union_size);

bool kmerminhash_is_compatible(const SourmashKmerMinHash *ptr, const SourmashKmerMinHash *other);

Expand Down
21 changes: 19 additions & 2 deletions src/core/src/ffi/minhash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,14 +377,31 @@ unsafe fn kmerminhash_count_common(ptr: *const SourmashKmerMinHash, other: *cons

ffi_fn! {
unsafe fn kmerminhash_intersection(ptr: *const SourmashKmerMinHash, other: *const SourmashKmerMinHash)
-> Result<*mut SourmashKmerMinHash> {
let mh = SourmashKmerMinHash::as_rust(ptr);
let other_mh = SourmashKmerMinHash::as_rust(other);

let isect = mh.intersection(other_mh)?;
let mut new_mh = mh.clone();
new_mh.clear();
new_mh.add_many(&isect.0)?;

Ok(SourmashKmerMinHash::from_rust(new_mh))
}
}

ffi_fn! {
unsafe fn kmerminhash_intersection_union_size(ptr: *const SourmashKmerMinHash, other: *const SourmashKmerMinHash, union_size: *mut u64)
-> Result<u64> {
let mh = SourmashKmerMinHash::as_rust(ptr);
let other_mh = SourmashKmerMinHash::as_rust(other);

if let Ok((_, size)) = mh.intersection_size(other_mh) {
return Ok(size);
if let Ok((common, union_s)) = mh.intersection_size(other_mh) {
*union_size = union_s;
return Ok(common);
}

*union_size = 0;
Ok(0)
}
}
Expand Down
1 change: 1 addition & 0 deletions src/core/src/sketch/minhash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ impl KmerMinHash {
}

// FIXME: intersection_size and count_common should be the same?
// (for scaled minhashes)
pub fn intersection_size(&self, other: &KmerMinHash) -> Result<(u64, u64), Error> {
self.check_compatible(other)?;

Expand Down
4 changes: 2 additions & 2 deletions src/sourmash/cli/categorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

def subparser(subparsers):
subparser = subparsers.add_parser('categorize')
subparser.add_argument('sbt_name', help='name of SBT to load')
subparser.add_argument('database', help='location of signature collection/database to load')
subparser.add_argument(
'queries', nargs='+',
help='list of signatures to categorize'
help='locations of signatures to categorize'
)
subparser.add_argument(
'-q', '--quiet', action='store_true',
Expand Down
127 changes: 82 additions & 45 deletions src/sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import os.path
import sys
import copy

import screed
from .compare import (compare_all_pairs, compare_serial_containment,
Expand All @@ -14,11 +15,8 @@
from . import signature as sig
from . import sourmash_args
from .logging import notify, error, print_results, set_quiet
from .sbtmh import SearchMinHashesFindBest, SigLeaf

from .sourmash_args import DEFAULT_LOAD_K, FileOutput, FileOutputCSV

DEFAULT_N = 500
WATERMARK_SIZE = 10000

from .command_compute import compute
Expand Down Expand Up @@ -385,6 +383,8 @@ def index(args):

if args.scaled:
ss.minhash = ss.minhash.downsample(scaled=args.scaled)
if ss.minhash.track_abundance:
ss.minhash = ss.minhash.flatten()
scaleds.add(ss.minhash.scaled)

tree.insert(ss)
Expand Down Expand Up @@ -422,7 +422,8 @@ def index(args):


def search(args):
from .search import search_databases
from .search import (search_databases_with_flat_query,
search_databases_with_abund_query)

set_quiet(args.quiet)
moltype = sourmash_args.calculate_moltype(args)
Expand Down Expand Up @@ -457,22 +458,32 @@ def search(args):
databases = sourmash_args.load_dbs_and_sigs(args.databases, query,
not is_containment)

# forcibly ignore abundances if query has no abundances
if not query.minhash.track_abundance:
args.ignore_abundance = True

if not len(databases):
error('Nothing found to search!')
sys.exit(-1)

# forcibly ignore abundances if query has no abundances
if not query.minhash.track_abundance:
args.ignore_abundance = True
else:
if args.ignore_abundance:
query.minhash = query.minhash.flatten()

# do the actual search
results = search_databases(query, databases,
threshold=args.threshold,
do_containment=args.containment,
do_max_containment=args.max_containment,
best_only=args.best_only,
ignore_abundance=args.ignore_abundance,
unload_data=True)
if query.minhash.track_abundance:
results = search_databases_with_abund_query(query, databases,
threshold=args.threshold,
do_containment=args.containment,
do_max_containment=args.max_containment,
best_only=args.best_only,
unload_data=True)
else:
results = search_databases_with_flat_query(query, databases,
threshold=args.threshold,
do_containment=args.containment,
do_max_containment=args.max_containment,
best_only=args.best_only,
unload_data=True)

n_matches = len(results)
if args.best_only:
Expand Down Expand Up @@ -520,6 +531,7 @@ def search(args):
def categorize(args):
"Use a database to find the best match to many signatures."
from .index import MultiIndex
from .search import make_jaccard_search_query

set_quiet(args.quiet)
moltype = sourmash_args.calculate_moltype(args)
Expand All @@ -533,7 +545,9 @@ def categorize(args):
already_names.add(row[0])

# load search database
tree = load_sbt_index(args.sbt_name)
db = sourmash_args.load_file_as_index(args.database)
if args.ksize or moltype:
db = db.select(ksize=args.ksize, moltype=moltype)

# utility function to load & select relevant signatures.
def _yield_all_sigs(queries, ksize, moltype):
Expand All @@ -549,40 +563,44 @@ def _yield_all_sigs(queries, ksize, moltype):
csv_fp = open(args.csv, 'w', newline='')
csv_w = csv.writer(csv_fp)

for query, loc in _yield_all_sigs(args.queries, args.ksize, moltype):
search_obj = make_jaccard_search_query(threshold=args.threshold)
for orig_query, loc in _yield_all_sigs(args.queries, args.ksize, moltype):
# skip if we've already done signatures from this file.
if loc in already_names:
continue

notify('loaded query: {}... (k={}, {})', str(query)[:30],
query.minhash.ksize, query.minhash.moltype)
notify('loaded query: {}... (k={}, {})', str(orig_query)[:30],
orig_query.minhash.ksize, orig_query.minhash.moltype)

results = []
search_fn = SearchMinHashesFindBest().search
if args.ignore_abundance:
query = copy.copy(orig_query)
query.minhash = query.minhash.flatten()
else:
if orig_query.minhash.track_abundance:
notify("ERROR: this search cannot be done on signatures calculated with abundance.")
notify("ERROR: please specify --ignore-abundance.")
sys.exit(-1)

# note, "ignore self" here may prevent using newer 'tree.search' fn.
for leaf in tree.find(search_fn, query, args.threshold):
if leaf.data.md5sum() != query.md5sum(): # ignore self.
similarity = query.similarity(
leaf.data, ignore_abundance=args.ignore_abundance)
results.append((similarity, leaf.data))
query = orig_query

results = []
for match, score in db.find(search_obj, query):
if match.md5sum() != query.md5sum(): # ignore self.
results.append((orig_query.similarity(match), match))

best_hit_sim = 0.0
best_hit_query_name = ""
if results:
results.sort(key=lambda x: -x[0]) # reverse sort on similarity
best_hit_sim, best_hit_query = results[0]
notify('for {}, found: {:.2f} {}', query,
best_hit_sim,
best_hit_query)
best_hit_query_name = best_hit_query.name
if csv_w:
csv_w.writerow([loc, query, best_hit_query_name,
best_hit_sim])
else:
notify('for {}, no match found', query)

if csv_w:
csv_w.writerow([loc, query, best_hit_query_name,
best_hit_sim])

if csv_fp:
csv_fp.close()

Expand Down Expand Up @@ -631,12 +649,14 @@ def gather(args):

found = []
weighted_missed = 1
is_abundance = query.minhash.track_abundance and not args.ignore_abundance
orig_query_mh = query.minhash
new_max_hash = query.minhash._max_hash
next_query = query

for result, weighted_missed, new_max_hash, next_query in gather_databases(query, databases, args.threshold_bp, args.ignore_abundance):
if not len(found): # first result? print header.
if query.minhash.track_abundance and not args.ignore_abundance:
if is_abundance:
print_results("")
print_results("overlap p_query p_match avg_abund")
print_results("--------- ------- ------- ---------")
Expand All @@ -651,7 +671,7 @@ def gather(args):
pct_genome = '{:.1f}%'.format(result.f_match*100)
name = result.match._display_name(40)

if query.minhash.track_abundance and not args.ignore_abundance:
if is_abundance:
average_abund ='{:.1f}'.format(result.average_abund)
print_results('{:9} {:>7} {:>7} {:>9} {}',
format_bp(result.intersect_bp), pct_query, pct_genome,
Expand All @@ -666,16 +686,16 @@ def gather(args):
break


# basic reporting
print_results('\nfound {} matches total;', len(found))
# basic reporting:
print_results(f'\nfound {len(found)} matches total;')
if args.num_results and len(found) == args.num_results:
print_results('(truncated gather because --num-results={})',
args.num_results)
print_results(f'(truncated gather because --num-results={args.num_results})')

print_results('the recovered matches hit {:.1f}% of the query',
(1 - weighted_missed) * 100)
p_covered = (1 - weighted_missed) * 100
print_results(f'the recovered matches hit {p_covered:.1f}% of the query')
print_results('')

# save CSV?
if found and args.output:
fieldnames = ['intersect_bp', 'f_orig_query', 'f_match',
'f_unique_to_query', 'f_unique_weighted',
Expand All @@ -691,19 +711,34 @@ def gather(args):
del d['match'] # actual signature not in CSV.
w.writerow(d)

# save matching signatures?
if found and args.save_matches:
notify('saving all matches to "{}"', args.save_matches)
notify(f"saving all matches to '{args.save_matches}'")
with FileOutput(args.save_matches, 'wt') as fp:
sig.save_signatures([ r.match for r in found ], fp)

# save unassigned hashes?
if args.output_unassigned:
if not len(next_query.minhash):
notify('no unassigned hashes to save with --output-unassigned!')
else:
notify('saving unassigned hashes to "{}"', args.output_unassigned)
notify(f"saving unassigned hashes to '{args.output_unassigned}'")

if is_abundance:
# next_query is flattened; reinflate abundances
hashes = set(next_query.minhash.hashes)
orig_abunds = orig_query_mh.hashes
abunds = { h: orig_abunds[h] for h in hashes }

abund_query_mh = orig_query_mh.copy_and_clear()
# orig_query might have been downsampled...
abund_query_mh.downsample(scaled=next_query.minhash.scaled)
abund_query_mh.set_abundances(abunds)
next_query.minhash = abund_query_mh

with FileOutput(args.output_unassigned, 'wt') as fp:
sig.save_signatures([ next_query ], fp)
# DONE w/gather function.


def multigather(args):
Expand Down Expand Up @@ -765,9 +800,10 @@ def multigather(args):

found = []
weighted_missed = 1
is_abundance = query.minhash.track_abundance and not args.ignore_abundance
for result, weighted_missed, new_max_hash, next_query in gather_databases(query, databases, args.threshold_bp, args.ignore_abundance):
if not len(found): # first result? print header.
if query.minhash.track_abundance and not args.ignore_abundance:
if is_abundance:
print_results("")
print_results("overlap p_query p_match avg_abund")
print_results("--------- ------- ------- ---------")
Expand All @@ -782,7 +818,7 @@ def multigather(args):
pct_genome = '{:.1f}%'.format(result.f_match*100)
name = result.match._display_name(40)

if query.minhash.track_abundance and not args.ignore_abundance:
if is_abundance:
average_abund ='{:.1f}'.format(result.average_abund)
print_results('{:9} {:>7} {:>7} {:>9} {}',
format_bp(result.intersect_bp), pct_query, pct_genome,
Expand Down Expand Up @@ -844,6 +880,7 @@ def multigather(args):

e = MinHash(ksize=query.minhash.ksize, n=0, max_hash=new_max_hash)
e.add_many(next_query.minhash.hashes)
# CTB: note, multigather does not save abundances
sig.save_signatures([ sig.SourmashSignature(e) ], fp)
n += 1

Expand Down
Loading

0 comments on commit f02e250

Please sign in to comment.