Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
moustakas committed Feb 26, 2017
1 parent bbc3f6a commit 339bbf5
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 26 deletions.
41 changes: 26 additions & 15 deletions py/desisim/lya_spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,33 @@

from __future__ import division, print_function

def get_spectra(lyafile, templateid=None, wave=None, normfilter='sdss2010-g',
rand=None, qso=None):
def get_spectra(lyafile, nqso=None, wave=None, templateid=None, normfilter='sdss2010-g',
seed=None, rand=None, qso=None):
'''Generate a QSO spectrum which includes Lyman-alpha absorption.
Args:
lyafile (str): name of the Lyman-alpha spectrum file to read.
templateid (int numpy.ndarray, optional): indices of the spectra
(0-indexed) to read from LYAFILE (default is to read everything).
nqso (int, optional): number of spectra to generate (starting from the
first spectrum; if more flexibility is needed use TEMPLATEID).
wave (numpy.ndarray, optional): desired output wavelength vector.
templateid (int numpy.ndarray, optional): indices of the spectra
(0-indexed) to read from LYAFILE (default is to read everything). If
provided together with NQSO, TEMPLATEID wins.
normfilter (str, optional): normalization filter
seed (int, optional): Seed for random number generator.
rand (numpy.RandomState, optional): RandomState object used for the
random number generation.
random number generation. If provided together with SEED, this
optional input superseeds the numpy.RandomState object instantiated by
SEED.
qso (desisim.templates.QSO, optional): object with which to generate
individual spectra/templates.
Returns:
flux (numpy.ndarray): Array [nmodel, npix] of observed-frame spectra
(erg/s/cm2/A).
wave (numpy.ndarray): Observed-frame [npix] wavelength array (Angstrom).
meta (astropy.Table): Table of meta-data [nmodel] for each output spectrum
with columns defined in desisim.io.empty_metatable *plus* RA, DEC.
flux (numpy.ndarray): Array [nmodel, npix] of observed-frame spectra
(erg/s/cm2/A).
wave (numpy.ndarray): Observed-frame [npix] wavelength array (Angstrom).
meta (astropy.Table): Table of meta-data [nmodel] for each output spectrum
with columns defined in desisim.io.empty_metatable *plus* RA, DEC.
'''
import numpy as np
Expand All @@ -41,17 +47,22 @@ def get_spectra(lyafile, templateid=None, wave=None, normfilter='sdss2010-g',

h = fitsio.FITS(lyafile)
if templateid is None:
nqso = len(h)-1
if nqso is None:
nqso = len(h)-1
templateid = np.arange(nqso)
else:
templateid = np.array(templateid)
nqso = len(templateid)
print(templateid)

if rand is None:
rand = np.random.RandomState()
seed = rand.randint(2**32, size=nqso)
rand = np.random.RandomState(seed)
templateseed = rand.randint(2**32, size=nqso)

#heads = [head.read_header() for head in h[templateid + 1]]
heads = []
for indx in templateid:
print(indx + 1)
heads.append(h[indx + 1].read_header())

zqso = np.array([head['ZQSO'] for head in heads])
Expand All @@ -72,13 +83,13 @@ def get_spectra(lyafile, templateid=None, wave=None, normfilter='sdss2010-g',
meta['TEMPLATEID'] = templateid
meta['REDSHIFT'] = zqso
meta['MAG'] = mag_g
meta['SEED'] = seed
meta['SEED'] = templateseed
meta['RA'] = ra
meta['DEC'] = dec

for ii, indx in enumerate(templateid):
flux1, _, meta1 = qso.make_templates(nmodel=1, redshift=np.array([zqso[ii]]),
mag=np.array([mag_g[ii]]), seed=seed[ii])
mag=np.array([mag_g[ii]]), seed=templateseed[ii])

# read lambda and forest transmission
data = h[indx + 1].read()
Expand Down
30 changes: 19 additions & 11 deletions py/desisim/test/test_lya.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,38 @@ def setUpClass(cls):
fx = fitsio.FITS(cls.infile)
cls.nspec = len(fx) - 1
fx.close()

cls.wavemin = 5000
cls.wavemax = 8000
cls.dwave = 2.0
cls.wave = np.arange(cls.wavemin, cls.wavemax+cls.dwave/2, cls.dwave)
cls.nspec = 5
cls.templateid = [3, 10, 500]
cls.seed = np.random.randint(2**32)
cls.rand = np.random.RandomState(cls.seed)

@unittest.skipIf(missing_fitsio, 'fitsio not installed; skipping lya_spectra tests')
def test_read_lya(self):
flux, wave, meta = lya_spectra.get_spectra(self.infile)
flux, wave, meta = lya_spectra.get_spectra(self.infile, wave=self.wave, seed=self.seed)
self.assertEqual(flux.shape[0], self.nspec)
self.assertEqual(wave.shape[0], flux.shape[1])
self.assertEqual(len(meta), self.nspec)

nqso = 3
flux, wave, meta = lya_spectra.get_spectra(self.infile, nqso=nqso)
flux, wave, meta = lya_spectra.get_spectra(self.infile, templateid=templateid,
wave=self.wave, seed=self.seed)
self.assertEqual(flux.shape[0], nqso)
self.assertEqual(wave.shape[0], flux.shape[1])
self.assertEqual(len(meta), nqso)

flux, wave, meta = lya_spectra.get_spectra(self.infile, nqso=nqso, first=2)
self.assertEqual(flux.shape[0], nqso)
self.assertEqual(wave.shape[0], flux.shape[1])
self.assertEqual(len(meta), nqso)
#flux, wave, meta = lya_spectra.get_spectra(self.infile, nqso=nqso, first=2)
#self.assertEqual(flux.shape[0], nqso)
#self.assertEqual(wave.shape[0], flux.shape[1])
#self.assertEqual(len(meta), nqso)

@unittest.skipIf(missing_fitsio, 'fitsio not installed; skipping lya_spectra tests')
def test_read_lya_seed(self):
flux1a, wave1a, meta1a = lya_spectra.get_spectra(self.infile, nqso=3, seed=1)
flux1b, wave1b, meta1b = lya_spectra.get_spectra(self.infile, nqso=3, seed=1)
flux2, wave2, meta2 = lya_spectra.get_spectra(self.infile, nqso=3, seed=2)
flux1a, wave1a, meta1a = lya_spectra.get_spectra(self.infile, wave=self.wave, nqso=3, seed=1)
flux1b, wave1b, meta1b = lya_spectra.get_spectra(self.infile, wave=self.wave, nqso=3, seed=1)
flux2, wave2, meta2 = lya_spectra.get_spectra(self.infile, wave=self.wave, nqso=3, seed=2)
self.assertTrue(np.all(flux1a == flux1b))
self.assertTrue(np.any(flux1a != flux2))

Expand Down

0 comments on commit 339bbf5

Please sign in to comment.