Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve usability of Enum values #840

Merged
merged 6 commits into from
Apr 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

# 31.0.1 [#840](https://github.com/openfisca/openfisca-core/pull/840)

- Improve usability of Enum values:
- Details:
- Allow the use of Enum values in comparisons: instead of using `<Enum class>.possible_values` you can simply `import` the Enum class
- Accept Enum values via set_input (same result as the previous point)

# 31.0.0 [#813](https://github.com/openfisca/openfisca-core/pull/813)

#### Breaking changes
Expand Down
11 changes: 10 additions & 1 deletion openfisca_core/indexed_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ def encode(cls, array):
if array.dtype.kind in {'U', 'S'}: # String array
array = np.select([array == item.name for item in cls], [item.index for item in cls]).astype(ENUM_ARRAY_DTYPE)
elif array.dtype.kind == 'O': # Enum items arrays
# Ensure we are comparing the comparable. The problem this fixes:
# On entering this method "cls" will generally come from variable.possible_values,
# while the array values may come from directly importing a module containing an Enum class.
# However, variables (and hence their possible_values) are loaded by a call to load_module,
# which gives them a different identity from the ones imported in the usual way.
# So, instead of relying on the "cls" passed in, we use only its name to check that
# the values in the array, if non-empty, are of the right type.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shorten comment or keep in PR only?

Suggested change
# the values in the array, if non-empty, are of the right type.
# "cls" generally comes from variables.possible_values (thus, from load_module)
# while the array values may come from directly importing a module containing an Enum class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you know I'm not a fan of long comments in code, even of comments in general. In this particular case I really feel that the somewhat long explanation is needed to make sense of the code that follows…

if len(array) > 0 and cls.__name__ is array[0].__class__.__name__:
cls = array[0].__class__
array = np.select([array == item for item in cls], [item.index for item in cls]).astype(ENUM_ARRAY_DTYPE)
return EnumArray(array, cls)

Expand All @@ -75,7 +84,7 @@ def __array_finalize__(self, obj):

def __eq__(self, other):
# When comparing to an item of self.possible_values, use the item index to speed up the comparison
if other.__class__ is self.possible_values:
if other.__class__.__name__ is self.possible_values.__name__:
return self.view(np.ndarray) == other.index # use view(np.ndarray) so that the result is a classic ndarray, not an EnumArray
return self.view(np.ndarray) == other

Expand Down
4 changes: 3 additions & 1 deletion openfisca_core/simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def calculate(self, variable_name, period, **parameters):
if array is None:
array = holder.default_array()

array = self._cast_formula_result(array, variable)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


holder.put_in_cache(array, period)
except SpiralError:
array = holder.default_array()
Expand Down Expand Up @@ -248,7 +250,7 @@ def _run_formula(self, variable, entity, period):
array = formula(entity, period, parameters_at)

self._check_formula_result(array, variable, entity, period)
return self._cast_formula_result(array, variable)
return array

def _check_period_consistency(self, period, variable):
"""
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

setup(
name = 'OpenFisca-Core',
version = '31.0.0',
version = '31.0.1',
author = 'OpenFisca Team',
author_email = 'contact@openfisca.org',
classifiers = [
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_holders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import openfisca_country_template.situation_examples
from openfisca_core.simulation_builder import SimulationBuilder
from openfisca_country_template.variables.housing import HousingOccupancyStatus
from openfisca_core.periods import period as make_period, ETERNITY
from openfisca_core.tools import assert_near
from openfisca_core.memory_config import MemoryConfig
Expand All @@ -26,7 +27,6 @@ def couple():


period = make_period('2017-12')
HousingOccupancyStatus = tax_benefit_system.get_variable('housing_occupancy_status').possible_values


def test_set_input_enum_string(couple):
Expand Down