Skip to content

Commit

Permalink
Merge pull request #343 from wkerzendorf/plasma/multi_output
Browse files Browse the repository at this point in the history
Plasma/multi output
  • Loading branch information
aoifeboyle committed Jul 1, 2015
2 parents 3fc862f + f69f230 commit 50bcada
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 137 deletions.
2 changes: 1 addition & 1 deletion astropy_helpers
Submodule astropy_helpers updated 51 files
+0 −1 .gitignore
+21 −35 .travis.yml
+3 −165 CHANGES.rst
+1 −7 README.rst
+46 −159 ah_bootstrap.py
+0 −52 appveyor.yml
+0 −29 astropy_helpers/__init__.py
+0 −183 astropy_helpers/commands/build_ext.py
+0 −39 astropy_helpers/commands/build_py.py
+0 −224 astropy_helpers/commands/build_sphinx.py
+0 −14 astropy_helpers/commands/install.py
+0 −14 astropy_helpers/commands/install_lib.py
+0 −53 astropy_helpers/commands/register.py
+0 −4 astropy_helpers/commands/setup_package.py
+0 −257 astropy_helpers/distutils_helpers.py
+9 −42 astropy_helpers/git_helpers.py
+823 −169 astropy_helpers/setup_helpers.py
+1 −5 astropy_helpers/sphinx/conf.py
+2 −17 astropy_helpers/sphinx/ext/astropyautosummary.py
+0 −98 astropy_helpers/sphinx/ext/autodoc_enhancements.py
+7 −21 astropy_helpers/sphinx/ext/automodapi.py
+15 −40 astropy_helpers/sphinx/ext/automodsumm.py
+0 −56 astropy_helpers/sphinx/ext/tests/test_autodoc_enhancements.py
+11 −8 astropy_helpers/sphinx/ext/tests/test_automodapi.py
+7 −6 astropy_helpers/sphinx/ext/tests/test_automodsumm.py
+7 −10 astropy_helpers/sphinx/ext/viewcode.py
+ astropy_helpers/sphinx/local/python3links.inv
+0 −7 astropy_helpers/sphinx/local/python3links.txt
+1 −1 astropy_helpers/sphinx/themes/bootstrap-astropy/layout.html
+0 −75 astropy_helpers/sphinx/themes/bootstrap-astropy/static/astropy_linkout.svg
+ astropy_helpers/sphinx/themes/bootstrap-astropy/static/astropy_logo.ico
+0 −87 astropy_helpers/sphinx/themes/bootstrap-astropy/static/astropy_logo.svg
+3 −15 astropy_helpers/sphinx/themes/bootstrap-astropy/static/bootstrap-astropy.css
+0 −0 astropy_helpers/src/__init__.py
+0 −0 astropy_helpers/src/compiler.c
+2 −0 astropy_helpers/src/setup_package.py
+3 −9 astropy_helpers/test_helpers.py
+0 −36 astropy_helpers/tests/__init__.py
+5 −42 astropy_helpers/tests/test_ah_bootstrap.py
+26 −85 astropy_helpers/tests/test_git_helpers.py
+46 −244 astropy_helpers/tests/test_setup_helpers.py
+5 −504 astropy_helpers/utils.py
+56 −150 astropy_helpers/version_helpers.py
+0 −71 continuous-integration/appveyor/install-miniconda.ps1
+0 −47 continuous-integration/appveyor/windows_sdk.cmd
+0 −7 continuous-integration/travis/install_conda_linux.sh
+0 −7 continuous-integration/travis/install_conda_osx.sh
+0 −4 continuous-integration/travis/install_graphviz_linux.sh
+0 −4 continuous-integration/travis/install_graphviz_osx.sh
+3 −6 setup.py
+2 −7 tox.ini
108 changes: 64 additions & 44 deletions tardis/plasma/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,24 @@


class BasePlasma(object):

def __init__(self, plasma_modules, **kwargs):
self.module_dict = {}
self.input_modules = []
self._init_modules(plasma_modules, **kwargs)
outputs_dict = {}
def __init__(self, plasma_properties, **kwargs):
self.outputs_dict = {}
self.input_properties = []
self.plasma_properties = self._init_properties(plasma_properties,
**kwargs)

