Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: command line parse xml #197

Merged
merged 14 commits into from
Aug 24, 2018
21 changes: 21 additions & 0 deletions pyhf/commandline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import logging
logging.basicConfig()
log = logging.getLogger(__name__)

import click
import json
from . import readxml

@click.group(context_settings=dict(help_option_names=['-h', '--help']))
def pyhf():
pass

@pyhf.command()
@click.option('--entrypoint-xml', required=True, prompt='Top-level XML', help='The top-level XML file for the PDF definition.', type=click.Path(exists=True))
@click.option('--basedir', required=True, prompt='Base directory', help='The base directory for the XML files to point relative to.', type=click.Path(exists=True))
@click.option('--output-file', required=True, prompt='Output file', help='The location of the output json file. If not specified, prints to screen.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think about making the entrypoint-xml be a click.argument there is really not way to convert without input, so pyhf xml2json input.xml seems to be a good cmd line

--basedir could default to os.getcwd()

also, maybe it's somewhat more unixy to print to stdout if the output file is not provided?

pyhf xml2json input.xml > test.json ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyhf xml2json input.xml > test.json ?

The good part is tqdm is part of stderr so we can definitely do that.

@click.option('--tqdm/--no-tqdm', default=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: --track/--no-track or --track-progress/--no-track-progress

def xml2json(entrypoint_xml, basedir, output_file, tqdm):
spec = readxml.parse(entrypoint_xml, basedir, enable_tqdm=tqdm)
json.dump(spec, open(output_file, 'w+'), indent=4, sort_keys=True)
log.info("Written to {0:s}".format(output_file))
46 changes: 30 additions & 16 deletions pyhf/readxml.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import logging
log = logging.getLogger(__name__)

import os
import xml.etree.ElementTree as ET
import numpy as np
import logging

log = logging.getLogger(__name__)
import tqdm

def import_root_histogram(rootdir, filename, path, name):
import uproot
#import pdb; pdb.set_trace()
#assert path == ''
# strip leading slashes as uproot doesn't use "/" for top-level
if path is None: path = ''
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was needed to handle situations where HistoPath wasn't included -- and in these cases, it's equivalent to ''. This code does need to be fixed up more to normalize the XMLs better...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

path = path or ''

is more pythonic?

path = path.strip('/')
f = uproot.open(os.path.join(rootdir, filename))
try:
Expand All @@ -26,7 +28,7 @@ def import_root_histogram(rootdir, filename, path, name):

raise KeyError('Both {0:s} and {1:s} were tried and not found in {2:s}'.format(name, os.path.join(path, name), os.path.join(rootdir, filename)))

def process_sample(sample,rootdir,inputfile, histopath, channelname):
def process_sample(sample,rootdir,inputfile, histopath, channelname, enable_tqdm=False):
if 'InputFile' in sample.attrib:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here and in the other cases, I'd also suggest renaming to track_progress instead of enable_tqdm

inputfile = sample.attrib.get('InputFile')
if 'HistoPath' in sample.attrib:
Expand All @@ -36,7 +38,11 @@ def process_sample(sample,rootdir,inputfile, histopath, channelname):
data,err = import_root_histogram(rootdir, inputfile, histopath, histoname)

modifiers = []
for modtag in sample.iter():

modtags = tqdm.tqdm(sample.iter(), unit='modifier', disable=not(enable_tqdm), total=len(sample))

for modtag in modtags:
modtags.set_description(' - modifier {0:s}({1:s})'.format(modtag.attrib.get('Name', 'n/a'), modtag.tag))
if modtag == sample:
continue
if modtag.tag == 'OverallSys':
Expand All @@ -51,7 +57,6 @@ def process_sample(sample,rootdir,inputfile, histopath, channelname):
'type': 'normfactor',
'data': None
})

