Skip to content

Commit

Permalink
[MRG] Doing bounds checking on --scaled via command line (#1650)
Browse files Browse the repository at this point in the history
* Index function

* Search function

* Gather function

* Multigather function

* Prefetch function

* For loop Index function

* Modified all functions

* Search function tests

* Gather function tests

* Multigather function tests

* Index function tests

* Prefetch function tests

* Made final changes

* Added function in command line files

* Made changes in function and tests

* Imported argparse

* Removed code block from gather function

* Removed code blocks and their corresponding tests

* Prefetch function cli

* Search function cli

* Multigather function cli

* Index function cli

* Made final changes

* Made requested changes

* Fixed tests and cli function code

* Fixed indentations

Co-authored-by: Tessa Pierce Ward <bluegenes@users.noreply.github.com>
  • Loading branch information
keyabarve and bluegenes authored Jul 16, 2021
1 parent 67aba79 commit 96920c7
Show file tree
Hide file tree
Showing 9 changed files with 277 additions and 34 deletions.
7 changes: 2 additions & 5 deletions src/sourmash/cli/gather.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""search a metagenome signature against dbs"""

from sourmash.cli.utils import (add_ksize_arg, add_moltype_args,
add_picklist_args)
add_picklist_args, add_scaled_arg)


def subparser(subparsers):
Expand Down Expand Up @@ -45,10 +45,6 @@ def subparser(subparsers):
help='output unassigned portions of the query as a signature to the '
'specified file'
)
subparser.add_argument(
'--scaled', metavar='FLOAT', type=float, default=0,
help='downsample query to the specified scaled factor'
)
subparser.add_argument(
'--ignore-abundance', action='store_true',
help='do NOT use k-mer abundances if present'
Expand Down Expand Up @@ -82,6 +78,7 @@ def subparser(subparsers):
add_ksize_arg(subparser, 31)
add_moltype_args(subparser)
add_picklist_args(subparser)
add_scaled_arg(subparser, 0)


def main(args):
Expand Down
7 changes: 2 additions & 5 deletions src/sourmash/cli/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"""

from sourmash.cli.utils import (add_ksize_arg, add_moltype_args,
add_picklist_args)
add_picklist_args, add_scaled_arg)


def subparser(subparsers):
Expand Down Expand Up @@ -66,13 +66,10 @@ def subparser(subparsers):
help='What percentage of internal nodes will not be saved; ranges '
'from 0.0 (save all nodes) to 1.0 (no nodes saved)'
)
subparser.add_argument(
'--scaled', metavar='FLOAT', type=float, default=0,
help='downsample signatures to the specified scaled factor'
)
add_ksize_arg(subparser, 31)
add_moltype_args(subparser)
add_picklist_args(subparser)
add_scaled_arg(subparser, 0)


def main(args):
Expand Down
7 changes: 2 additions & 5 deletions src/sourmash/cli/multigather.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"'sourmash multigather' - gather many signatures against multiple databases."

from sourmash.cli.utils import add_ksize_arg, add_moltype_args
from sourmash.cli.utils import add_ksize_arg, add_moltype_args, add_scaled_arg


def subparser(subparsers):
Expand Down Expand Up @@ -28,16 +28,13 @@ def subparser(subparsers):
'--threshold-bp', metavar='REAL', type=float, default=5e4,
help='threshold (in bp) for reporting results (default=50,000)'
)
subparser.add_argument(
'--scaled', metavar='FLOAT', type=float, default=0,
help='downsample query to the specified scaled factor'
)
subparser.add_argument(
'--ignore-abundance', action='store_true',
help='do NOT use k-mer abundances if present'
)
add_ksize_arg(subparser, 31)
add_moltype_args(subparser)
add_scaled_arg(subparser, 0)


def main(args):
Expand Down
7 changes: 2 additions & 5 deletions src/sourmash/cli/prefetch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""search a signature against dbs, find all overlaps"""

from sourmash.cli.utils import (add_ksize_arg, add_moltype_args,
add_picklist_args)
add_picklist_args, add_scaled_arg)


def subparser(subparsers):
Expand Down Expand Up @@ -54,17 +54,14 @@ def subparser(subparsers):
help='output matching query hashes as a signature to the '
'specified file'
)
subparser.add_argument(
'--scaled', metavar='FLOAT', type=float, default=None,
help='downsample signatures to the specified scaled factor'
)
subparser.add_argument(
'--md5', default=None,
help='select the signature with this md5 as query'
)
add_ksize_arg(subparser, 31)
add_moltype_args(subparser)
add_picklist_args(subparser)
add_scaled_arg(subparser, 0)


def main(args):
Expand Down
7 changes: 2 additions & 5 deletions src/sourmash/cli/search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""search a signature against other signatures"""

from sourmash.cli.utils import (add_ksize_arg, add_moltype_args,
add_picklist_args)
add_picklist_args, add_scaled_arg)


