Skip to content

Commit

Permalink
Merge pull request #840 from openfisca/fix-double-import-errors
Browse files Browse the repository at this point in the history
Improve usability of Enum values
  • Loading branch information
Morendil authored Apr 5, 2019
2 parents 607a3b6 + eed5f30 commit ae1dad2
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 4 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

# 31.0.1 [#840](https://github.com/openfisca/openfisca-core/pull/840)

- Improve usability of Enum values:
- Details:
- Allow the use of Enum values in comparisons: instead of using `<Enum class>.possible_values` you can simply `import` the Enum class
- Accept Enum values via set_input (same result as the previous point)

# 31.0.0 [#813](https://github.com/openfisca/openfisca-core/pull/813)

#### Breaking changes
Expand Down
11 changes: 10 additions & 1 deletion openfisca_core/indexed_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ def encode(cls, array):
if array.dtype.kind in {'U', 'S'}: # String array
array = np.select([array == item.name for item in cls], [item.index for item in cls]).astype(ENUM_ARRAY_DTYPE)
elif array.dtype.kind == 'O': # Enum items arrays
# Ensure we are comparing the comparable. The problem this fixes:
# On entering this method "cls" will generally come from variable.possible_values,
# while the array values may come from directly importing a module containing an Enum class.
# However, variables (and hence their possible_values) are loaded by a call to load_module,
# which gives them a different identity from the ones imported in the usual way.
# So, instead of relying on the "cls" passed in, we use only its name to check that
# the values in the array, if non-empty, are of the right type.
if len(array) > 0 and cls.__name__ is array[0].__class__.__name__:
cls = array[0].__class__
array = np.select([array == item for item in cls], [item.index for item in cls]).astype(ENUM_ARRAY_DTYPE)
return EnumArray(array, cls)

Expand All @@ -75,7 +84,7 @@ def __array_finalize__(self, obj):

def __eq__(self, other):
# When comparing to an item of self.possible_values, use the item index to speed up the comparison
if other.__class__ is self.possible_values:
if other.__class__.__name__ is self.possible_values.__name__:
return self.view(np.ndarray) == other.index # use view(np.ndarray) so that the result is a classic ndarray, not an EnumArray
return self.view(np.ndarray) == other

Expand Down
4 changes: 3 additions & 1 deletion openfisca_core/simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def calculate(self, variable_name, period, **parameters):
if array is None:
array = holder.default_array()

array = self._cast_formula_result(array, variable)

holder.put_in_cache(array, period)
except SpiralError:
array = holder.default_array()
Expand Down Expand Up @@ -248,7 +250,7 @@ def _run_formula(self, variable, entity, period):
array = formula(entity, period, parameters_at)

self._check_formula_result(array, variable, entity, period)
return self._cast_formula_result(array, variable)
return array

def _check_period_consistency(self, period, variable):
"""
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

setup(
name = 'OpenFisca-Core',
version = '31.0.0',
version = '31.0.1',
author = 'OpenFisca Team',
author_email = 'contact@openfisca.org',
classifiers = [
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_holders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import openfisca_country_template.situation_examples
from openfisca_core.simulation_builder import SimulationBuilder
from openfisca_country_template.variables.housing import HousingOccupancyStatus
from openfisca_core.periods import period as make_period, ETERNITY
from openfisca_core.tools import assert_near
from openfisca_core.memory_config import MemoryConfig
Expand All @@ -26,7 +27,6 @@ def couple():


period = make_period('2017-12')
HousingOccupancyStatus = tax_benefit_system.get_variable('housing_occupancy_status').possible_values


def test_set_input_enum_string(couple):
Expand Down

0 comments on commit ae1dad2

Please sign in to comment.