elif modtag.tag == 'HistoSys':
lo,_ = import_root_histogram(rootdir,
modtag.attrib.get('HistoFileLow',inputfile),
Expand Down Expand Up @@ -97,26 +102,35 @@ def process_data(sample,rootdir,inputfile, histopath):
data,_ = import_root_histogram(rootdir, inputfile, histopath, histoname)
return data

def process_channel(channelxml,rootdir):
def process_channel(channelxml, rootdir, enable_tqdm=False):
channel = channelxml.getroot()

inputfile = channel.attrib.get('InputFile')
histopath = channel.attrib.get('HistoPath')

samples = channel.findall('Sample')

samples = tqdm.tqdm(channel.findall('Sample'), unit='sample', disable=not(enable_tqdm))

data = channel.findall('Data')[0]

channelname = channel.attrib['Name']
return channelname, process_data(data, rootdir, inputfile, histopath), [process_sample(x, rootdir, inputfile, histopath, channelname) for x in samples]

def parse(configfile,rootdir):
results = []
for sample in samples:
samples.set_description(' - sample {}'.format(sample.attrib.get('Name')))
result = process_sample(sample, rootdir, inputfile, histopath, channelname, enable_tqdm)
results.append(result)

return channelname, process_data(data, rootdir, inputfile, histopath), results

def parse(configfile, rootdir, enable_tqdm=False):
toplvl = ET.parse(configfile)
inputs = [ET.parse(os.path.join(rootdir,x.text)) for x in toplvl.findall('Input')]
channels = {
k:{'data': d, 'samples': v} for k,d,v in [process_channel(inp,rootdir) for inp in inputs]
}
inputs = tqdm.tqdm([x.text for x in toplvl.findall('Input')], unit='channel', disable=not(enable_tqdm))

channels = {}
for inp in inputs:
inputs.set_description('Processing {}'.format(inp))
channel, data, samples = process_channel(ET.parse(os.path.join(rootdir,inp)), rootdir, enable_tqdm)
channels[channel] = {'data': data, 'samples': samples}

return {
'toplvl':{
'resultprefix':toplvl.getroot().attrib['OutputFilePrefix'],
Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
include_package_data = True,
install_requires = [
'numpy<=1.14.5,>=1.14.3', # required by tensorflow, mxnet, and us
'scipy'
'scipy',
'click>=6.0', # for console scripts,
'tqdm', # for readxml
],
extras_require = {
'xmlimport': [
Expand All @@ -35,6 +37,7 @@
'pytest>=3.5.1',
'pytest-cov>=2.5.1',
'pytest-benchmark[histogram]',
'pytest-console-scripts',
'python-coveralls',
'coverage==4.0.3', # coveralls
'matplotlib',
Expand All @@ -52,6 +55,7 @@
]
},
entry_points = {
'console_scripts': ['pyhf=pyhf.commandline:pyhf']
},
dependency_links = [
]
Expand Down
27 changes: 27 additions & 0 deletions tests/test_scripts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
import json
import shlex

import pyhf

# see test_import.py for the same (detailed) test
def test_import_prepHistFactory(tmpdir, script_runner):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this fixture defined?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comes from pytest-console-scripts (#198) which adds the fixture (see readme)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok click has some built in testing capabilities from click.testing. Could use that and avoid the dependency unless pytest-console-scripts adds some nice features (haven't used it)

example usage
https://github.com/yadage/yadage/blob/master/tests/test_maincli.py#L5

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CliRunner doesn't isolate stdout/stderr. It's probably only specific to running click-enabled commands. The pytest-console-scripts is much more generic (runs any script). I would use CliRunner if I spent more time figuring out stderr extraction.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok yes, testing stdout/stderr separately is important, especially if we want to do e.g. > bla.json where we must ensure that the stdout is json deserializable. let's go with pytest-consolte-scripts then.

temp = tmpdir.join("parsed_output.json")
command = 'pyhf xml2json --entrypoint-xml validation/xmlimport_input/config/example.xml --basedir validation/xmlimport_input/ --output-file {0:s} --no-tqdm'.format(temp.strpath)
ret = script_runner.run(*shlex.split(command))
assert ret.success
assert ret.stdout == ''
assert ret.stderr == ''

parsed_xml = json.loads(temp.read())
spec = {'channels': parsed_xml['channels']}
pyhf.utils.validate(spec, pyhf.utils.get_default_schema())

def test_import_prepHistFactory_TQDM(tmpdir, script_runner):
temp = tmpdir.join("parsed_output.json")
command = 'pyhf xml2json --entrypoint-xml validation/xmlimport_input/config/example.xml --basedir validation/xmlimport_input/ --output-file {0:s}'.format(temp.strpath)
ret = script_runner.run(*shlex.split(command))
assert ret.success
assert ret.stdout == ''
assert ret.stderr != ''