def subparser(subparsers):
Expand Down Expand Up @@ -46,10 +46,6 @@ def subparser(subparsers):
help='do NOT use k-mer abundances if present; note: has no effect if '
'--containment or --max-containment is specified'
)
subparser.add_argument(
'--scaled', metavar='FLOAT', type=float, default=0,
help='downsample query to this scaled factor (yields greater speed)'
)
subparser.add_argument(
'-o', '--output', metavar='FILE',
help='output CSV containing matches to this file'
Expand All @@ -61,6 +57,7 @@ def subparser(subparsers):
add_ksize_arg(subparser, 31)
add_moltype_args(subparser)
add_picklist_args(subparser)
add_scaled_arg(subparser, 0)


def main(args):
Expand Down
24 changes: 24 additions & 0 deletions src/sourmash/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from glob import glob
import os
import argparse
from sourmash.logging import notify


def add_moltype_args(parser):
Expand Down Expand Up @@ -93,3 +94,26 @@ def command_list(dirpath):
basenames = [os.path.splitext(path)[0] for path in filenames if not path.startswith('__')]
basenames = filter(opfilter, basenames)
return sorted(basenames)


def check_scaled_bounds(arg):
actual_min_val = 0
min_val = 100
max_val = 1e6

f = float(arg)

if f < actual_min_val:
raise argparse.ArgumentTypeError(f"ERROR: --scaled value must be positive")
if f < min_val:
notify('WARNING: --scaled value should be >= 100. Continuing anyway.')
if f > max_val:
notify('WARNING: --scaled value should be <= 1e6. Continuing anyway.')
return f


def add_scaled_arg(parser, default=None):
parser.add_argument(
'--scaled', metavar='FLOAT', type=check_scaled_bounds,
help='scaled value should be between 100 and 1e6'
)
14 changes: 5 additions & 9 deletions src/sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def index(args):

if args.scaled:
ss.minhash = ss.minhash.downsample(scaled=args.scaled)

if ss.minhash.track_abundance:
ss.minhash = ss.minhash.flatten()
scaleds.add(ss.minhash.scaled)
Expand Down Expand Up @@ -450,15 +451,13 @@ def search(args):
query.minhash.ksize,
sourmash_args.get_moltype(query))

# downsample if requested
if args.scaled:
if not query.minhash.scaled:
error('cannot downsample a signature not created with --scaled')
sys.exit(-1)

if args.scaled != query.minhash.scaled:
notify('downsampling query from scaled={} to {}',
query.minhash.scaled, int(args.scaled))
query.minhash.scaled, int(args.scaled))
query.minhash = query.minhash.downsample(scaled=args.scaled)

# set up the search databases
Expand Down Expand Up @@ -646,10 +645,9 @@ def gather(args):
error('query signature needs to be created with --scaled')
sys.exit(-1)

# downsample if requested
if args.scaled:
notify('downsampling query from scaled={} to {}',
query.minhash.scaled, int(args.scaled))
query.minhash.scaled, int(args.scaled))
query.minhash = query.minhash.downsample(scaled=args.scaled)

# empty?
Expand Down Expand Up @@ -868,12 +866,11 @@ def multigather(args):
error('query signature needs to be created with --scaled; skipping')
continue

# downsample if requested
if args.scaled:
notify('downsampling query from scaled={} to {}',
query.minhash.scaled, int(args.scaled))
query.minhash.scaled, int(args.scaled))
query.minhash = query.minhash.downsample(scaled=args.scaled)

# empty?
if not len(query.minhash):
error('no query hashes!? skipping to next..')
Expand Down Expand Up @@ -1137,7 +1134,6 @@ def prefetch(args):
if query_mh.track_abundance:
query_mh = query_mh.flatten()

# downsample if/as requested
if args.scaled:
notify(f'downsampling query from scaled={query_mh.scaled} to {int(args.scaled)}')
query_mh = query_mh.downsample(scaled=args.scaled)
Expand Down
42 changes: 42 additions & 0 deletions tests/test_prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,48 @@ def test_prefetch_no_db(runtmp, linear_gather):
assert "ERROR: no databases or signatures to search!?" in c.last_result.err


def test_prefetch_check_scaled_bounds_negative(runtmp, linear_gather):
c = runtmp

sig2 = utils.get_test_data('2.fa.sig')
sig47 = utils.get_test_data('47.fa.sig')
sig63 = utils.get_test_data('63.fa.sig')

with pytest.raises(ValueError) as exc:
c.run_sourmash('prefetch', '-k', '31', sig47, sig63, sig2, sig47,
'--scaled', '-5', linear_gather)

assert "ERROR: --scaled value must be positive" in str(exc.value)


def test_prefetch_check_scaled_bounds_less_than_minimum(runtmp, linear_gather):
c = runtmp

sig2 = utils.get_test_data('2.fa.sig')
sig47 = utils.get_test_data('47.fa.sig')
sig63 = utils.get_test_data('63.fa.sig')

with pytest.raises(ValueError) as exc:
c.run_sourmash('prefetch', '-k', '31', sig47, sig63, sig2, sig47,
'--scaled', '50', linear_gather)

assert "WARNING: --scaled value should be >= 100. Continuing anyway." in str(exc.value)


def test_prefetch_check_scaled_bounds_more_than_maximum(runtmp, linear_gather):
c = runtmp

sig2 = utils.get_test_data('2.fa.sig')
sig47 = utils.get_test_data('47.fa.sig')
sig63 = utils.get_test_data('63.fa.sig')

with pytest.raises(ValueError) as exc:
c.run_sourmash('prefetch', '-k', '31', sig47, sig63, sig2, sig47,
'--scaled', '1e9', linear_gather)

assert "WARNING: --scaled value should be <= 1e6. Continuing anyway." in str(exc.value)


def test_prefetch_downsample_scaled(runtmp, linear_gather):
c = runtmp

Expand Down
Loading

0 comments on commit 96920c7

Please sign in to comment.