self._build_graph()
self.update(**kwargs)


def __getattr__(self, item):
if item in self.module_dict:
if item in self.outputs_dict:
return self.get_value(item)
else:
super(BasePlasma, self).__getattribute__(item)


def __setattr__(self, key, value):
if key != 'module_dict' and key in self.module_dict:
if key != 'module_dict' and key in self.outputs_dict:
raise AttributeError('Plasma inputs can only be updated using '
'the \'update\' method')
else:
Expand All @@ -42,8 +42,12 @@ def __dir__(self):

return attrs

@property
def plasma_properties_dict(self):
return {item.name:item for item in self.plasma_properties}

def get_value(self, item):
return self.module_dict[item].value
return getattr(self.outputs_dict[item], item)

def _build_graph(self):
"""
Expand All @@ -56,62 +60,80 @@ def _build_graph(self):
self.graph = nx.DiGraph()

## Adding all nodes
self.graph.add_nodes_from([(key, {})
for key, value in self.module_dict.items()])
self.graph.add_nodes_from([(plasma_property.name, {})
for plasma_property
in self.plasma_properties])

#Flagging all input modules
self.input_modules = [key for key, item in self.module_dict.items()
if not hasattr(item, 'inputs')]
self.input_properties = [item for item in self.plasma_properties
if not hasattr(item, 'inputs')]

for plasma_module in self.module_dict.values():
for plasma_property in self.plasma_properties:
#Skipping any module that is an input module
if plasma_module.name in self.input_modules:
if plasma_property in self.input_properties:
continue

for input in plasma_module.inputs:
if input not in self.graph:
for input in plasma_property.inputs:
if input not in self.outputs_dict:
raise PlasmaMissingModule('Module {0} requires input '
'{1} which has not been added'
' to this plasma'.format(
plasma_module.name, input))
self.graph.add_edge(input, plasma_module.name)
plasma_property.name, input))
self.graph.add_edge(self.outputs_dict[input].name,
plasma_property.name, label=input)

def _init_modules(self, plasma_modules, **kwargs):
def _init_properties(self, plasma_properties, **kwargs):
"""
Builds a dictionary with the plasma module names as keys
:param plasma_modules:
:return:
"""
self.module_dict = {}
for module in plasma_modules:
if hasattr(module, 'set_value'):
if module.name not in kwargs:
Parameters
----------
plasma_modules: ~list
list of Plasma properties
kwargs: dictionary
input values for input properties. For example, t_rad=[5000, 6000,],
j_blues=[..]
"""
plasma_property_objects = []
self.outputs_dict = {}
for plasma_property in plasma_properties:

if hasattr(plasma_property, 'set_value'):
#duck-typing for PlasmaInputProperty
#that means if it is an input property from model
if not set(kwargs.keys()).issuperset(plasma_property.outputs):
missing_input_values = (set(plasma_property.outputs) -
set(kwargs.keys()))
raise NotInitializedModule('Input {0} required for '
'plasma but not given when '
'instantiating the '
'plasma'.format(module.name))
current_module_object = module()
'plasma'.format(
missing_input_values))
current_property_object = plasma_property()
else:
current_module_object = module(self)

self.module_dict[module.name] = current_module_object
current_property_object = plasma_property(self)
for output in plasma_property.outputs:
self.outputs_dict[output] = current_property_object
plasma_property_objects.append(current_property_object)
return plasma_property_objects

def update(self, **kwargs):
for key in kwargs:
if key not in self.module_dict:
if key not in self.outputs_dict:
raise PlasmaMissingModule('Trying to update property {0}'
' that is unavailable'.format(key))
self.module_dict[key].set_value(kwargs[key])
self.outputs_dict[key].set_value(kwargs[key])

for module_name in self._resolve_update_list(kwargs.keys()):
self.module_dict[module_name].update()
self.plasma_properties_dict[module_name].update()

def _update_module_type_str(self):
for node in self.graph:
self.module_dict[node]._update_type_str()
self.outputs_dict[node]._update_type_str()

def _resolve_update_list(self, changed_modules):
def _resolve_update_list(self, changed_properties):
"""
Returns a list of all plasma models which are affected by the
changed_modules due to there dependency in the
Expand All @@ -120,8 +142,6 @@ def _resolve_update_list(self, changed_modules):
Parameters
----------
graph: ~networkx.Graph
the plasma graph as
changed_modules: ~list
all modules changed in the plasma
Expand All @@ -134,8 +154,9 @@ def _resolve_update_list(self, changed_modules):

descendants_ob = []

for module in changed_modules:
descendants_ob += nx.descendants(self.graph, module)
for plasma_property in changed_properties:
node_name = self.outputs_dict[plasma_property].name
descendants_ob += nx.descendants(self.graph, node_name)

descendants_ob = list(set(descendants_ob))
sort_order = nx.topological_sort(self.graph)
Expand All @@ -147,7 +168,7 @@ def _resolve_update_list(self, changed_modules):

return descendants_ob

def write_to_dot(self, fname):
def write_to_dot(self, fname, latex_label=True):
self._update_module_type_str()

try:
Expand Down Expand Up @@ -187,4 +208,3 @@ def __init__(self, number_densities, atom_data, time_explosion,
link_t_rad_t_electron=0.9):

pass

82 changes: 43 additions & 39 deletions tardis/plasma/properties/atomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ class BaseAtomicDataProperty(ProcessingPlasmaProperty):
inputs = ['atomic_data', 'selected_atoms']

def __init__(self, plasma_parent):

super(BaseAtomicDataProperty, self).__init__(plasma_parent)
self.value = None
assert len(self.outputs) == 1

@abstractmethod
def _set_index(self, raw_atomic_property, atomic_data):
Expand All @@ -31,31 +33,33 @@ def _filter_atomic_property(self, raw_atomic_property):


def calculate(self, atomic_data, selected_atoms):
if self.value is not None:
return self.value

if getattr(self, self.outputs[0]) is not None:
return getattr(self, self.outputs[0])
else:
try:
raw_atomic_property = getattr(atomic_data, '_' + self.name)
return self._set_index(self._filter_atomic_property(
raw_atomic_property, selected_atoms), atomic_data)
except:
raw_atomic_property = getattr(atomic_data, self.name)
raw_atomic_property = getattr(atomic_data, '_' + self.outputs[0])
except AttributeError:
raw_atomic_property = getattr(atomic_data, self.outputs[0])
finally:
return self._set_index(self._filter_atomic_property(
raw_atomic_property, selected_atoms), atomic_data)


class Levels(BaseAtomicDataProperty):
name = 'levels'
outputs = ('levels',)

def _filter_atomic_property(self, levels, selected_atoms):
return levels[levels.atomic_number.isin(selected_atoms)]
return levels[levels.atomic_number.isin([selected_atoms]
if np.isscalar(selected_atoms)
else selected_atoms)]

def _set_index(self, levels, atomic_data):
return levels.set_index(['atomic_number', 'ion_number',
'level_number'])

class Lines(BaseAtomicDataProperty):
name = 'lines'
outputs = ('lines',)

def _filter_atomic_property(self, lines, selected_atoms):
return lines[lines.atomic_number.isin(selected_atoms)]
Expand All @@ -68,7 +72,7 @@ def _set_index(self, lines, atomic_data):
return reindexed

class LinesLowerLevelIndex(ProcessingPlasmaProperty):
name = 'lines_lower_level_index'
outputs = ('lines_lower_level_index',)

def calculate(self, levels, lines):
levels_index = pd.Series(np.arange(len(levels), dtype=np.int64),
Expand All @@ -79,7 +83,7 @@ def calculate(self, levels, lines):
return np.array(levels_index.ix[lines_index])

class LinesUpperLevelIndex(ProcessingPlasmaProperty):
name = 'lines_upper_level_index'
outputs = ('lines_upper_level_index',)

def calculate(self, levels, lines):
levels_index = pd.Series(np.arange(len(levels), dtype=np.int64),
Expand All @@ -91,7 +95,7 @@ def calculate(self, levels, lines):


class IonCXData(BaseAtomicDataProperty):
name = 'ion_cx_data'
outputs = ('ion_cx_data',)

def _filter_atomic_property(self, ion_cx_data, selected_atoms):
return filtered_ion_cx_data
Expand All @@ -102,16 +106,16 @@ def _set_index(self, ion_cx_data, atomic_data):


class AtomicMass(ProcessingPlasmaProperty):
name = 'atomic_mass'
outputs = ('atomic_mass',)

def calculate(self, atomic_data, selected_atoms):
if self.value is not None:
return self.value
if getattr(self, self.outputs[0]) is not None:
return (getattr(self, self.outputs[0]),)
else:
return atomic_data.atom_data.ix[selected_atoms].mass

class IonizationData(BaseAtomicDataProperty):
name = 'ionization_data'
outputs = ('ionization_data',)

def _filter_atomic_property(self, ionization_data, selected_atoms):
ionization_data['atomic_number'] = ionization_data.index.labels[0]+1
Expand All @@ -132,7 +136,7 @@ def _set_index(self, ionization_data, atomic_data):
return ionization_data.set_index(['atomic_number', 'ion_number'])

class ZetaData(BaseAtomicDataProperty):
name = 'zeta_data'
outputs = ('zeta_data',)

def _filter_atomic_property(self, zeta_data, selected_atoms):
zeta_data['atomic_number'] = zeta_data.index.labels[0]+1
Expand All @@ -144,27 +148,27 @@ def _filter_atomic_property(self, zeta_data, selected_atoms):
if np.alltrue(keys+1==values):
return zeta_data
else:
raise IncompleteAtomicData('zeta data')
# logger.warn('Zeta_data missing - replaced with 1s')
# updated_index = []
# for atom in selected_atoms:
# for ion in range(1, atom+2):
# updated_index.append([atom,ion])
# updated_index = np.array(updated_index)
# updated_dataframe = pd.DataFrame(index=pd.MultiIndex.from_arrays(
# updated_index.transpose().astype(int)),
# columns = zeta_data.columns)
# for value in range(len(zeta_data)):
# updated_dataframe.ix[zeta_data.atomic_number.values[value]].ix[
# zeta_data.ion_number.values[value]] = \
# zeta_data.ix[zeta_data.atomic_number.values[value]].ix[
# zeta_data.ion_number.values[value]]
# updated_dataframe = updated_dataframe.astype(float)
# updated_index = pd.DataFrame(updated_index)
# updated_dataframe['atomic_number'] = np.array(updated_index[0])
# updated_dataframe['ion_number'] = np.array(updated_index[1])
# updated_dataframe.fillna(1.0, inplace=True)
# return updated_dataframe
# raise IncompleteAtomicData('zeta data')
logger.warn('Zeta_data missing - replaced with 1s')
updated_index = []
for atom in selected_atoms:
for ion in range(1, atom+2):
updated_index.append([atom,ion])
updated_index = np.array(updated_index)
updated_dataframe = pd.DataFrame(index=pd.MultiIndex.from_arrays(
updated_index.transpose().astype(int)),
columns = zeta_data.columns)
for value in range(len(zeta_data)):
updated_dataframe.ix[zeta_data.atomic_number.values[value]].ix[
zeta_data.ion_number.values[value]] = \
zeta_data.ix[zeta_data.atomic_number.values[value]].ix[
zeta_data.ion_number.values[value]]
updated_dataframe = updated_dataframe.astype(float)
updated_index = pd.DataFrame(updated_index)
updated_dataframe['atomic_number'] = np.array(updated_index[0])
updated_dataframe['ion_number'] = np.array(updated_index[1])
updated_dataframe.fillna(1.0, inplace=True)
return updated_dataframe

def _set_index(self, zeta_data, atomic_data):
return zeta_data.set_index(['atomic_number', 'ion_number'])
Expand Down
Loading

0 comments on commit 50bcada

Please sign in to comment.